Loading [MathJax]/extensions/TeX/AMSsymbols.js
Open3D (C++ API)  0.14.1
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
TrilinearDevoxelizeKernel.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // The MIT License (MIT)
5 //
6 // Copyright (c) 2018-2021 www.open3d.org
7 //
8 // Permission is hereby granted, free of charge, to any person obtaining a copy
9 // of this software and associated documentation files (the "Software"), to deal
10 // in the Software without restriction, including without limitation the rights
11 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 // copies of the Software, and to permit persons to whom the Software is
13 // furnished to do so, subject to the following conditions:
14 //
15 // The above copyright notice and this permission notice shall be included in
16 // all copies or substantial portions of the Software.
17 //
18 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
23 // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
24 // IN THE SOFTWARE.
25 // ----------------------------------------------------------------------------
26 
27 #pragma once
28 
29 #include "../TensorFlowHelper.h"
30 #include "tensorflow/core/framework/op.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 
34 class TrilinearDevoxelizeOpKernel : public tensorflow::OpKernel {
35 public:
37  tensorflow::OpKernelConstruction* context)
38  : tensorflow::OpKernel(context) {
39  using namespace tensorflow;
40  OP_REQUIRES_OK(context, context->GetAttr("resolution", &r));
41  OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training));
42  OP_REQUIRES(context, r > 0,
43  errors::InvalidArgument(
44  "TrilinearDevoxelize expects positive resolution"));
45  }
46 
47  void Compute(tensorflow::OpKernelContext* context) override {
48  using namespace tensorflow;
49  const Tensor& coords = context->input(0);
50  OP_REQUIRES(
51  context, coords.dims() == 3 && coords.shape().dim_size(1) == 3,
52  errors::InvalidArgument("TrilinearDevoxelize expects "
53  "(batch_size, 3, N) coordinate shape"));
54  const Tensor& feat = context->input(1);
55  OP_REQUIRES(context, feat.dims() == 5,
56  errors::InvalidArgument("TrilinearDevoxelize expects "
57  "5 dimensions for features"));
58 
59  int batch_size = coords.shape().dim_size(0);
60  int num_points = coords.shape().dim_size(2);
61  int feat_dim = feat.shape().dim_size(1);
62 
63  auto coords_flat = coords.flat<float>();
64  auto feat_flat = feat.flat<float>();
65 
66  const float* inp_coords = &(coords_flat(0));
67  const float* inp_feat = &(feat_flat(0));
68 
69  Tensor* out_tensor_0;
70  OP_REQUIRES_OK(context,
71  context->allocate_output(
72  0, TensorShape{batch_size, feat_dim, num_points},
73  &out_tensor_0));
74  Tensor* out_tensor_1;
75  OP_REQUIRES_OK(context,
76  context->allocate_output(
77  1, TensorShape{batch_size, 8, num_points},
78  &out_tensor_1));
79  Tensor* out_tensor_2;
80  OP_REQUIRES_OK(context,
81  context->allocate_output(
82  2, TensorShape{batch_size, 8, num_points},
83  &out_tensor_2));
84  auto flat_0 = out_tensor_0->flat<float>();
85  auto flat_1 = out_tensor_1->flat<int>();
86  auto flat_2 = out_tensor_2->flat<float>();
87 
88  float* out_0 = &(flat_0(0));
89  int* out_1 = &(flat_1(0));
90  float* out_2 = &(flat_2(0));
91 
92  if (is_training)
93  Kernel(context, batch_size, feat_dim, num_points, r, r * r,
94  r * r * r, true, inp_coords, inp_feat, out_1, out_2, out_0);
95  else
96  Kernel(context, batch_size, feat_dim, num_points, r, r * r,
97  r * r * r, false, inp_coords, inp_feat, out_1, out_2, out_0);
98  }
99 
100  virtual void Kernel(tensorflow::OpKernelContext* context,
101  int b,
102  int c,
103  int n,
104  int r,
105  int r2,
106  int r3,
107  bool training,
108  const float* coords,
109  const float* feat,
110  int* inds,
111  float* wgts,
112  float* outs) = 0;
113 
114 protected:
115  int r;
117 };
118 
119 class TrilinearDevoxelizeGradOpKernel : public tensorflow::OpKernel {
120 public:
122  tensorflow::OpKernelConstruction* context)
123  : tensorflow::OpKernel(context) {
124  using namespace tensorflow;
125  OP_REQUIRES_OK(context, context->GetAttr("resolution", &r));
126  OP_REQUIRES(
127  context, r > 0,
128  errors::InvalidArgument(
129  "TrilinearDevoxelizeGrad expects positive resolution"));
130  }
131 
132  void Compute(tensorflow::OpKernelContext* context) override {
133  using namespace tensorflow;
134  const Tensor& grad_y = context->input(0);
135  OP_REQUIRES(
136  context, grad_y.dims() == 3,
137  errors::InvalidArgument("TrilinearDevoxelizeGrad expects "
138  "(batch_size, C, N) gradient shape"));
139  const Tensor& inds = context->input(1);
140  OP_REQUIRES(
141  context, inds.dims() == 3 && inds.shape().dim_size(1) == 8,
142  errors::InvalidArgument("TrilinearDevoxelizeGrad expects "
143  "(batch_size, 8, N) indices shape"));
144  const Tensor& wgts = context->input(2);
145  OP_REQUIRES(
146  context, wgts.dims() == 3 && wgts.shape().dim_size(1) == 8,
147  errors::InvalidArgument("TrilinearDevoxelizeGrad expects "
148  "(batch_size, 8, N) weights shape"));
149 
150  int batch_size = grad_y.shape().dim_size(0);
151  int num_points = grad_y.shape().dim_size(2);
152  int feat_dim = grad_y.shape().dim_size(1);
153 
154  auto grad_y_flat = grad_y.flat<float>();
155  auto inds_flat = inds.flat<int>();
156  auto wgts_flat = wgts.flat<float>();
157 
158  const float* inp_grad_y = &(grad_y_flat(0));
159  const int* inp_inds = &(inds_flat(0));
160  const float* inp_wgts = &(wgts_flat(0));
161 
162  Tensor* out_tensor;
163  OP_REQUIRES_OK(context,
164  context->allocate_output(
165  0, TensorShape{batch_size, feat_dim, r, r, r},
166  &out_tensor));
167  auto flat_tensor = out_tensor->flat<float>();
168 
169  float* out = &(flat_tensor(0));
170 
171  Kernel(context, batch_size, feat_dim, num_points, r * r * r, inp_inds,
172  inp_wgts, inp_grad_y, out);
173  }
174 
175  virtual void Kernel(tensorflow::OpKernelContext* context,
176  int b,
177  int c,
178  int n,
179  int r3,
180  const int* inds,
181  const float* wgts,
182  const float* grad_y,
183  float* grad_x) = 0;
184 
185 protected:
186  int r;
187 };
bool is_training
Definition: TrilinearDevoxelizeKernel.h:116
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int c, int n, int r, int r2, int r3, bool training, const float *coords, const float *feat, int *inds, float *wgts, float *outs)=0
TrilinearDevoxelizeGradOpKernel(tensorflow::OpKernelConstruction *context)
Definition: TrilinearDevoxelizeKernel.h:121
Definition: TrilinearDevoxelizeKernel.h:34
void Compute(tensorflow::OpKernelContext *context) override
Definition: TrilinearDevoxelizeKernel.h:132
int r
Definition: TrilinearDevoxelizeKernel.h:186
ImGuiContext * context
Definition: Window.cpp:95
Definition: TrilinearDevoxelizeKernel.h:119
TrilinearDevoxelizeOpKernel(tensorflow::OpKernelConstruction *context)
Definition: TrilinearDevoxelizeKernel.h:36
int r
Definition: TrilinearDevoxelizeKernel.h:115
void Compute(tensorflow::OpKernelContext *context) override
Definition: TrilinearDevoxelizeKernel.h:47