Open3D (C++ API)  0.19.0
NmsOpKernel.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.open3d.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #pragma once
9 
10 #include <cstdint>
11 
12 // #include "open3d/ml/impl/misc/VoxelPooling.h"
14 #include "tensorflow/core/framework/op.h"
15 #include "tensorflow/core/framework/op_kernel.h"
16 #include "tensorflow/core/lib/core/errors.h"
17 
19 // namespace for code that is common for all kernels
20 namespace nms_opkernel {
21 
22 class OutputAllocator {
23 public:
24  OutputAllocator(tensorflow::OpKernelContext* context) : context(context) {}
25 
26  void AllocKeepIndices(int64_t** ptr, int64_t num) {
27  using namespace tensorflow;
28  *ptr = nullptr;
29  Tensor* tensor = 0;
30  TensorShape shape({num});
31  OP_REQUIRES_OK(context, context->allocate_output(0, shape, &tensor));
32  auto flat_tensor = tensor->flat<int64_t>();
33  *ptr = (int64_t*)flat_tensor.data();
34  }
35 
36 private:
37  tensorflow::OpKernelContext* context;
38 };
39 
40 // Base class with common code for the OpKernel implementations
41 class NmsOpKernel : public tensorflow::OpKernel {
42 public:
43  explicit NmsOpKernel(tensorflow::OpKernelConstruction* construction)
44  : OpKernel(construction) {
45  OP_REQUIRES_OK(construction,
46  construction->GetAttr("nms_overlap_thresh",
47  &nms_overlap_thresh));
48  }
49 
50  void Compute(tensorflow::OpKernelContext* context) override {
51  using namespace tensorflow;
52  const Tensor& boxes = context->input(0);
53  const Tensor& scores = context->input(1);
54 
55  {
56  using namespace open3d::ml::op_util;
57  Dim num_points("num_points");
58  Dim five(5, "five");
59  CHECK_SHAPE(context, boxes, num_points, five);
60  CHECK_SHAPE(context, scores, num_points);
61  }
62 
63  Kernel(context, boxes, scores);
64  }
65 
66  // Function with the device specific code
67  virtual void Kernel(tensorflow::OpKernelContext* context,
68  const tensorflow::Tensor& boxes,
69  const tensorflow::Tensor& scores) = 0;
70 
71 protected:
72  float nms_overlap_thresh;
73 };
74 
75 } // namespace nms_opkernel
#define CHECK_SHAPE(tensor,...)
Definition: TorchHelper.h:190
ImGuiContext * context
Definition: Window.cpp:76
Class for dimensions for which the value should be inferred.
Definition: ShapeChecking.h:50
Definition: ShapeChecking.h:16