25 #include "tensorflow/core/framework/op.h" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/lib/core/errors.h" 34 class ReduceSubarraysSumOpKernel :
public tensorflow::OpKernel {
36 explicit ReduceSubarraysSumOpKernel(
37 tensorflow::OpKernelConstruction* construction)
38 : OpKernel(construction) {}
40 void Compute(tensorflow::OpKernelContext* context)
override {
42 static_assert(
sizeof(int64) ==
sizeof(int64_t),
43 "int64 type is not compatible");
45 const Tensor& values = context->input(0);
46 OP_REQUIRES(context, values.shape().dims() == 1,
47 errors::InvalidArgument(
"values must be a rank 1 tensor"));
49 const Tensor& row_splits = context->input(1);
51 context, row_splits.shape().dims() == 1,
52 errors::InvalidArgument(
"row_splits must be a rank 1 tensor"));
55 if (values.shape().dim_size(0) == 0) {
56 Tensor* sums_tensor = 0;
57 OP_REQUIRES_OK(context, context->allocate_output(0, values.shape(),
62 Tensor* sums_tensor = 0;
63 TensorShape sums_shape({row_splits.shape().dim_size(0) - 1});
64 OP_REQUIRES_OK(context,
65 context->allocate_output(0, sums_shape, &sums_tensor));
67 Kernel(context, values, row_splits, *sums_tensor);
71 virtual void Kernel(tensorflow::OpKernelContext* context,
72 const tensorflow::Tensor& values,
73 const tensorflow::Tensor& row_splits,
74 tensorflow::Tensor& sums) = 0;