29 #include "../TensorFlowHelper.h" 30 #include "tensorflow/core/framework/op.h" 31 #include "tensorflow/core/framework/op_kernel.h" 32 #include "tensorflow/core/lib/core/errors.h" 37 : OpKernel(construction) {}
42 const Tensor& inp_tensor = context->input(0);
45 inp_tensor.dims() == 3 && inp_tensor.shape().dim_size(2) == 3,
46 errors::InvalidArgument(
"ThreeNN expects " 47 "(batch_size,num_points,3) inp shape"));
48 int batch_size = inp_tensor.shape().dim_size(0);
49 int pts_num_out = inp_tensor.shape().dim_size(1);
50 auto inp_flat = inp_tensor.flat<
float>();
51 const float* inp = &(inp_flat(0));
53 const Tensor& data_tensor = context->input(1);
56 data_tensor.dims() == 3 && data_tensor.shape().dim_size(2) == 3,
57 errors::InvalidArgument(
59 "(batch_size,num_points,3) data shape"));
60 int pts_num_in = data_tensor.shape().dim_size(1);
61 auto data_flat = data_tensor.flat<
float>();
62 const float*
data = &(data_flat(0));
67 context->allocate_output(
68 0, TensorShape{batch_size, pts_num_out, 3}, &out_dist));
69 auto out_flat0 = out_dist->flat<
float>();
70 float* out0 = &(out_flat0(0));
75 context->allocate_output(
76 1, TensorShape{batch_size, pts_num_out, 3}, &out_idx));
77 auto out_flat1 = out_idx->flat<
int>();
78 int* out1 = &(out_flat1(0));
80 Kernel(context, batch_size, pts_num_out, pts_num_in, inp, data, out0,
97 tensorflow::OpKernelConstruction* construction)
98 : OpKernel(construction) {}
103 const Tensor& inp_tensor = context->input(0);
105 context, inp_tensor.dims() == 3,
106 errors::InvalidArgument(
"ThreeInterpolate expects " 107 "(batch_size,num_points,3) inp shape"));
108 int batch_size = inp_tensor.shape().dim_size(0);
109 int C = inp_tensor.shape().dim_size(1);
110 int M = inp_tensor.shape().dim_size(2);
111 auto inp_flat = inp_tensor.flat<
float>();
112 const float* inp = &(inp_flat(0));
114 const Tensor& idx_tensor = context->input(1);
116 context, idx_tensor.dims() == 3,
117 errors::InvalidArgument(
"ThreeInterpolate expects " 118 "(batch_size,num_points,3) idx shape"));
119 int N = idx_tensor.shape().dim_size(1);
120 auto idx_flat = idx_tensor.flat<
int>();
121 const int* idx = &(idx_flat(0));
123 const Tensor& weights_tensor = context->input(2);
124 OP_REQUIRES(context, weights_tensor.dims() == 3,
125 errors::InvalidArgument(
126 "ThreeInterpolate expects " 127 "(batch_size,num_points,3) weights shape"));
128 auto weights_flat = weights_tensor.flat<
float>();
129 const float* weights = &(weights_flat(0));
132 OP_REQUIRES_OK(context,
133 context->allocate_output(
134 0, TensorShape{batch_size, C, N}, &out_tensor));
135 auto out_flat = out_tensor->flat<
float>();
136 float* out = &(out_flat(0));
138 Kernel(context, batch_size, C, M, N, inp, idx, weights, out);
155 tensorflow::OpKernelConstruction* construction)
156 : OpKernel(construction) {
157 OP_REQUIRES_OK(construction, construction->GetAttr(
"M", &M));
163 const Tensor& inp_tensor = context->input(0);
165 context, inp_tensor.dims() == 3,
166 errors::InvalidArgument(
"ThreeInterpolateGrad expects " 167 "(batch_size,num_points,3) inp shape"));
168 int batch_size = inp_tensor.shape().dim_size(0);
169 int C = inp_tensor.shape().dim_size(1);
170 int N = inp_tensor.shape().dim_size(2);
171 auto inp_flat = inp_tensor.flat<
float>();
172 const float* inp = &(inp_flat(0));
174 const Tensor& idx_tensor = context->input(1);
176 context, idx_tensor.dims() == 3,
177 errors::InvalidArgument(
"ThreeInterpolateGrad expects " 178 "(batch_size,num_points,3) idx shape"));
179 auto idx_flat = idx_tensor.flat<
int>();
180 const int* idx = &(idx_flat(0));
182 const Tensor& weights_tensor = context->input(2);
183 OP_REQUIRES(context, weights_tensor.dims() == 3,
184 errors::InvalidArgument(
185 "ThreeInterpolateGrad expects " 186 "(batch_size,num_points,3) weights shape"));
187 auto weights_flat = weights_tensor.flat<
float>();
188 const float* weights = &(weights_flat(0));
191 OP_REQUIRES_OK(context,
192 context->allocate_output(
193 0, TensorShape{batch_size, C, M}, &out_tensor));
194 auto out_flat = out_tensor->flat<
float>();
195 float* out = &(out_flat(0));
197 Kernel(context, batch_size, C, N, M, inp, idx, weights, out);
205 const float* grad_out,
208 float* grad_points) = 0;
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int n, int m, const float *unknown, const float *known, float *dist2, int *idx)=0
ThreeNNOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: InterpolateOpKernel.h:36
Definition: InterpolateOpKernel.h:34
void Compute(tensorflow::OpKernelContext *context) override
Definition: InterpolateOpKernel.h:100
void Compute(tensorflow::OpKernelContext *context) override
Definition: InterpolateOpKernel.h:160
ImGuiContext * context
Definition: Window.cpp:95
Definition: InterpolateOpKernel.h:94
ThreeInterpolateOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: InterpolateOpKernel.h:96
int M
Definition: InterpolateOpKernel.h:211
ThreeInterpolateGradOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: InterpolateOpKernel.h:154
void Compute(tensorflow::OpKernelContext *context) override
Definition: InterpolateOpKernel.h:39
Definition: InterpolateOpKernel.h:152
const char const char value recording_handle imu_sample recording_handle uint8_t data
Definition: K4aPlugin.cpp:274