10 #include "../TensorFlowHelper.h"
11 #include "tensorflow/core/framework/op.h"
12 #include "tensorflow/core/framework/op_kernel.h"
13 #include "tensorflow/core/lib/core/errors.h"
18 tensorflow::OpKernelConstruction*
context)
19 : tensorflow::OpKernel(
context) {
20 using namespace tensorflow;
24 absl::InvalidArgumentError(
25 "TrilinearDevoxelize expects positive resolution"));
29 using namespace tensorflow;
30 const Tensor& coords =
context->input(0);
32 coords.dims() == 3 && coords.shape().dim_size(1) == 3,
33 absl::InvalidArgumentError(
34 "TrilinearDevoxelize expects "
35 "(batch_size, 3, N) coordinate shape"));
36 const Tensor& feat =
context->input(1);
37 OP_REQUIRES(
context, feat.dims() == 5,
38 absl::InvalidArgumentError(
"TrilinearDevoxelize expects "
39 "5 dimensions for features"));
41 int batch_size = coords.shape().dim_size(0);
42 int num_points = coords.shape().dim_size(2);
43 int feat_dim = feat.shape().dim_size(1);
45 auto coords_flat = coords.flat<
float>();
46 auto feat_flat = feat.flat<
float>();
48 const float* inp_coords = &(coords_flat(0));
49 const float* inp_feat = &(feat_flat(0));
54 0, TensorShape{batch_size, feat_dim, num_points},
59 1, TensorShape{batch_size, 8, num_points},
64 2, TensorShape{batch_size, 8, num_points},
66 auto flat_0 = out_tensor_0->flat<
float>();
67 auto flat_1 = out_tensor_1->flat<
int>();
68 auto flat_2 = out_tensor_2->flat<
float>();
70 float* out_0 = &(flat_0(0));
71 int* out_1 = &(flat_1(0));
72 float* out_2 = &(flat_2(0));
76 r *
r *
r,
true, inp_coords, inp_feat, out_1, out_2, out_0);
79 r *
r *
r,
false, inp_coords, inp_feat, out_1, out_2, out_0);
104 tensorflow::OpKernelConstruction*
context)
105 : tensorflow::OpKernel(
context) {
106 using namespace tensorflow;
110 absl::InvalidArgumentError(
111 "TrilinearDevoxelizeGrad expects positive resolution"));
115 using namespace tensorflow;
116 const Tensor& grad_y =
context->input(0);
117 OP_REQUIRES(
context, grad_y.dims() == 3,
118 absl::InvalidArgumentError(
119 "TrilinearDevoxelizeGrad expects "
120 "(batch_size, C, N) gradient shape"));
121 const Tensor& inds =
context->input(1);
123 context, inds.dims() == 3 && inds.shape().dim_size(1) == 8,
124 absl::InvalidArgumentError(
"TrilinearDevoxelizeGrad expects "
125 "(batch_size, 8, N) indices shape"));
126 const Tensor& wgts =
context->input(2);
128 context, wgts.dims() == 3 && wgts.shape().dim_size(1) == 8,
129 absl::InvalidArgumentError(
"TrilinearDevoxelizeGrad expects "
130 "(batch_size, 8, N) weights shape"));
132 int batch_size = grad_y.shape().dim_size(0);
133 int num_points = grad_y.shape().dim_size(2);
134 int feat_dim = grad_y.shape().dim_size(1);
136 auto grad_y_flat = grad_y.flat<
float>();
137 auto inds_flat = inds.flat<
int>();
138 auto wgts_flat = wgts.flat<
float>();
140 const float* inp_grad_y = &(grad_y_flat(0));
141 const int* inp_inds = &(inds_flat(0));
142 const float* inp_wgts = &(wgts_flat(0));
147 0, TensorShape{batch_size, feat_dim, r, r, r},
149 auto flat_tensor = out_tensor->flat<
float>();
151 float* out = &(flat_tensor(0));
154 inp_wgts, inp_grad_y, out);
ImGuiContext * context
Definition: Window.cpp:76
Definition: TrilinearDevoxelizeKernel.h:101
int r
Definition: TrilinearDevoxelizeKernel.h:168
void Compute(tensorflow::OpKernelContext *context) override
Definition: TrilinearDevoxelizeKernel.h:114
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int c, int n, int r3, const int *inds, const float *wgts, const float *grad_y, float *grad_x)=0
TrilinearDevoxelizeGradOpKernel(tensorflow::OpKernelConstruction *context)
Definition: TrilinearDevoxelizeKernel.h:103
Definition: TrilinearDevoxelizeKernel.h:15
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int c, int n, int r, int r2, int r3, bool training, const float *coords, const float *feat, int *inds, float *wgts, float *outs)=0
TrilinearDevoxelizeOpKernel(tensorflow::OpKernelConstruction *context)
Definition: TrilinearDevoxelizeKernel.h:17
void Compute(tensorflow::OpKernelContext *context) override
Definition: TrilinearDevoxelizeKernel.h:28
int r
Definition: TrilinearDevoxelizeKernel.h:97
bool is_training
Definition: TrilinearDevoxelizeKernel.h:98