25 #include "../TensorFlowHelper.h" 26 #include "tensorflow/core/framework/op.h" 27 #include "tensorflow/core/framework/op_kernel.h" 28 #include "tensorflow/core/lib/core/errors.h" 33 : OpKernel(construction) {}
38 const Tensor& inp_tensor = context->input(0);
41 inp_tensor.dims() == 3 && inp_tensor.shape().dim_size(2) == 3,
42 errors::InvalidArgument(
"ThreeNN expects " 43 "(batch_size,num_points,3) inp shape"));
44 int batch_size = inp_tensor.shape().dim_size(0);
45 int pts_num_out = inp_tensor.shape().dim_size(1);
46 auto inp_flat = inp_tensor.flat<
float>();
47 const float* inp = &(inp_flat(0));
49 const Tensor& data_tensor = context->input(1);
52 data_tensor.dims() == 3 && data_tensor.shape().dim_size(2) == 3,
53 errors::InvalidArgument(
55 "(batch_size,num_points,3) data shape"));
56 int pts_num_in = data_tensor.shape().dim_size(1);
57 auto data_flat = data_tensor.flat<
float>();
58 const float*
data = &(data_flat(0));
63 context->allocate_output(
64 0, TensorShape{batch_size, pts_num_out, 3}, &out_dist));
65 auto out_flat0 = out_dist->flat<
float>();
66 float* out0 = &(out_flat0(0));
71 context->allocate_output(
72 1, TensorShape{batch_size, pts_num_out, 3}, &out_idx));
73 auto out_flat1 = out_idx->flat<
int>();
74 int* out1 = &(out_flat1(0));
76 Kernel(context, batch_size, pts_num_out, pts_num_in, inp, data, out0,
93 tensorflow::OpKernelConstruction* construction)
94 : OpKernel(construction) {}
99 const Tensor& inp_tensor = context->input(0);
101 context, inp_tensor.dims() == 3,
102 errors::InvalidArgument(
"ThreeInterpolate expects " 103 "(batch_size,num_points,3) inp shape"));
104 int batch_size = inp_tensor.shape().dim_size(0);
105 int C = inp_tensor.shape().dim_size(1);
106 int M = inp_tensor.shape().dim_size(2);
107 auto inp_flat = inp_tensor.flat<
float>();
108 const float* inp = &(inp_flat(0));
110 const Tensor& idx_tensor = context->input(1);
112 context, idx_tensor.dims() == 3,
113 errors::InvalidArgument(
"ThreeInterpolate expects " 114 "(batch_size,num_points,3) idx shape"));
115 int N = idx_tensor.shape().dim_size(1);
116 auto idx_flat = idx_tensor.flat<
int>();
117 const int* idx = &(idx_flat(0));
119 const Tensor& weights_tensor = context->input(2);
120 OP_REQUIRES(context, weights_tensor.dims() == 3,
121 errors::InvalidArgument(
122 "ThreeInterpolate expects " 123 "(batch_size,num_points,3) weights shape"));
124 auto weights_flat = weights_tensor.flat<
float>();
125 const float* weights = &(weights_flat(0));
128 OP_REQUIRES_OK(context,
129 context->allocate_output(
130 0, TensorShape{batch_size, C, N}, &out_tensor));
131 auto out_flat = out_tensor->flat<
float>();
132 float* out = &(out_flat(0));
134 Kernel(context, batch_size, C, M, N, inp, idx, weights, out);
151 tensorflow::OpKernelConstruction* construction)
152 : OpKernel(construction) {
153 OP_REQUIRES_OK(construction, construction->GetAttr(
"M", &M));
159 const Tensor& inp_tensor = context->input(0);
161 context, inp_tensor.dims() == 3,
162 errors::InvalidArgument(
"ThreeInterpolateGrad expects " 163 "(batch_size,num_points,3) inp shape"));
164 int batch_size = inp_tensor.shape().dim_size(0);
165 int C = inp_tensor.shape().dim_size(1);
166 int N = inp_tensor.shape().dim_size(2);
167 auto inp_flat = inp_tensor.flat<
float>();
168 const float* inp = &(inp_flat(0));
170 const Tensor& idx_tensor = context->input(1);
172 context, idx_tensor.dims() == 3,
173 errors::InvalidArgument(
"ThreeInterpolateGrad expects " 174 "(batch_size,num_points,3) idx shape"));
175 auto idx_flat = idx_tensor.flat<
int>();
176 const int* idx = &(idx_flat(0));
178 const Tensor& weights_tensor = context->input(2);
179 OP_REQUIRES(context, weights_tensor.dims() == 3,
180 errors::InvalidArgument(
181 "ThreeInterpolateGrad expects " 182 "(batch_size,num_points,3) weights shape"));
183 auto weights_flat = weights_tensor.flat<
float>();
184 const float* weights = &(weights_flat(0));
187 OP_REQUIRES_OK(context,
188 context->allocate_output(
189 0, TensorShape{batch_size, C, M}, &out_tensor));
190 auto out_flat = out_tensor->flat<
float>();
191 float* out = &(out_flat(0));
193 Kernel(context, batch_size, C, N, M, inp, idx, weights, out);
201 const float* grad_out,
204 float* grad_points) = 0;
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int n, int m, const float *unknown, const float *known, float *dist2, int *idx)=0
ThreeNNOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: InterpolateOpKernel.h:32
Definition: InterpolateOpKernel.h:30
void Compute(tensorflow::OpKernelContext *context) override
Definition: InterpolateOpKernel.h:96
void Compute(tensorflow::OpKernelContext *context) override
Definition: InterpolateOpKernel.h:156
ImGuiContext * context
Definition: Window.cpp:95
int points
Definition: FilePCD.cpp:73
Definition: InterpolateOpKernel.h:90
ThreeInterpolateOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: InterpolateOpKernel.h:92
int M
Definition: InterpolateOpKernel.h:207
ThreeInterpolateGradOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: InterpolateOpKernel.h:150
void Compute(tensorflow::OpKernelContext *context) override
Definition: InterpolateOpKernel.h:35
Definition: InterpolateOpKernel.h:148
const char const char value recording_handle imu_sample recording_handle uint8_t data
Definition: K4aPlugin.cpp:274