Open3D (C++ API)  0.19.0
BallQueryOpKernel.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.open3d.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #pragma once
9 
10 #include "../TensorFlowHelper.h"
11 #include "tensorflow/core/framework/op.h"
12 #include "tensorflow/core/framework/op_kernel.h"
13 #include "tensorflow/core/lib/core/errors.h"
14 
15 class BallQueryOpKernel : public tensorflow::OpKernel {
16 public:
17  explicit BallQueryOpKernel(tensorflow::OpKernelConstruction* construction)
18  : OpKernel(construction) {
19  using namespace tensorflow;
20 
21  OP_REQUIRES_OK(construction,
22  construction->GetAttr("nsample", &nsample));
23  OP_REQUIRES_OK(construction, construction->GetAttr("radius", &radius));
24  OP_REQUIRES(construction, nsample > 0,
25  absl::InvalidArgumentError(
26  "BallQuery expects positive nsample"));
27  }
28 
29  void Compute(tensorflow::OpKernelContext* context) override {
30  using namespace tensorflow;
31 
32  const Tensor& inp_tensor = context->input(0);
33  OP_REQUIRES(
34  context,
35  inp_tensor.dims() == 3 && inp_tensor.shape().dim_size(2) == 3,
36  absl::InvalidArgumentError(
37  "BallQuery expects "
38  "(batch_size,num_points,3) inp shape"));
39  int batch_size = inp_tensor.shape().dim_size(0);
40  int pts_size = inp_tensor.shape().dim_size(1);
41  auto inp_flat = inp_tensor.flat<float>();
42  const float* inp = &(inp_flat(0));
43 
44  const Tensor& center_tensor = context->input(1);
45  OP_REQUIRES(context,
46  center_tensor.dims() == 3 &&
47  center_tensor.shape().dim_size(2) == 3,
48  absl::InvalidArgumentError(
49  "BallQuery expects "
50  "(batch_size,num_points,3) center shape"));
51  int ball_size = center_tensor.shape().dim_size(1);
52  auto center_flat = center_tensor.flat<float>();
53  const float* center = &(center_flat(0));
54 
55  Tensor* out_tensor;
56  OP_REQUIRES_OK(context,
57  context->allocate_output(
58  0, TensorShape{batch_size, ball_size, nsample},
59  &out_tensor));
60  auto out_flat = out_tensor->flat<int>();
61  int* out = &(out_flat(0));
62 
63  Kernel(context, batch_size, pts_size, ball_size, radius, nsample,
64  center, inp, out);
65  }
66 
67  virtual void Kernel(tensorflow::OpKernelContext* context,
68  int b,
69  int n,
70  int m,
71  float radius,
72  int nsample,
73  const float* new_xyz,
74  const float* xyz,
75  int* idx) = 0;
76 
77 protected:
78  int nsample;
79  float radius;
80 };
ImGuiContext * context
Definition: Window.cpp:76
Definition: BallQueryOpKernel.h:15
int nsample
Definition: BallQueryOpKernel.h:78
BallQueryOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: BallQueryOpKernel.h:17
void Compute(tensorflow::OpKernelContext *context) override
Definition: BallQueryOpKernel.h:29
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int n, int m, float radius, int nsample, const float *new_xyz, const float *xyz, int *idx)=0
float radius
Definition: BallQueryOpKernel.h:79