29 #include "../TensorFlowHelper.h" 30 #include "tensorflow/core/framework/op.h" 31 #include "tensorflow/core/framework/op_kernel.h" 32 #include "tensorflow/core/lib/core/errors.h" 37 tensorflow::OpKernelConstruction* construction)
38 : OpKernel(construction) {
41 OP_REQUIRES_OK(construction,
42 construction->GetAttr(
"sample_size", &
sample_size));
44 errors::InvalidArgument(
45 "FurthestPointSampling expects positive npoint"));
51 const Tensor& inp_tensor = context->input(0);
54 inp_tensor.dims() == 3 && inp_tensor.shape().dim_size(2) == 3,
55 errors::InvalidArgument(
"FurthestPointSampling expects " 56 "(batch_size,num_points,3) inp shape"));
57 int batch_size = inp_tensor.shape().dim_size(0);
58 int pts_size = inp_tensor.shape().dim_size(1);
59 auto inp_flat = inp_tensor.flat<
float>();
60 const float* inp = &(inp_flat(0));
63 OP_REQUIRES_OK(context, context->allocate_output(
64 0, TensorShape{batch_size, sample_size},
66 auto out_flat = out_tensor->flat<
int>();
67 int* out = &(out_flat(0));
70 OP_REQUIRES_OK(context,
71 context->allocate_temp(DataTypeToEnum<float>::value,
72 TensorShape{batch_size, pts_size},
74 auto temp_flat = temp_tensor.flat<
float>();
75 float* temp = &(temp_flat(0));
Definition: SamplingOpKernel.h:34
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int n, int m, const float *dataset, float *temp, int *idxs)=0
ImGuiContext * context
Definition: Window.cpp:95
int sample_size
Definition: SamplingOpKernel.h:89
FurthestPointSamplingOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: SamplingOpKernel.h:36
void Compute(tensorflow::OpKernelContext *context) override
Definition: SamplingOpKernel.h:48