10 #include "../TensorFlowHelper.h"
11 #include "tensorflow/core/framework/op.h"
12 #include "tensorflow/core/framework/op_kernel.h"
13 #include "tensorflow/core/lib/core/errors.h"
18 : OpKernel(construction) {}
21 using namespace tensorflow;
23 const Tensor& inp_tensor =
context->input(0);
26 inp_tensor.dims() == 3 && inp_tensor.shape().dim_size(2) == 3,
27 absl::InvalidArgumentError(
29 "(batch_size,num_points,3) inp shape"));
30 int batch_size = inp_tensor.shape().dim_size(0);
31 int pts_num_out = inp_tensor.shape().dim_size(1);
32 auto inp_flat = inp_tensor.flat<
float>();
33 const float* inp = &(inp_flat(0));
35 const Tensor& data_tensor =
context->input(1);
38 data_tensor.dims() == 3 && data_tensor.shape().dim_size(2) == 3,
39 absl::InvalidArgumentError(
41 "(batch_size,num_points,3) data shape"));
42 int pts_num_in = data_tensor.shape().dim_size(1);
43 auto data_flat = data_tensor.flat<
float>();
44 const float*
data = &(data_flat(0));
50 0, TensorShape{batch_size, pts_num_out, 3}, &out_dist));
51 auto out_flat0 = out_dist->flat<
float>();
52 float* out0 = &(out_flat0(0));
58 1, TensorShape{batch_size, pts_num_out, 3}, &out_idx));
59 auto out_flat1 = out_idx->flat<
int>();
60 int* out1 = &(out_flat1(0));
79 tensorflow::OpKernelConstruction* construction)
80 : OpKernel(construction) {}
83 using namespace tensorflow;
85 const Tensor& inp_tensor =
context->input(0);
86 OP_REQUIRES(
context, inp_tensor.dims() == 3,
87 absl::InvalidArgumentError(
88 "ThreeInterpolate expects "
89 "(batch_size,num_points,3) inp shape"));
90 int batch_size = inp_tensor.shape().dim_size(0);
91 int C = inp_tensor.shape().dim_size(1);
92 int M = inp_tensor.shape().dim_size(2);
93 auto inp_flat = inp_tensor.flat<
float>();
94 const float* inp = &(inp_flat(0));
96 const Tensor& idx_tensor =
context->input(1);
97 OP_REQUIRES(
context, idx_tensor.dims() == 3,
98 absl::InvalidArgumentError(
99 "ThreeInterpolate expects "
100 "(batch_size,num_points,3) idx shape"));
101 int N = idx_tensor.shape().dim_size(1);
102 auto idx_flat = idx_tensor.flat<
int>();
103 const int* idx = &(idx_flat(0));
105 const Tensor& weights_tensor =
context->input(2);
106 OP_REQUIRES(
context, weights_tensor.dims() == 3,
107 absl::InvalidArgumentError(
108 "ThreeInterpolate expects "
109 "(batch_size,num_points,3) weights shape"));
110 auto weights_flat = weights_tensor.flat<
float>();
111 const float* weights = &(weights_flat(0));
116 0, TensorShape{batch_size, C, N}, &out_tensor));
117 auto out_flat = out_tensor->flat<
float>();
118 float* out = &(out_flat(0));
137 tensorflow::OpKernelConstruction* construction)
138 : OpKernel(construction) {
139 OP_REQUIRES_OK(construction, construction->GetAttr(
"M", &
M));
143 using namespace tensorflow;
145 const Tensor& inp_tensor =
context->input(0);
146 OP_REQUIRES(
context, inp_tensor.dims() == 3,
147 absl::InvalidArgumentError(
148 "ThreeInterpolateGrad expects "
149 "(batch_size,num_points,3) inp shape"));
150 int batch_size = inp_tensor.shape().dim_size(0);
151 int C = inp_tensor.shape().dim_size(1);
152 int N = inp_tensor.shape().dim_size(2);
153 auto inp_flat = inp_tensor.flat<
float>();
154 const float* inp = &(inp_flat(0));
156 const Tensor& idx_tensor =
context->input(1);
157 OP_REQUIRES(
context, idx_tensor.dims() == 3,
158 absl::InvalidArgumentError(
159 "ThreeInterpolateGrad expects "
160 "(batch_size,num_points,3) idx shape"));
161 auto idx_flat = idx_tensor.flat<
int>();
162 const int* idx = &(idx_flat(0));
164 const Tensor& weights_tensor =
context->input(2);
165 OP_REQUIRES(
context, weights_tensor.dims() == 3,
166 absl::InvalidArgumentError(
167 "ThreeInterpolateGrad expects "
168 "(batch_size,num_points,3) weights shape"));
169 auto weights_flat = weights_tensor.flat<
float>();
170 const float* weights = &(weights_flat(0));
175 0, TensorShape{batch_size, C, M}, &out_tensor));
176 auto out_flat = out_tensor->flat<
float>();
177 float* out = &(out_flat(0));
187 const float* grad_out,
190 float* grad_points) = 0;
Eigen::Matrix3Xd M
Definition: PointCloudPlanarPatchDetection.cpp:520
Real weight
Definition: SurfaceReconstructionPoisson.cpp:267
ImGuiContext * context
Definition: Window.cpp:76
Definition: InterpolateOpKernel.h:134
ThreeInterpolateGradOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: InterpolateOpKernel.h:136
int M
Definition: InterpolateOpKernel.h:193
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int c, int n, int m, const float *grad_out, const int *idx, const float *weight, float *grad_points)=0
void Compute(tensorflow::OpKernelContext *context) override
Definition: InterpolateOpKernel.h:142
Definition: InterpolateOpKernel.h:76
ThreeInterpolateOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: InterpolateOpKernel.h:78
void Compute(tensorflow::OpKernelContext *context) override
Definition: InterpolateOpKernel.h:82
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int c, int m, int n, const float *points, const int *idx, const float *weight, float *out)=0
Definition: InterpolateOpKernel.h:15
void Compute(tensorflow::OpKernelContext *context) override
Definition: InterpolateOpKernel.h:20
ThreeNNOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: InterpolateOpKernel.h:17
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int n, int m, const float *unknown, const float *known, float *dist2, int *idx)=0
const char const char value recording_handle imu_sample recording_handle uint8_t data
Definition: K4aPlugin.cpp:269