Open3D (C++ API)  0.19.0
TrilinearDevoxelizeKernel.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.open3d.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #pragma once
9 
10 #include "../TensorFlowHelper.h"
11 #include "tensorflow/core/framework/op.h"
12 #include "tensorflow/core/framework/op_kernel.h"
13 #include "tensorflow/core/lib/core/errors.h"
14 
15 class TrilinearDevoxelizeOpKernel : public tensorflow::OpKernel {
16 public:
18  tensorflow::OpKernelConstruction* context)
19  : tensorflow::OpKernel(context) {
20  using namespace tensorflow;
21  OP_REQUIRES_OK(context, context->GetAttr("resolution", &r));
22  OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training));
23  OP_REQUIRES(context, r > 0,
24  absl::InvalidArgumentError(
25  "TrilinearDevoxelize expects positive resolution"));
26  }
27 
28  void Compute(tensorflow::OpKernelContext* context) override {
29  using namespace tensorflow;
30  const Tensor& coords = context->input(0);
31  OP_REQUIRES(context,
32  coords.dims() == 3 && coords.shape().dim_size(1) == 3,
33  absl::InvalidArgumentError(
34  "TrilinearDevoxelize expects "
35  "(batch_size, 3, N) coordinate shape"));
36  const Tensor& feat = context->input(1);
37  OP_REQUIRES(context, feat.dims() == 5,
38  absl::InvalidArgumentError("TrilinearDevoxelize expects "
39  "5 dimensions for features"));
40 
41  int batch_size = coords.shape().dim_size(0);
42  int num_points = coords.shape().dim_size(2);
43  int feat_dim = feat.shape().dim_size(1);
44 
45  auto coords_flat = coords.flat<float>();
46  auto feat_flat = feat.flat<float>();
47 
48  const float* inp_coords = &(coords_flat(0));
49  const float* inp_feat = &(feat_flat(0));
50 
51  Tensor* out_tensor_0;
52  OP_REQUIRES_OK(context,
53  context->allocate_output(
54  0, TensorShape{batch_size, feat_dim, num_points},
55  &out_tensor_0));
56  Tensor* out_tensor_1;
57  OP_REQUIRES_OK(context,
58  context->allocate_output(
59  1, TensorShape{batch_size, 8, num_points},
60  &out_tensor_1));
61  Tensor* out_tensor_2;
62  OP_REQUIRES_OK(context,
63  context->allocate_output(
64  2, TensorShape{batch_size, 8, num_points},
65  &out_tensor_2));
66  auto flat_0 = out_tensor_0->flat<float>();
67  auto flat_1 = out_tensor_1->flat<int>();
68  auto flat_2 = out_tensor_2->flat<float>();
69 
70  float* out_0 = &(flat_0(0));
71  int* out_1 = &(flat_1(0));
72  float* out_2 = &(flat_2(0));
73 
74  if (is_training)
75  Kernel(context, batch_size, feat_dim, num_points, r, r * r,
76  r * r * r, true, inp_coords, inp_feat, out_1, out_2, out_0);
77  else
78  Kernel(context, batch_size, feat_dim, num_points, r, r * r,
79  r * r * r, false, inp_coords, inp_feat, out_1, out_2, out_0);
80  }
81 
82  virtual void Kernel(tensorflow::OpKernelContext* context,
83  int b,
84  int c,
85  int n,
86  int r,
87  int r2,
88  int r3,
89  bool training,
90  const float* coords,
91  const float* feat,
92  int* inds,
93  float* wgts,
94  float* outs) = 0;
95 
96 protected:
97  int r;
99 };
100 
101 class TrilinearDevoxelizeGradOpKernel : public tensorflow::OpKernel {
102 public:
104  tensorflow::OpKernelConstruction* context)
105  : tensorflow::OpKernel(context) {
106  using namespace tensorflow;
107  OP_REQUIRES_OK(context, context->GetAttr("resolution", &r));
108  OP_REQUIRES(
109  context, r > 0,
110  absl::InvalidArgumentError(
111  "TrilinearDevoxelizeGrad expects positive resolution"));
112  }
113 
114  void Compute(tensorflow::OpKernelContext* context) override {
115  using namespace tensorflow;
116  const Tensor& grad_y = context->input(0);
117  OP_REQUIRES(context, grad_y.dims() == 3,
118  absl::InvalidArgumentError(
119  "TrilinearDevoxelizeGrad expects "
120  "(batch_size, C, N) gradient shape"));
121  const Tensor& inds = context->input(1);
122  OP_REQUIRES(
123  context, inds.dims() == 3 && inds.shape().dim_size(1) == 8,
124  absl::InvalidArgumentError("TrilinearDevoxelizeGrad expects "
125  "(batch_size, 8, N) indices shape"));
126  const Tensor& wgts = context->input(2);
127  OP_REQUIRES(
128  context, wgts.dims() == 3 && wgts.shape().dim_size(1) == 8,
129  absl::InvalidArgumentError("TrilinearDevoxelizeGrad expects "
130  "(batch_size, 8, N) weights shape"));
131 
132  int batch_size = grad_y.shape().dim_size(0);
133  int num_points = grad_y.shape().dim_size(2);
134  int feat_dim = grad_y.shape().dim_size(1);
135 
136  auto grad_y_flat = grad_y.flat<float>();
137  auto inds_flat = inds.flat<int>();
138  auto wgts_flat = wgts.flat<float>();
139 
140  const float* inp_grad_y = &(grad_y_flat(0));
141  const int* inp_inds = &(inds_flat(0));
142  const float* inp_wgts = &(wgts_flat(0));
143 
144  Tensor* out_tensor;
145  OP_REQUIRES_OK(context,
146  context->allocate_output(
147  0, TensorShape{batch_size, feat_dim, r, r, r},
148  &out_tensor));
149  auto flat_tensor = out_tensor->flat<float>();
150 
151  float* out = &(flat_tensor(0));
152 
153  Kernel(context, batch_size, feat_dim, num_points, r * r * r, inp_inds,
154  inp_wgts, inp_grad_y, out);
155  }
156 
157  virtual void Kernel(tensorflow::OpKernelContext* context,
158  int b,
159  int c,
160  int n,
161  int r3,
162  const int* inds,
163  const float* wgts,
164  const float* grad_y,
165  float* grad_x) = 0;
166 
167 protected:
168  int r;
169 };
ImGuiContext * context
Definition: Window.cpp:76
Definition: TrilinearDevoxelizeKernel.h:101
int r
Definition: TrilinearDevoxelizeKernel.h:168
void Compute(tensorflow::OpKernelContext *context) override
Definition: TrilinearDevoxelizeKernel.h:114
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int c, int n, int r3, const int *inds, const float *wgts, const float *grad_y, float *grad_x)=0
TrilinearDevoxelizeGradOpKernel(tensorflow::OpKernelConstruction *context)
Definition: TrilinearDevoxelizeKernel.h:103
Definition: TrilinearDevoxelizeKernel.h:15
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int c, int n, int r, int r2, int r3, bool training, const float *coords, const float *feat, int *inds, float *wgts, float *outs)=0
TrilinearDevoxelizeOpKernel(tensorflow::OpKernelConstruction *context)
Definition: TrilinearDevoxelizeKernel.h:17
void Compute(tensorflow::OpKernelContext *context) override
Definition: TrilinearDevoxelizeKernel.h:28
int r
Definition: TrilinearDevoxelizeKernel.h:97
bool is_training
Definition: TrilinearDevoxelizeKernel.h:98