Open3D (C++ API)  0.19.0
SamplingOpKernel.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.open3d.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #pragma once
9 
10 #include "../TensorFlowHelper.h"
11 #include "tensorflow/core/framework/op.h"
12 #include "tensorflow/core/framework/op_kernel.h"
13 #include "tensorflow/core/lib/core/errors.h"
14 
15 class FurthestPointSamplingOpKernel : public tensorflow::OpKernel {
16 public:
18  tensorflow::OpKernelConstruction* construction)
19  : OpKernel(construction) {
20  using namespace tensorflow;
21 
22  OP_REQUIRES_OK(construction,
23  construction->GetAttr("sample_size", &sample_size));
24  OP_REQUIRES(construction, sample_size > 0,
25  absl::InvalidArgumentError(
26  "FurthestPointSampling expects positive npoint"));
27  }
28 
29  void Compute(tensorflow::OpKernelContext* context) override {
30  using namespace tensorflow;
31 
32  const Tensor& inp_tensor = context->input(0);
33  OP_REQUIRES(
34  context,
35  inp_tensor.dims() == 3 && inp_tensor.shape().dim_size(2) == 3,
36  absl::InvalidArgumentError(
37  "FurthestPointSampling expects "
38  "(batch_size,num_points,3) inp shape"));
39  int batch_size = inp_tensor.shape().dim_size(0);
40  int pts_size = inp_tensor.shape().dim_size(1);
41  auto inp_flat = inp_tensor.flat<float>();
42  const float* inp = &(inp_flat(0));
43 
44  Tensor* out_tensor;
45  OP_REQUIRES_OK(context, context->allocate_output(
46  0, TensorShape{batch_size, sample_size},
47  &out_tensor));
48  auto out_flat = out_tensor->flat<int>();
49  int* out = &(out_flat(0));
50 
51  Tensor temp_tensor;
52  OP_REQUIRES_OK(context,
53  context->allocate_temp(DataTypeToEnum<float>::value,
54  TensorShape{batch_size, pts_size},
55  &temp_tensor));
56  auto temp_flat = temp_tensor.flat<float>();
57  float* temp = &(temp_flat(0));
58 
59  Kernel(context, batch_size, pts_size, sample_size, inp, temp, out);
60  }
61 
62  virtual void Kernel(tensorflow::OpKernelContext* context,
63  int b,
64  int n,
65  int m,
66  const float* dataset,
67  float* temp,
68  int* idxs) = 0;
69 
70 protected:
72 };
ImGuiContext * context
Definition: Window.cpp:76
Definition: SamplingOpKernel.h:15
void Compute(tensorflow::OpKernelContext *context) override
Definition: SamplingOpKernel.h:29
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int n, int m, const float *dataset, float *temp, int *idxs)=0
int sample_size
Definition: SamplingOpKernel.h:71
FurthestPointSamplingOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: SamplingOpKernel.h:17