Open3D (C++ API)  0.19.0
VoxelPoolingOpKernel.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.open3d.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #pragma once
9 
10 #include <cstdint>
11 
12 #include "absl/status/status.h"
14 #include "tensorflow/core/framework/op.h"
15 #include "tensorflow/core/framework/op_kernel.h"
16 #include "tensorflow/core/lib/core/errors.h"
17 
19 // namespace for code that is common for all kernels
20 namespace voxel_pooling_opkernel {
21 
22 template <class TReal, class TFeat>
23 class OutputAllocator {
24 public:
25  OutputAllocator(tensorflow::OpKernelContext* context) : context(context) {}
26 
27  void AllocPooledPositions(TReal** ptr, size_t num) {
28  using namespace tensorflow;
29  *ptr = nullptr;
30  Tensor* tensor = 0;
31  TensorShape shape({int64_t(num), 3});
32  OP_REQUIRES_OK(context, context->allocate_output(0, shape, &tensor));
33  auto flat_tensor = tensor->flat<TReal>();
34  *ptr = flat_tensor.data();
35  }
36 
37  void AllocPooledFeatures(TFeat** ptr, size_t num, int channels) {
38  using namespace tensorflow;
39  *ptr = nullptr;
40  Tensor* tensor = 0;
41  TensorShape shape({int64_t(num), channels});
42  OP_REQUIRES_OK(context, context->allocate_output(1, shape, &tensor));
43  auto flat_tensor = tensor->flat<TFeat>();
44  *ptr = flat_tensor.data();
45  }
46 
47 private:
48  tensorflow::OpKernelContext* context;
49 };
50 
51 // Base class with common code for the OpKernel implementations
52 class VoxelPoolingOpKernel : public tensorflow::OpKernel {
53 public:
54  explicit VoxelPoolingOpKernel(
55  tensorflow::OpKernelConstruction* construction)
56  : OpKernel(construction) {
57  using namespace tensorflow;
58  using namespace open3d::ml::impl;
59  std::string pos_fn_str;
60  OP_REQUIRES_OK(construction,
61  construction->GetAttr("position_fn", &pos_fn_str));
62 
63  if (pos_fn_str == "average")
64  position_fn = AVERAGE;
65  else if (pos_fn_str == "nearest_neighbor")
66  position_fn = NEAREST_NEIGHBOR;
67  else
68  position_fn = CENTER;
69 
70  std::string feat_fn_str;
71  OP_REQUIRES_OK(construction,
72  construction->GetAttr("feature_fn", &feat_fn_str));
73 
74  if (feat_fn_str == "average")
75  feature_fn = AVERAGE;
76  else if (feat_fn_str == "nearest_neighbor")
77  feature_fn = NEAREST_NEIGHBOR;
78  else
79  feature_fn = MAX;
80 
81  OP_REQUIRES_OK(construction, construction->GetAttr("debug", &debug));
82  }
83 
84  void Compute(tensorflow::OpKernelContext* context) override {
85  using namespace tensorflow;
86  using namespace open3d::ml::impl;
87  const Tensor& positions = context->input(0);
88  OP_REQUIRES(context, positions.shape().dims() == 2,
89  absl::InvalidArgumentError(
90  "positions must be a rank 2 tensor"));
91 
92  const Tensor& features = context->input(1);
93  OP_REQUIRES(
94  context, features.shape().dims() == 2,
95  absl::InvalidArgumentError("features must be a rank 2 tensor"));
96 
97  const Tensor& voxel_size = context->input(2);
98  OP_REQUIRES(
99  context, TensorShapeUtils::IsScalar(voxel_size.shape()),
100  absl::InvalidArgumentError(
101  std::string("voxel_size must be a scalar, but is ") +
102  voxel_size.shape().DebugString()));
103 
104  Kernel(context, positions, features, voxel_size);
105  }
106 
107  // Function with the device specific code
108  virtual void Kernel(tensorflow::OpKernelContext* context,
109  const tensorflow::Tensor& positions,
110  const tensorflow::Tensor& features,
111  const tensorflow::Tensor& voxel_size) = 0;
112 
113 protected:
116  bool debug;
117 };
118 
119 } // namespace voxel_pooling_opkernel
ImGuiContext * context
Definition: Window.cpp:76
Definition: ContinuousConv.h:16
AccumulationFn
Definition: VoxelPooling.h:21
@ CENTER
Definition: VoxelPooling.h:21
@ NEAREST_NEIGHBOR
Definition: VoxelPooling.h:21
@ MAX
Definition: VoxelPooling.h:21
@ AVERAGE
Definition: VoxelPooling.h:21