25 #include "../TensorFlowHelper.h" 26 #include "tensorflow/core/framework/op.h" 27 #include "tensorflow/core/framework/op_kernel.h" 28 #include "tensorflow/core/lib/core/errors.h" 33 : OpKernel(construction) {
35 OP_REQUIRES_OK(construction, construction->GetAttr(
"sampled_pts_num",
42 const Tensor& inp_tensor = context->input(0);
45 inp_tensor.dims() == 3 && inp_tensor.shape().dim_size(2) == 3,
46 errors::InvalidArgument(
"RoiPool expects " 47 "(batch_size,num_points,3) inp shape"));
48 int batch_size = inp_tensor.shape().dim_size(0);
49 int pts_num = inp_tensor.shape().dim_size(1);
50 auto inp_flat = inp_tensor.flat<
float>();
51 const float* inp = &(inp_flat(0));
53 const Tensor& boxes3d_tensor = context->input(1);
55 boxes3d_tensor.dims() == 3 &&
56 boxes3d_tensor.shape().dim_size(2) == 7,
57 errors::InvalidArgument(
59 "(batch_size,num_boxes,7) boxes3d shape"));
60 int boxes_num = boxes3d_tensor.shape().dim_size(1);
61 auto boxes3d_flat = boxes3d_tensor.flat<
float>();
62 const float* boxes3d = &(boxes3d_flat(0));
64 const Tensor& feats_tensor = context->input(2);
66 feats_tensor.dims() == 3 &&
67 feats_tensor.shape().dim_size(1) == pts_num,
68 errors::InvalidArgument(
70 "(batch_size,num_points,feats) feats shape"));
71 int feature_in_len = feats_tensor.shape().dim_size(2);
72 auto feats_flat = feats_tensor.flat<
float>();
73 const float* feats = &(feats_flat(0));
76 OP_REQUIRES_OK(context,
77 context->allocate_output(
79 TensorShape{batch_size, boxes_num,
80 sampled_pts_num, 3 + feature_in_len},
82 auto out_flat0 = out_feats->flat<
float>();
83 float* out0 = &(out_flat0(0));
86 OP_REQUIRES_OK(context, context->allocate_output(
87 1, TensorShape{batch_size, boxes_num},
89 auto out_flat1 = out_flags->flat<
int>();
90 int* out1 = &(out_flat1(0));
92 Kernel(context, batch_size, pts_num, boxes_num, feature_in_len,
103 const float* boxes3d,
104 const float* pts_feature,
105 float* pooled_features,
106 int* pooled_empty_flag) = 0;
virtual void Kernel(tensorflow::OpKernelContext *context, int batch_size, int pts_num, int boxes_num, int feature_in_len, int sampled_pts_num, const float *xyz, const float *boxes3d, const float *pts_feature, float *pooled_features, int *pooled_empty_flag)=0
ImGuiContext * context
Definition: Window.cpp:95
void Compute(tensorflow::OpKernelContext *context) override
Definition: RoiPoolOpKernel.h:39
int sampled_pts_num
Definition: RoiPoolOpKernel.h:109
RoiPoolOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: RoiPoolOpKernel.h:32
Definition: RoiPoolOpKernel.h:30