Loading [MathJax]/extensions/TeX/AMSsymbols.js
Open3D (C++ API)  0.14.1
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
InterpolateOpKernel.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // The MIT License (MIT)
5 //
6 // Copyright (c) 2018-2021 www.open3d.org
7 //
8 // Permission is hereby granted, free of charge, to any person obtaining a copy
9 // of this software and associated documentation files (the "Software"), to deal
10 // in the Software without restriction, including without limitation the rights
11 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 // copies of the Software, and to permit persons to whom the Software is
13 // furnished to do so, subject to the following conditions:
14 //
15 // The above copyright notice and this permission notice shall be included in
16 // all copies or substantial portions of the Software.
17 //
18 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
23 // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
24 // IN THE SOFTWARE.
25 // ----------------------------------------------------------------------------
26 
27 #pragma once
28 
29 #include "../TensorFlowHelper.h"
30 #include "tensorflow/core/framework/op.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 
34 class ThreeNNOpKernel : public tensorflow::OpKernel {
35 public:
36  explicit ThreeNNOpKernel(tensorflow::OpKernelConstruction* construction)
37  : OpKernel(construction) {}
38 
39  void Compute(tensorflow::OpKernelContext* context) override {
40  using namespace tensorflow;
41 
42  const Tensor& inp_tensor = context->input(0);
43  OP_REQUIRES(
44  context,
45  inp_tensor.dims() == 3 && inp_tensor.shape().dim_size(2) == 3,
46  errors::InvalidArgument("ThreeNN expects "
47  "(batch_size,num_points,3) inp shape"));
48  int batch_size = inp_tensor.shape().dim_size(0);
49  int pts_num_out = inp_tensor.shape().dim_size(1);
50  auto inp_flat = inp_tensor.flat<float>();
51  const float* inp = &(inp_flat(0));
52 
53  const Tensor& data_tensor = context->input(1);
54  OP_REQUIRES(
55  context,
56  data_tensor.dims() == 3 && data_tensor.shape().dim_size(2) == 3,
57  errors::InvalidArgument(
58  "ThreeNN expects "
59  "(batch_size,num_points,3) data shape"));
60  int pts_num_in = data_tensor.shape().dim_size(1);
61  auto data_flat = data_tensor.flat<float>();
62  const float* data = &(data_flat(0));
63 
64  Tensor* out_dist;
65  OP_REQUIRES_OK(
66  context,
67  context->allocate_output(
68  0, TensorShape{batch_size, pts_num_out, 3}, &out_dist));
69  auto out_flat0 = out_dist->flat<float>();
70  float* out0 = &(out_flat0(0));
71 
72  Tensor* out_idx;
73  OP_REQUIRES_OK(
74  context,
75  context->allocate_output(
76  1, TensorShape{batch_size, pts_num_out, 3}, &out_idx));
77  auto out_flat1 = out_idx->flat<int>();
78  int* out1 = &(out_flat1(0));
79 
80  Kernel(context, batch_size, pts_num_out, pts_num_in, inp, data, out0,
81  out1);
82  }
83 
84  virtual void Kernel(tensorflow::OpKernelContext* context,
85  int b,
86  int n,
87  int m,
88  const float* unknown,
89  const float* known,
90  float* dist2,
91  int* idx) = 0;
92 };
93 
94 class ThreeInterpolateOpKernel : public tensorflow::OpKernel {
95 public:
97  tensorflow::OpKernelConstruction* construction)
98  : OpKernel(construction) {}
99 
100  void Compute(tensorflow::OpKernelContext* context) override {
101  using namespace tensorflow;
102 
103  const Tensor& inp_tensor = context->input(0);
104  OP_REQUIRES(
105  context, inp_tensor.dims() == 3,
106  errors::InvalidArgument("ThreeInterpolate expects "
107  "(batch_size,num_points,3) inp shape"));
108  int batch_size = inp_tensor.shape().dim_size(0);
109  int C = inp_tensor.shape().dim_size(1);
110  int M = inp_tensor.shape().dim_size(2);
111  auto inp_flat = inp_tensor.flat<float>();
112  const float* inp = &(inp_flat(0));
113 
114  const Tensor& idx_tensor = context->input(1);
115  OP_REQUIRES(
116  context, idx_tensor.dims() == 3,
117  errors::InvalidArgument("ThreeInterpolate expects "
118  "(batch_size,num_points,3) idx shape"));
119  int N = idx_tensor.shape().dim_size(1);
120  auto idx_flat = idx_tensor.flat<int>();
121  const int* idx = &(idx_flat(0));
122 
123  const Tensor& weights_tensor = context->input(2);
124  OP_REQUIRES(context, weights_tensor.dims() == 3,
125  errors::InvalidArgument(
126  "ThreeInterpolate expects "
127  "(batch_size,num_points,3) weights shape"));
128  auto weights_flat = weights_tensor.flat<float>();
129  const float* weights = &(weights_flat(0));
130 
131  Tensor* out_tensor;
132  OP_REQUIRES_OK(context,
133  context->allocate_output(
134  0, TensorShape{batch_size, C, N}, &out_tensor));
135  auto out_flat = out_tensor->flat<float>();
136  float* out = &(out_flat(0));
137 
138  Kernel(context, batch_size, C, M, N, inp, idx, weights, out);
139  }
140 
141  virtual void Kernel(tensorflow::OpKernelContext* context,
142  int b,
143  int c,
144  int m,
145  int n,
146  const float* points,
147  const int* idx,
148  const float* weight,
149  float* out) = 0;
150 };
151 
152 class ThreeInterpolateGradOpKernel : public tensorflow::OpKernel {
153 public:
155  tensorflow::OpKernelConstruction* construction)
156  : OpKernel(construction) {
157  OP_REQUIRES_OK(construction, construction->GetAttr("M", &M));
158  }
159 
160  void Compute(tensorflow::OpKernelContext* context) override {
161  using namespace tensorflow;
162 
163  const Tensor& inp_tensor = context->input(0);
164  OP_REQUIRES(
165  context, inp_tensor.dims() == 3,
166  errors::InvalidArgument("ThreeInterpolateGrad expects "
167  "(batch_size,num_points,3) inp shape"));
168  int batch_size = inp_tensor.shape().dim_size(0);
169  int C = inp_tensor.shape().dim_size(1);
170  int N = inp_tensor.shape().dim_size(2);
171  auto inp_flat = inp_tensor.flat<float>();
172  const float* inp = &(inp_flat(0));
173 
174  const Tensor& idx_tensor = context->input(1);
175  OP_REQUIRES(
176  context, idx_tensor.dims() == 3,
177  errors::InvalidArgument("ThreeInterpolateGrad expects "
178  "(batch_size,num_points,3) idx shape"));
179  auto idx_flat = idx_tensor.flat<int>();
180  const int* idx = &(idx_flat(0));
181 
182  const Tensor& weights_tensor = context->input(2);
183  OP_REQUIRES(context, weights_tensor.dims() == 3,
184  errors::InvalidArgument(
185  "ThreeInterpolateGrad expects "
186  "(batch_size,num_points,3) weights shape"));
187  auto weights_flat = weights_tensor.flat<float>();
188  const float* weights = &(weights_flat(0));
189 
190  Tensor* out_tensor;
191  OP_REQUIRES_OK(context,
192  context->allocate_output(
193  0, TensorShape{batch_size, C, M}, &out_tensor));
194  auto out_flat = out_tensor->flat<float>();
195  float* out = &(out_flat(0));
196 
197  Kernel(context, batch_size, C, N, M, inp, idx, weights, out);
198  }
199 
200  virtual void Kernel(tensorflow::OpKernelContext* context,
201  int b,
202  int c,
203  int n,
204  int m,
205  const float* grad_out,
206  const int* idx,
207  const float* weight,
208  float* grad_points) = 0;
209 
210 protected:
211  int M;
212 };
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int n, int m, const float *unknown, const float *known, float *dist2, int *idx)=0
ThreeNNOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: InterpolateOpKernel.h:36
Definition: InterpolateOpKernel.h:34
int points
Definition: FilePCD.cpp:73
void Compute(tensorflow::OpKernelContext *context) override
Definition: InterpolateOpKernel.h:100
void Compute(tensorflow::OpKernelContext *context) override
Definition: InterpolateOpKernel.h:160
ImGuiContext * context
Definition: Window.cpp:95
Definition: InterpolateOpKernel.h:94
ThreeInterpolateOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: InterpolateOpKernel.h:96
int M
Definition: InterpolateOpKernel.h:211
ThreeInterpolateGradOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: InterpolateOpKernel.h:154
void Compute(tensorflow::OpKernelContext *context) override
Definition: InterpolateOpKernel.h:39
Definition: InterpolateOpKernel.h:152
const char const char value recording_handle imu_sample recording_handle uint8_t data
Definition: K4aPlugin.cpp:274