10 #include "../TensorFlowHelper.h"
11 #include "tensorflow/core/framework/op.h"
12 #include "tensorflow/core/framework/op_kernel.h"
13 #include "tensorflow/core/lib/core/errors.h"
18 : OpKernel(construction) {
19 using namespace tensorflow;
20 OP_REQUIRES_OK(construction, construction->GetAttr(
"sampled_pts_num",
25 using namespace tensorflow;
27 const Tensor& inp_tensor =
context->input(0);
30 inp_tensor.dims() == 3 && inp_tensor.shape().dim_size(2) == 3,
31 absl::InvalidArgumentError(
33 "(batch_size,num_points,3) inp shape"));
34 int batch_size = inp_tensor.shape().dim_size(0);
35 int pts_num = inp_tensor.shape().dim_size(1);
36 auto inp_flat = inp_tensor.flat<
float>();
37 const float* inp = &(inp_flat(0));
39 const Tensor& boxes3d_tensor =
context->input(1);
41 boxes3d_tensor.dims() == 3 &&
42 boxes3d_tensor.shape().dim_size(2) == 7,
43 absl::InvalidArgumentError(
45 "(batch_size,num_boxes,7) boxes3d shape"));
46 int boxes_num = boxes3d_tensor.shape().dim_size(1);
47 auto boxes3d_flat = boxes3d_tensor.flat<
float>();
48 const float* boxes3d = &(boxes3d_flat(0));
50 const Tensor& feats_tensor =
context->input(2);
52 feats_tensor.dims() == 3 &&
53 feats_tensor.shape().dim_size(1) == pts_num,
54 absl::InvalidArgumentError(
56 "(batch_size,num_points,feats) feats shape"));
57 int feature_in_len = feats_tensor.shape().dim_size(2);
58 auto feats_flat = feats_tensor.flat<
float>();
59 const float* feats = &(feats_flat(0));
65 TensorShape{batch_size, boxes_num,
66 sampled_pts_num, 3 + feature_in_len},
68 auto out_flat0 = out_feats->flat<
float>();
69 float* out0 = &(out_flat0(0));
73 1, TensorShape{batch_size, boxes_num},
75 auto out_flat1 = out_flags->flat<
int>();
76 int* out1 = &(out_flat1(0));
78 Kernel(
context, batch_size, pts_num, boxes_num, feature_in_len,
90 const float* pts_feature,
91 float* pooled_features,
92 int* pooled_empty_flag) = 0;
ImGuiContext * context
Definition: Window.cpp:76
Definition: RoiPoolOpKernel.h:15
RoiPoolOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: RoiPoolOpKernel.h:17
int sampled_pts_num
Definition: RoiPoolOpKernel.h:95
virtual void Kernel(tensorflow::OpKernelContext *context, int batch_size, int pts_num, int boxes_num, int feature_in_len, int sampled_pts_num, const float *xyz, const float *boxes3d, const float *pts_feature, float *pooled_features, int *pooled_empty_flag)=0
void Compute(tensorflow::OpKernelContext *context) override
Definition: RoiPoolOpKernel.h:24