29 #include "../TensorFlowHelper.h" 30 #include "tensorflow/core/framework/op.h" 31 #include "tensorflow/core/framework/op_kernel.h" 32 #include "tensorflow/core/lib/core/errors.h" 37 : OpKernel(construction) {
40 OP_REQUIRES_OK(construction,
41 construction->GetAttr(
"nsample", &
nsample));
42 OP_REQUIRES_OK(construction, construction->GetAttr(
"radius", &
radius));
45 errors::InvalidArgument(
"BallQuery expects positive nsample"));
51 const Tensor& inp_tensor = context->input(0);
54 inp_tensor.dims() == 3 && inp_tensor.shape().dim_size(2) == 3,
55 errors::InvalidArgument(
"BallQuery expects " 56 "(batch_size,num_points,3) inp shape"));
57 int batch_size = inp_tensor.shape().dim_size(0);
58 int pts_size = inp_tensor.shape().dim_size(1);
59 auto inp_flat = inp_tensor.flat<
float>();
60 const float* inp = &(inp_flat(0));
62 const Tensor& center_tensor = context->input(1);
64 center_tensor.dims() == 3 &&
65 center_tensor.shape().dim_size(2) == 3,
66 errors::InvalidArgument(
68 "(batch_size,num_points,3) center shape"));
69 int ball_size = center_tensor.shape().dim_size(1);
70 auto center_flat = center_tensor.flat<
float>();
71 const float* center = &(center_flat(0));
74 OP_REQUIRES_OK(context,
75 context->allocate_output(
76 0, TensorShape{batch_size, ball_size, nsample},
78 auto out_flat = out_tensor->flat<
int>();
79 int* out = &(out_flat(0));
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int n, int m, float radius, int nsample, const float *new_xyz, const float *xyz, int *idx)=0
ImGuiContext * context
Definition: Window.cpp:95
int nsample
Definition: BallQueryOpKernel.h:96
float radius
Definition: BallQueryOpKernel.h:97
Definition: BallQueryOpKernel.h:34
BallQueryOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: BallQueryOpKernel.h:36
void Compute(tensorflow::OpKernelContext *context) override
Definition: BallQueryOpKernel.h:48