25 #include "../TensorFlowHelper.h" 26 #include "tensorflow/core/framework/op.h" 27 #include "tensorflow/core/framework/op_kernel.h" 28 #include "tensorflow/core/lib/core/errors.h" 33 tensorflow::OpKernelConstruction* construction)
34 : OpKernel(construction) {
37 OP_REQUIRES_OK(construction,
38 construction->GetAttr(
"sample_size", &
sample_size));
40 errors::InvalidArgument(
41 "FurthestPointSampling expects positive npoint"));
47 const Tensor& inp_tensor = context->input(0);
50 inp_tensor.dims() == 3 && inp_tensor.shape().dim_size(2) == 3,
51 errors::InvalidArgument(
"FurthestPointSampling expects " 52 "(batch_size,num_points,3) inp shape"));
53 int batch_size = inp_tensor.shape().dim_size(0);
54 int pts_size = inp_tensor.shape().dim_size(1);
55 auto inp_flat = inp_tensor.flat<
float>();
56 const float* inp = &(inp_flat(0));
59 OP_REQUIRES_OK(context, context->allocate_output(
60 0, TensorShape{batch_size, sample_size},
62 auto out_flat = out_tensor->flat<
int>();
63 int* out = &(out_flat(0));
66 OP_REQUIRES_OK(context,
67 context->allocate_temp(DataTypeToEnum<float>::value,
68 TensorShape{batch_size, pts_size},
70 auto temp_flat = temp_tensor.flat<
float>();
71 float* temp = &(temp_flat(0));
Definition: SamplingOpKernel.h:30
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int n, int m, const float *dataset, float *temp, int *idxs)=0
ImGuiContext * context
Definition: Window.cpp:95
int sample_size
Definition: SamplingOpKernel.h:85
FurthestPointSamplingOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: SamplingOpKernel.h:32
void Compute(tensorflow::OpKernelContext *context) override
Definition: SamplingOpKernel.h:44