Open3D (C++ API)  0.19.0
VoxelPoolingGradOpKernel.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.open3d.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #pragma once
9 
10 #include <cstdint>
11 
12 #include "absl/status/status.h"
14 #include "tensorflow/core/framework/op.h"
15 #include "tensorflow/core/framework/op_kernel.h"
16 #include "tensorflow/core/lib/core/errors.h"
17 
19 // namespace for code that is common for all kernels
20 namespace voxel_pooling_opkernel {
21 
22 template <class TReal, class TFeat>
23 class OutputAllocator {
24 public:
25  OutputAllocator(tensorflow::OpKernelContext* context) : context(context) {}
26 
27  void AllocPooledPositions(TReal** ptr, size_t num) {
28  using namespace tensorflow;
29  *ptr = nullptr;
30  Tensor* tensor = 0;
31  TensorShape shape({int64_t(num), 3});
32  OP_REQUIRES_OK(context, context->allocate_output(0, shape, &tensor));
33  auto flat_tensor = tensor->flat<TReal>();
34  *ptr = flat_tensor.data();
35  }
36 
37  void AllocPooledFeatures(TFeat** ptr, size_t num, int channels) {
38  using namespace tensorflow;
39  *ptr = nullptr;
40  Tensor* tensor = 0;
41  TensorShape shape({int64_t(num), channels});
42  OP_REQUIRES_OK(context, context->allocate_output(1, shape, &tensor));
43  auto flat_tensor = tensor->flat<TFeat>();
44  *ptr = flat_tensor.data();
45  }
46 
47 private:
48  tensorflow::OpKernelContext* context;
49 };
50 
51 // Base class with common code for the OpKernel implementations
52 class VoxelPoolingGradOpKernel : public tensorflow::OpKernel {
53 public:
54  explicit VoxelPoolingGradOpKernel(
55  tensorflow::OpKernelConstruction* construction)
56  : OpKernel(construction) {
57  using namespace tensorflow;
58  using namespace open3d::ml::impl;
59  std::string pos_fn_str;
60  OP_REQUIRES_OK(construction,
61  construction->GetAttr("position_fn", &pos_fn_str));
62 
63  if (pos_fn_str == "average")
64  position_fn = AVERAGE;
65  else if (pos_fn_str == "nearest_neighbor")
66  position_fn = NEAREST_NEIGHBOR;
67  else
68  position_fn = CENTER;
69 
70  std::string feat_fn_str;
71  OP_REQUIRES_OK(construction,
72  construction->GetAttr("feature_fn", &feat_fn_str));
73 
74  if (feat_fn_str == "average")
75  feature_fn = AVERAGE;
76  else if (feat_fn_str == "nearest_neighbor")
77  feature_fn = NEAREST_NEIGHBOR;
78  else
79  feature_fn = MAX;
80  }
81 
82  void Compute(tensorflow::OpKernelContext* context) override {
83  using namespace tensorflow;
84  using namespace open3d::ml::impl;
85 
86  const Tensor& positions = context->input(0);
87  OP_REQUIRES(context, positions.shape().dims() == 2,
88  absl::InvalidArgumentError(
89  "positions must be a rank 2 tensor"));
90 
91  const Tensor& features = context->input(1);
92  OP_REQUIRES(
93  context, features.shape().dims() == 2,
94  absl::InvalidArgumentError("features must be a rank 2 tensor"));
95 
96  const Tensor& voxel_size = context->input(2);
97  OP_REQUIRES(
98  context, TensorShapeUtils::IsScalar(voxel_size.shape()),
99  absl::InvalidArgumentError(
100  std::string("voxel_size must be a scalar, but is ") +
101  voxel_size.shape().DebugString()));
102 
103  const Tensor& pooled_positions = context->input(3);
104  OP_REQUIRES(context, pooled_positions.shape().dims() == 2,
105  absl::InvalidArgumentError(
106  "pooled_positions must be a rank 2 tensor"));
107 
108  const Tensor& pooled_features_gradient = context->input(4);
109  OP_REQUIRES(
110  context, pooled_features_gradient.shape().dims() == 2,
111  absl::InvalidArgumentError(
112  "pooled_features_gradient must be a rank 2 tensor"));
113 
114  Tensor* features_backprop = nullptr;
115  OP_REQUIRES_OK(context, context->allocate_output(0, features.shape(),
116  &features_backprop));
117 
118  Kernel(context, *features_backprop, positions, features,
119  pooled_positions, pooled_features_gradient, voxel_size);
120  }
121 
122  // Function with the device specific code
123  virtual void Kernel(tensorflow::OpKernelContext* context,
124  tensorflow::Tensor& features_backprop,
125  const tensorflow::Tensor& positions,
126  const tensorflow::Tensor& features,
127  const tensorflow::Tensor& pooled_positions,
128  const tensorflow::Tensor& pooled_features_gradient,
129  const tensorflow::Tensor& voxel_size) = 0;
130 
131 protected:
134 };
135 
136 } // namespace voxel_pooling_opkernel
ImGuiContext * context
Definition: Window.cpp:76
Definition: ContinuousConv.h:16
AccumulationFn
Definition: VoxelPooling.h:21
@ CENTER
Definition: VoxelPooling.h:21
@ NEAREST_NEIGHBOR
Definition: VoxelPooling.h:21
@ MAX
Definition: VoxelPooling.h:21
@ AVERAGE
Definition: VoxelPooling.h:21