Open3D (C++ API)  0.19.0
InterpolateOpKernel.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.open3d.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #pragma once
9 
10 #include "../TensorFlowHelper.h"
11 #include "tensorflow/core/framework/op.h"
12 #include "tensorflow/core/framework/op_kernel.h"
13 #include "tensorflow/core/lib/core/errors.h"
14 
15 class ThreeNNOpKernel : public tensorflow::OpKernel {
16 public:
17  explicit ThreeNNOpKernel(tensorflow::OpKernelConstruction* construction)
18  : OpKernel(construction) {}
19 
20  void Compute(tensorflow::OpKernelContext* context) override {
21  using namespace tensorflow;
22 
23  const Tensor& inp_tensor = context->input(0);
24  OP_REQUIRES(
25  context,
26  inp_tensor.dims() == 3 && inp_tensor.shape().dim_size(2) == 3,
27  absl::InvalidArgumentError(
28  "ThreeNN expects "
29  "(batch_size,num_points,3) inp shape"));
30  int batch_size = inp_tensor.shape().dim_size(0);
31  int pts_num_out = inp_tensor.shape().dim_size(1);
32  auto inp_flat = inp_tensor.flat<float>();
33  const float* inp = &(inp_flat(0));
34 
35  const Tensor& data_tensor = context->input(1);
36  OP_REQUIRES(
37  context,
38  data_tensor.dims() == 3 && data_tensor.shape().dim_size(2) == 3,
39  absl::InvalidArgumentError(
40  "ThreeNN expects "
41  "(batch_size,num_points,3) data shape"));
42  int pts_num_in = data_tensor.shape().dim_size(1);
43  auto data_flat = data_tensor.flat<float>();
44  const float* data = &(data_flat(0));
45 
46  Tensor* out_dist;
47  OP_REQUIRES_OK(
48  context,
49  context->allocate_output(
50  0, TensorShape{batch_size, pts_num_out, 3}, &out_dist));
51  auto out_flat0 = out_dist->flat<float>();
52  float* out0 = &(out_flat0(0));
53 
54  Tensor* out_idx;
55  OP_REQUIRES_OK(
56  context,
57  context->allocate_output(
58  1, TensorShape{batch_size, pts_num_out, 3}, &out_idx));
59  auto out_flat1 = out_idx->flat<int>();
60  int* out1 = &(out_flat1(0));
61 
62  Kernel(context, batch_size, pts_num_out, pts_num_in, inp, data, out0,
63  out1);
64  }
65 
66  virtual void Kernel(tensorflow::OpKernelContext* context,
67  int b,
68  int n,
69  int m,
70  const float* unknown,
71  const float* known,
72  float* dist2,
73  int* idx) = 0;
74 };
75 
76 class ThreeInterpolateOpKernel : public tensorflow::OpKernel {
77 public:
79  tensorflow::OpKernelConstruction* construction)
80  : OpKernel(construction) {}
81 
82  void Compute(tensorflow::OpKernelContext* context) override {
83  using namespace tensorflow;
84 
85  const Tensor& inp_tensor = context->input(0);
86  OP_REQUIRES(context, inp_tensor.dims() == 3,
87  absl::InvalidArgumentError(
88  "ThreeInterpolate expects "
89  "(batch_size,num_points,3) inp shape"));
90  int batch_size = inp_tensor.shape().dim_size(0);
91  int C = inp_tensor.shape().dim_size(1);
92  int M = inp_tensor.shape().dim_size(2);
93  auto inp_flat = inp_tensor.flat<float>();
94  const float* inp = &(inp_flat(0));
95 
96  const Tensor& idx_tensor = context->input(1);
97  OP_REQUIRES(context, idx_tensor.dims() == 3,
98  absl::InvalidArgumentError(
99  "ThreeInterpolate expects "
100  "(batch_size,num_points,3) idx shape"));
101  int N = idx_tensor.shape().dim_size(1);
102  auto idx_flat = idx_tensor.flat<int>();
103  const int* idx = &(idx_flat(0));
104 
105  const Tensor& weights_tensor = context->input(2);
106  OP_REQUIRES(context, weights_tensor.dims() == 3,
107  absl::InvalidArgumentError(
108  "ThreeInterpolate expects "
109  "(batch_size,num_points,3) weights shape"));
110  auto weights_flat = weights_tensor.flat<float>();
111  const float* weights = &(weights_flat(0));
112 
113  Tensor* out_tensor;
114  OP_REQUIRES_OK(context,
115  context->allocate_output(
116  0, TensorShape{batch_size, C, N}, &out_tensor));
117  auto out_flat = out_tensor->flat<float>();
118  float* out = &(out_flat(0));
119 
120  Kernel(context, batch_size, C, M, N, inp, idx, weights, out);
121  }
122 
123  virtual void Kernel(tensorflow::OpKernelContext* context,
124  int b,
125  int c,
126  int m,
127  int n,
128  const float* points,
129  const int* idx,
130  const float* weight,
131  float* out) = 0;
132 };
133 
134 class ThreeInterpolateGradOpKernel : public tensorflow::OpKernel {
135 public:
137  tensorflow::OpKernelConstruction* construction)
138  : OpKernel(construction) {
139  OP_REQUIRES_OK(construction, construction->GetAttr("M", &M));
140  }
141 
142  void Compute(tensorflow::OpKernelContext* context) override {
143  using namespace tensorflow;
144 
145  const Tensor& inp_tensor = context->input(0);
146  OP_REQUIRES(context, inp_tensor.dims() == 3,
147  absl::InvalidArgumentError(
148  "ThreeInterpolateGrad expects "
149  "(batch_size,num_points,3) inp shape"));
150  int batch_size = inp_tensor.shape().dim_size(0);
151  int C = inp_tensor.shape().dim_size(1);
152  int N = inp_tensor.shape().dim_size(2);
153  auto inp_flat = inp_tensor.flat<float>();
154  const float* inp = &(inp_flat(0));
155 
156  const Tensor& idx_tensor = context->input(1);
157  OP_REQUIRES(context, idx_tensor.dims() == 3,
158  absl::InvalidArgumentError(
159  "ThreeInterpolateGrad expects "
160  "(batch_size,num_points,3) idx shape"));
161  auto idx_flat = idx_tensor.flat<int>();
162  const int* idx = &(idx_flat(0));
163 
164  const Tensor& weights_tensor = context->input(2);
165  OP_REQUIRES(context, weights_tensor.dims() == 3,
166  absl::InvalidArgumentError(
167  "ThreeInterpolateGrad expects "
168  "(batch_size,num_points,3) weights shape"));
169  auto weights_flat = weights_tensor.flat<float>();
170  const float* weights = &(weights_flat(0));
171 
172  Tensor* out_tensor;
173  OP_REQUIRES_OK(context,
174  context->allocate_output(
175  0, TensorShape{batch_size, C, M}, &out_tensor));
176  auto out_flat = out_tensor->flat<float>();
177  float* out = &(out_flat(0));
178 
179  Kernel(context, batch_size, C, N, M, inp, idx, weights, out);
180  }
181 
182  virtual void Kernel(tensorflow::OpKernelContext* context,
183  int b,
184  int c,
185  int n,
186  int m,
187  const float* grad_out,
188  const int* idx,
189  const float* weight,
190  float* grad_points) = 0;
191 
192 protected:
193  int M;
194 };
Eigen::Matrix3Xd M
Definition: PointCloudPlanarPatchDetection.cpp:520
Real weight
Definition: SurfaceReconstructionPoisson.cpp:267
ImGuiContext * context
Definition: Window.cpp:76
Definition: InterpolateOpKernel.h:134
ThreeInterpolateGradOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: InterpolateOpKernel.h:136
int M
Definition: InterpolateOpKernel.h:193
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int c, int n, int m, const float *grad_out, const int *idx, const float *weight, float *grad_points)=0
void Compute(tensorflow::OpKernelContext *context) override
Definition: InterpolateOpKernel.h:142
Definition: InterpolateOpKernel.h:76
ThreeInterpolateOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: InterpolateOpKernel.h:78
void Compute(tensorflow::OpKernelContext *context) override
Definition: InterpolateOpKernel.h:82
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int c, int m, int n, const float *points, const int *idx, const float *weight, float *out)=0
Definition: InterpolateOpKernel.h:15
void Compute(tensorflow::OpKernelContext *context) override
Definition: InterpolateOpKernel.h:20
ThreeNNOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: InterpolateOpKernel.h:17
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int n, int m, const float *unknown, const float *known, float *dist2, int *idx)=0
int points
Definition: FilePCD.cpp:54
const char const char value recording_handle imu_sample recording_handle uint8_t data
Definition: K4aPlugin.cpp:269