25 #include "../TensorFlowHelper.h" 26 #include "tensorflow/core/framework/op.h" 27 #include "tensorflow/core/framework/op_kernel.h" 28 #include "tensorflow/core/lib/core/errors.h" 33 : OpKernel(construction) {
36 OP_REQUIRES_OK(construction,
37 construction->GetAttr(
"nsample", &
nsample));
38 OP_REQUIRES_OK(construction, construction->GetAttr(
"radius", &
radius));
41 errors::InvalidArgument(
"BallQuery expects positive nsample"));
47 const Tensor& inp_tensor = context->input(0);
50 inp_tensor.dims() == 3 && inp_tensor.shape().dim_size(2) == 3,
51 errors::InvalidArgument(
"BallQuery expects " 52 "(batch_size,num_points,3) inp shape"));
53 int batch_size = inp_tensor.shape().dim_size(0);
54 int pts_size = inp_tensor.shape().dim_size(1);
55 auto inp_flat = inp_tensor.flat<
float>();
56 const float* inp = &(inp_flat(0));
58 const Tensor& center_tensor = context->input(1);
60 center_tensor.dims() == 3 &&
61 center_tensor.shape().dim_size(2) == 3,
62 errors::InvalidArgument(
64 "(batch_size,num_points,3) center shape"));
65 int ball_size = center_tensor.shape().dim_size(1);
66 auto center_flat = center_tensor.flat<
float>();
67 const float* center = &(center_flat(0));
70 OP_REQUIRES_OK(context,
71 context->allocate_output(
72 0, TensorShape{batch_size, ball_size, nsample},
74 auto out_flat = out_tensor->flat<
int>();
75 int* out = &(out_flat(0));
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int n, int m, float radius, int nsample, const float *new_xyz, const float *xyz, int *idx)=0
ImGuiContext * context
Definition: Window.cpp:95
int nsample
Definition: BallQueryOpKernel.h:92
float radius
Definition: BallQueryOpKernel.h:93
Definition: BallQueryOpKernel.h:30
BallQueryOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: BallQueryOpKernel.h:32
void Compute(tensorflow::OpKernelContext *context) override
Definition: BallQueryOpKernel.h:44