29 #include "../TensorFlowHelper.h" 30 #include "tensorflow/core/framework/op.h" 31 #include "tensorflow/core/framework/op_kernel.h" 32 #include "tensorflow/core/lib/core/errors.h" 37 tensorflow::OpKernelConstruction*
context)
40 OP_REQUIRES_OK(context, context->GetAttr(
"resolution", &
r));
41 OP_REQUIRES_OK(context, context->GetAttr(
"is_training", &
is_training));
42 OP_REQUIRES(context,
r > 0,
43 errors::InvalidArgument(
44 "TrilinearDevoxelize expects positive resolution"));
49 const Tensor& coords = context->input(0);
51 context, coords.dims() == 3 && coords.shape().dim_size(1) == 3,
52 errors::InvalidArgument(
"TrilinearDevoxelize expects " 53 "(batch_size, 3, N) coordinate shape"));
54 const Tensor& feat = context->input(1);
55 OP_REQUIRES(context, feat.dims() == 5,
56 errors::InvalidArgument(
"TrilinearDevoxelize expects " 57 "5 dimensions for features"));
59 int batch_size = coords.shape().dim_size(0);
60 int num_points = coords.shape().dim_size(2);
61 int feat_dim = feat.shape().dim_size(1);
63 auto coords_flat = coords.flat<
float>();
64 auto feat_flat = feat.flat<
float>();
66 const float* inp_coords = &(coords_flat(0));
67 const float* inp_feat = &(feat_flat(0));
70 OP_REQUIRES_OK(context,
71 context->allocate_output(
72 0, TensorShape{batch_size, feat_dim, num_points},
75 OP_REQUIRES_OK(context,
76 context->allocate_output(
77 1, TensorShape{batch_size, 8, num_points},
80 OP_REQUIRES_OK(context,
81 context->allocate_output(
82 2, TensorShape{batch_size, 8, num_points},
84 auto flat_0 = out_tensor_0->flat<
float>();
85 auto flat_1 = out_tensor_1->flat<
int>();
86 auto flat_2 = out_tensor_2->flat<
float>();
88 float* out_0 = &(flat_0(0));
89 int* out_1 = &(flat_1(0));
90 float* out_2 = &(flat_2(0));
93 Kernel(context, batch_size, feat_dim, num_points,
r,
r *
r,
94 r * r * r,
true, inp_coords, inp_feat, out_1, out_2, out_0);
96 Kernel(context, batch_size, feat_dim, num_points, r, r * r,
97 r * r * r,
false, inp_coords, inp_feat, out_1, out_2, out_0);
122 tensorflow::OpKernelConstruction*
context)
125 OP_REQUIRES_OK(context, context->GetAttr(
"resolution", &
r));
128 errors::InvalidArgument(
129 "TrilinearDevoxelizeGrad expects positive resolution"));
134 const Tensor& grad_y = context->input(0);
136 context, grad_y.dims() == 3,
137 errors::InvalidArgument(
"TrilinearDevoxelizeGrad expects " 138 "(batch_size, C, N) gradient shape"));
139 const Tensor& inds = context->input(1);
141 context, inds.dims() == 3 && inds.shape().dim_size(1) == 8,
142 errors::InvalidArgument(
"TrilinearDevoxelizeGrad expects " 143 "(batch_size, 8, N) indices shape"));
144 const Tensor& wgts = context->input(2);
146 context, wgts.dims() == 3 && wgts.shape().dim_size(1) == 8,
147 errors::InvalidArgument(
"TrilinearDevoxelizeGrad expects " 148 "(batch_size, 8, N) weights shape"));
150 int batch_size = grad_y.shape().dim_size(0);
151 int num_points = grad_y.shape().dim_size(2);
152 int feat_dim = grad_y.shape().dim_size(1);
154 auto grad_y_flat = grad_y.flat<
float>();
155 auto inds_flat = inds.flat<
int>();
156 auto wgts_flat = wgts.flat<
float>();
158 const float* inp_grad_y = &(grad_y_flat(0));
159 const int* inp_inds = &(inds_flat(0));
160 const float* inp_wgts = &(wgts_flat(0));
163 OP_REQUIRES_OK(context,
164 context->allocate_output(
165 0, TensorShape{batch_size, feat_dim, r, r, r},
167 auto flat_tensor = out_tensor->flat<
float>();
169 float* out = &(flat_tensor(0));
171 Kernel(context, batch_size, feat_dim, num_points,
r *
r *
r, inp_inds,
172 inp_wgts, inp_grad_y, out);
bool is_training
Definition: TrilinearDevoxelizeKernel.h:116
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int c, int n, int r, int r2, int r3, bool training, const float *coords, const float *feat, int *inds, float *wgts, float *outs)=0
TrilinearDevoxelizeGradOpKernel(tensorflow::OpKernelConstruction *context)
Definition: TrilinearDevoxelizeKernel.h:121
Definition: TrilinearDevoxelizeKernel.h:34
void Compute(tensorflow::OpKernelContext *context) override
Definition: TrilinearDevoxelizeKernel.h:132
int r
Definition: TrilinearDevoxelizeKernel.h:186
ImGuiContext * context
Definition: Window.cpp:95
Definition: TrilinearDevoxelizeKernel.h:119
TrilinearDevoxelizeOpKernel(tensorflow::OpKernelConstruction *context)
Definition: TrilinearDevoxelizeKernel.h:36
int r
Definition: TrilinearDevoxelizeKernel.h:115
void Compute(tensorflow::OpKernelContext *context) override
Definition: TrilinearDevoxelizeKernel.h:47