Open3D (C++ API)  0.13.0
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
InterpolateOpKernel.h
Go to the documentation of this file.
1 // The MIT License (MIT)
2 //
3 // Copyright (c) 2020 www.open3d.org
4 //
5 // Permission is hereby granted, free of charge, to any person obtaining a copy
6 // of this software and associated documentation files (the "Software"), to deal
7 // in the Software without restriction, including without limitation the rights
8 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 // copies of the Software, and to permit persons to whom the Software is
10 // furnished to do so, subject to the following conditions:
11 //
12 // The above copyright notice and this permission notice shall be included in
13 // all copies or substantial portions of the Software.
14 //
15 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 // IN THE SOFTWARE.
22 // ----------------------------------------------------------------------------
23 #pragma once
24 
25 #include "../TensorFlowHelper.h"
26 #include "tensorflow/core/framework/op.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 
30 class ThreeNNOpKernel : public tensorflow::OpKernel {
31 public:
32  explicit ThreeNNOpKernel(tensorflow::OpKernelConstruction* construction)
33  : OpKernel(construction) {}
34 
35  void Compute(tensorflow::OpKernelContext* context) override {
36  using namespace tensorflow;
37 
38  const Tensor& inp_tensor = context->input(0);
39  OP_REQUIRES(
40  context,
41  inp_tensor.dims() == 3 && inp_tensor.shape().dim_size(2) == 3,
42  errors::InvalidArgument("ThreeNN expects "
43  "(batch_size,num_points,3) inp shape"));
44  int batch_size = inp_tensor.shape().dim_size(0);
45  int pts_num_out = inp_tensor.shape().dim_size(1);
46  auto inp_flat = inp_tensor.flat<float>();
47  const float* inp = &(inp_flat(0));
48 
49  const Tensor& data_tensor = context->input(1);
50  OP_REQUIRES(
51  context,
52  data_tensor.dims() == 3 && data_tensor.shape().dim_size(2) == 3,
53  errors::InvalidArgument(
54  "ThreeNN expects "
55  "(batch_size,num_points,3) data shape"));
56  int pts_num_in = data_tensor.shape().dim_size(1);
57  auto data_flat = data_tensor.flat<float>();
58  const float* data = &(data_flat(0));
59 
60  Tensor* out_dist;
61  OP_REQUIRES_OK(
62  context,
63  context->allocate_output(
64  0, TensorShape{batch_size, pts_num_out, 3}, &out_dist));
65  auto out_flat0 = out_dist->flat<float>();
66  float* out0 = &(out_flat0(0));
67 
68  Tensor* out_idx;
69  OP_REQUIRES_OK(
70  context,
71  context->allocate_output(
72  1, TensorShape{batch_size, pts_num_out, 3}, &out_idx));
73  auto out_flat1 = out_idx->flat<int>();
74  int* out1 = &(out_flat1(0));
75 
76  Kernel(context, batch_size, pts_num_out, pts_num_in, inp, data, out0,
77  out1);
78  }
79 
80  virtual void Kernel(tensorflow::OpKernelContext* context,
81  int b,
82  int n,
83  int m,
84  const float* unknown,
85  const float* known,
86  float* dist2,
87  int* idx) = 0;
88 };
89 
90 class ThreeInterpolateOpKernel : public tensorflow::OpKernel {
91 public:
93  tensorflow::OpKernelConstruction* construction)
94  : OpKernel(construction) {}
95 
96  void Compute(tensorflow::OpKernelContext* context) override {
97  using namespace tensorflow;
98 
99  const Tensor& inp_tensor = context->input(0);
100  OP_REQUIRES(
101  context, inp_tensor.dims() == 3,
102  errors::InvalidArgument("ThreeInterpolate expects "
103  "(batch_size,num_points,3) inp shape"));
104  int batch_size = inp_tensor.shape().dim_size(0);
105  int C = inp_tensor.shape().dim_size(1);
106  int M = inp_tensor.shape().dim_size(2);
107  auto inp_flat = inp_tensor.flat<float>();
108  const float* inp = &(inp_flat(0));
109 
110  const Tensor& idx_tensor = context->input(1);
111  OP_REQUIRES(
112  context, idx_tensor.dims() == 3,
113  errors::InvalidArgument("ThreeInterpolate expects "
114  "(batch_size,num_points,3) idx shape"));
115  int N = idx_tensor.shape().dim_size(1);
116  auto idx_flat = idx_tensor.flat<int>();
117  const int* idx = &(idx_flat(0));
118 
119  const Tensor& weights_tensor = context->input(2);
120  OP_REQUIRES(context, weights_tensor.dims() == 3,
121  errors::InvalidArgument(
122  "ThreeInterpolate expects "
123  "(batch_size,num_points,3) weights shape"));
124  auto weights_flat = weights_tensor.flat<float>();
125  const float* weights = &(weights_flat(0));
126 
127  Tensor* out_tensor;
128  OP_REQUIRES_OK(context,
129  context->allocate_output(
130  0, TensorShape{batch_size, C, N}, &out_tensor));
131  auto out_flat = out_tensor->flat<float>();
132  float* out = &(out_flat(0));
133 
134  Kernel(context, batch_size, C, M, N, inp, idx, weights, out);
135  }
136 
137  virtual void Kernel(tensorflow::OpKernelContext* context,
138  int b,
139  int c,
140  int m,
141  int n,
142  const float* points,
143  const int* idx,
144  const float* weight,
145  float* out) = 0;
146 };
147 
148 class ThreeInterpolateGradOpKernel : public tensorflow::OpKernel {
149 public:
151  tensorflow::OpKernelConstruction* construction)
152  : OpKernel(construction) {
153  OP_REQUIRES_OK(construction, construction->GetAttr("M", &M));
154  }
155 
156  void Compute(tensorflow::OpKernelContext* context) override {
157  using namespace tensorflow;
158 
159  const Tensor& inp_tensor = context->input(0);
160  OP_REQUIRES(
161  context, inp_tensor.dims() == 3,
162  errors::InvalidArgument("ThreeInterpolateGrad expects "
163  "(batch_size,num_points,3) inp shape"));
164  int batch_size = inp_tensor.shape().dim_size(0);
165  int C = inp_tensor.shape().dim_size(1);
166  int N = inp_tensor.shape().dim_size(2);
167  auto inp_flat = inp_tensor.flat<float>();
168  const float* inp = &(inp_flat(0));
169 
170  const Tensor& idx_tensor = context->input(1);
171  OP_REQUIRES(
172  context, idx_tensor.dims() == 3,
173  errors::InvalidArgument("ThreeInterpolateGrad expects "
174  "(batch_size,num_points,3) idx shape"));
175  auto idx_flat = idx_tensor.flat<int>();
176  const int* idx = &(idx_flat(0));
177 
178  const Tensor& weights_tensor = context->input(2);
179  OP_REQUIRES(context, weights_tensor.dims() == 3,
180  errors::InvalidArgument(
181  "ThreeInterpolateGrad expects "
182  "(batch_size,num_points,3) weights shape"));
183  auto weights_flat = weights_tensor.flat<float>();
184  const float* weights = &(weights_flat(0));
185 
186  Tensor* out_tensor;
187  OP_REQUIRES_OK(context,
188  context->allocate_output(
189  0, TensorShape{batch_size, C, M}, &out_tensor));
190  auto out_flat = out_tensor->flat<float>();
191  float* out = &(out_flat(0));
192 
193  Kernel(context, batch_size, C, N, M, inp, idx, weights, out);
194  }
195 
196  virtual void Kernel(tensorflow::OpKernelContext* context,
197  int b,
198  int c,
199  int n,
200  int m,
201  const float* grad_out,
202  const int* idx,
203  const float* weight,
204  float* grad_points) = 0;
205 
206 protected:
207  int M;
208 };
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int n, int m, const float *unknown, const float *known, float *dist2, int *idx)=0
ThreeNNOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: InterpolateOpKernel.h:32
Definition: InterpolateOpKernel.h:30
void Compute(tensorflow::OpKernelContext *context) override
Definition: InterpolateOpKernel.h:96
void Compute(tensorflow::OpKernelContext *context) override
Definition: InterpolateOpKernel.h:156
ImGuiContext * context
Definition: Window.cpp:95
int points
Definition: FilePCD.cpp:73
Definition: InterpolateOpKernel.h:90
ThreeInterpolateOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: InterpolateOpKernel.h:92
int M
Definition: InterpolateOpKernel.h:207
ThreeInterpolateGradOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: InterpolateOpKernel.h:150
void Compute(tensorflow::OpKernelContext *context) override
Definition: InterpolateOpKernel.h:35
Definition: InterpolateOpKernel.h:148
const char const char value recording_handle imu_sample recording_handle uint8_t data
Definition: K4aPlugin.cpp:274