29 #include "../TensorFlowHelper.h" 30 #include "tensorflow/core/framework/op.h" 31 #include "tensorflow/core/framework/op_kernel.h" 32 #include "tensorflow/core/lib/core/errors.h" 37 : OpKernel(construction) {
39 OP_REQUIRES_OK(construction, construction->GetAttr(
"sampled_pts_num",
46 const Tensor& inp_tensor = context->input(0);
49 inp_tensor.dims() == 3 && inp_tensor.shape().dim_size(2) == 3,
50 errors::InvalidArgument(
"RoiPool expects " 51 "(batch_size,num_points,3) inp shape"));
52 int batch_size = inp_tensor.shape().dim_size(0);
53 int pts_num = inp_tensor.shape().dim_size(1);
54 auto inp_flat = inp_tensor.flat<
float>();
55 const float* inp = &(inp_flat(0));
57 const Tensor& boxes3d_tensor = context->input(1);
59 boxes3d_tensor.dims() == 3 &&
60 boxes3d_tensor.shape().dim_size(2) == 7,
61 errors::InvalidArgument(
63 "(batch_size,num_boxes,7) boxes3d shape"));
64 int boxes_num = boxes3d_tensor.shape().dim_size(1);
65 auto boxes3d_flat = boxes3d_tensor.flat<
float>();
66 const float* boxes3d = &(boxes3d_flat(0));
68 const Tensor& feats_tensor = context->input(2);
70 feats_tensor.dims() == 3 &&
71 feats_tensor.shape().dim_size(1) == pts_num,
72 errors::InvalidArgument(
74 "(batch_size,num_points,feats) feats shape"));
75 int feature_in_len = feats_tensor.shape().dim_size(2);
76 auto feats_flat = feats_tensor.flat<
float>();
77 const float* feats = &(feats_flat(0));
80 OP_REQUIRES_OK(context,
81 context->allocate_output(
83 TensorShape{batch_size, boxes_num,
84 sampled_pts_num, 3 + feature_in_len},
86 auto out_flat0 = out_feats->flat<
float>();
87 float* out0 = &(out_flat0(0));
90 OP_REQUIRES_OK(context, context->allocate_output(
91 1, TensorShape{batch_size, boxes_num},
93 auto out_flat1 = out_flags->flat<
int>();
94 int* out1 = &(out_flat1(0));
96 Kernel(context, batch_size, pts_num, boxes_num, feature_in_len,
107 const float* boxes3d,
108 const float* pts_feature,
109 float* pooled_features,
110 int* pooled_empty_flag) = 0;
virtual void Kernel(tensorflow::OpKernelContext *context, int batch_size, int pts_num, int boxes_num, int feature_in_len, int sampled_pts_num, const float *xyz, const float *boxes3d, const float *pts_feature, float *pooled_features, int *pooled_empty_flag)=0
ImGuiContext * context
Definition: Window.cpp:95
void Compute(tensorflow::OpKernelContext *context) override
Definition: RoiPoolOpKernel.h:43
int sampled_pts_num
Definition: RoiPoolOpKernel.h:113
RoiPoolOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: RoiPoolOpKernel.h:36
Definition: RoiPoolOpKernel.h:34