10 #include "../TensorFlowHelper.h"
11 #include "tensorflow/core/framework/op.h"
12 #include "tensorflow/core/framework/op_kernel.h"
13 #include "tensorflow/core/lib/core/errors.h"
18 : OpKernel(construction) {
19 using namespace tensorflow;
21 OP_REQUIRES_OK(construction,
22 construction->GetAttr(
"nsample", &
nsample));
23 OP_REQUIRES_OK(construction, construction->GetAttr(
"radius", &
radius));
24 OP_REQUIRES(construction,
nsample > 0,
25 absl::InvalidArgumentError(
26 "BallQuery expects positive nsample"));
30 using namespace tensorflow;
32 const Tensor& inp_tensor =
context->input(0);
35 inp_tensor.dims() == 3 && inp_tensor.shape().dim_size(2) == 3,
36 absl::InvalidArgumentError(
38 "(batch_size,num_points,3) inp shape"));
39 int batch_size = inp_tensor.shape().dim_size(0);
40 int pts_size = inp_tensor.shape().dim_size(1);
41 auto inp_flat = inp_tensor.flat<
float>();
42 const float* inp = &(inp_flat(0));
44 const Tensor& center_tensor =
context->input(1);
46 center_tensor.dims() == 3 &&
47 center_tensor.shape().dim_size(2) == 3,
48 absl::InvalidArgumentError(
50 "(batch_size,num_points,3) center shape"));
51 int ball_size = center_tensor.shape().dim_size(1);
52 auto center_flat = center_tensor.flat<
float>();
53 const float* center = &(center_flat(0));
58 0, TensorShape{batch_size, ball_size, nsample},
60 auto out_flat = out_tensor->flat<
int>();
61 int* out = &(out_flat(0));
ImGuiContext * context
Definition: Window.cpp:76
Definition: BallQueryOpKernel.h:15
int nsample
Definition: BallQueryOpKernel.h:78
BallQueryOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: BallQueryOpKernel.h:17
void Compute(tensorflow::OpKernelContext *context) override
Definition: BallQueryOpKernel.h:29
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int n, int m, float radius, int nsample, const float *new_xyz, const float *xyz, int *idx)=0
float radius
Definition: BallQueryOpKernel.h:79