Open3D (C++ API)  0.19.0
VoxelizeOpKernel.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.open3d.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #pragma once
9 
10 #include <cstdint>
11 
12 // #include "open3d/ml/impl/misc/VoxelPooling.h"
14 #include "tensorflow/core/framework/op.h"
15 #include "tensorflow/core/framework/op_kernel.h"
16 #include "tensorflow/core/lib/core/errors.h"
17 
19 // namespace for code that is common for all kernels
20 namespace voxelize_opkernel {
21 
22 class OutputAllocator {
23 public:
24  OutputAllocator(tensorflow::OpKernelContext* context) : context(context) {}
25 
26  void AllocVoxelCoords(int32_t** ptr, int64_t rows, int64_t cols) {
27  using namespace tensorflow;
28  *ptr = nullptr;
29  Tensor* tensor = 0;
30  TensorShape shape({rows, cols});
31  OP_REQUIRES_OK(context, context->allocate_output(0, shape, &tensor));
32  auto flat_tensor = tensor->flat<int32_t>();
33  *ptr = flat_tensor.data();
34  }
35 
36  void AllocVoxelPointIndices(int64_t** ptr, int64_t num) {
37  using namespace tensorflow;
38  *ptr = nullptr;
39  Tensor* tensor = 0;
40  TensorShape shape({num});
41  OP_REQUIRES_OK(context, context->allocate_output(1, shape, &tensor));
42  auto flat_tensor = tensor->flat<int64_t>();
43  *ptr = (int64_t*)flat_tensor.data();
44  }
45 
46  void AllocVoxelPointRowSplits(int64_t** ptr, int64_t num) {
47  using namespace tensorflow;
48  *ptr = nullptr;
49  Tensor* tensor = 0;
50  TensorShape shape({num});
51  OP_REQUIRES_OK(context, context->allocate_output(2, shape, &tensor));
52  auto flat_tensor = tensor->flat<int64_t>();
53  *ptr = (int64_t*)flat_tensor.data();
54  }
55 
56  void AllocVoxelBatchSplits(int64_t** ptr, int64_t num) {
57  using namespace tensorflow;
58  *ptr = nullptr;
59  Tensor* tensor = 0;
60  TensorShape shape({num});
61  OP_REQUIRES_OK(context, context->allocate_output(3, shape, &tensor));
62  auto flat_tensor = tensor->flat<int64_t>();
63  *ptr = (int64_t*)flat_tensor.data();
64  }
65 
66 private:
67  tensorflow::OpKernelContext* context;
68 };
69 
70 // Base class with common code for the OpKernel implementations
71 class VoxelizeOpKernel : public tensorflow::OpKernel {
72 public:
73  explicit VoxelizeOpKernel(tensorflow::OpKernelConstruction* construction)
74  : OpKernel(construction) {
75  OP_REQUIRES_OK(construction,
76  construction->GetAttr("max_points_per_voxel",
77  &max_points_per_voxel));
78  OP_REQUIRES_OK(construction,
79  construction->GetAttr("max_voxels", &max_voxels));
80  }
81 
82  void Compute(tensorflow::OpKernelContext* context) override {
83  using namespace tensorflow;
84  const Tensor& points = context->input(0);
85  const Tensor& row_splits = context->input(1);
86  const Tensor& voxel_size = context->input(2);
87  const Tensor& points_range_min = context->input(3);
88  const Tensor& points_range_max = context->input(4);
89 
90  {
91  using namespace open3d::ml::op_util;
92  Dim num_points("num_points");
93  Dim ndim("ndim");
94  CHECK_SHAPE(context, points, num_points, ndim);
95  CHECK_SHAPE(context, voxel_size, ndim);
96  CHECK_SHAPE(context, points_range_min, ndim);
97  CHECK_SHAPE(context, points_range_max, ndim);
98  OP_REQUIRES(
99  context, ndim.value() > 0 && ndim.value() < 9,
100  absl::InvalidArgumentError(
101  "the number of dimensions must be in [1,..,8]"));
102  }
103 
104  Kernel(context, points, row_splits, voxel_size, points_range_min,
105  points_range_max);
106  }
107 
108  // Function with the device specific code
109  virtual void Kernel(tensorflow::OpKernelContext* context,
110  const tensorflow::Tensor& points,
111  const tensorflow::Tensor& row_splits,
112  const tensorflow::Tensor& voxel_size,
113  const tensorflow::Tensor& points_range_min,
114  const tensorflow::Tensor& points_range_max) = 0;
115 
116 protected:
117  int64_t max_points_per_voxel;
118  int64_t max_voxels;
119 };
120 
121 } // namespace voxelize_opkernel
#define CHECK_SHAPE(tensor,...)
Definition: TorchHelper.h:190
ImGuiContext * context
Definition: Window.cpp:76
Class for dimensions for which the value should be inferred.
Definition: ShapeChecking.h:50
int points
Definition: FilePCD.cpp:54
const char const char value recording_handle imu_sample recording_handle uint8_t size_t data_size k4a_record_configuration_t config target_format k4a_capture_t capture_handle k4a_imu_sample_t imu_sample playback_handle k4a_logging_message_cb_t void min_level device_handle k4a_imu_sample_t int32_t
Definition: K4aPlugin.cpp:395
Definition: ShapeChecking.h:16