31 #include "tensorflow/core/framework/op.h" 32 #include "tensorflow/core/framework/op_kernel.h" 33 #include "tensorflow/core/lib/core/errors.h" 39 class OutputAllocator {
41 OutputAllocator(tensorflow::OpKernelContext*
context) : context(context) {}
43 void AllocKeepIndices(int64_t** ptr, int64_t num) {
47 TensorShape shape({num});
48 OP_REQUIRES_OK(
context,
context->allocate_output(0, shape, &tensor));
49 auto flat_tensor = tensor->flat<int64>();
50 *ptr = (int64_t*)flat_tensor.data();
54 tensorflow::OpKernelContext*
context;
58 class NmsOpKernel :
public tensorflow::OpKernel {
60 explicit NmsOpKernel(tensorflow::OpKernelConstruction* construction)
61 : OpKernel(construction) {
62 OP_REQUIRES_OK(construction,
63 construction->GetAttr(
"nms_overlap_thresh",
64 &nms_overlap_thresh));
67 void Compute(tensorflow::OpKernelContext*
context)
override {
69 const Tensor& boxes = context->input(0);
70 const Tensor& scores = context->input(1);
74 Dim num_points(
"num_points");
80 Kernel(context, boxes, scores);
84 virtual void Kernel(tensorflow::OpKernelContext* context,
85 const tensorflow::Tensor& boxes,
86 const tensorflow::Tensor& scores) = 0;
89 float nms_overlap_thresh;
#define CHECK_SHAPE(tensor,...)
Definition: TorchHelper.h:205
Class for dimensions for which the value should be inferred.
Definition: ShapeChecking.h:69
ImGuiContext * context
Definition: Window.cpp:95
Definition: ShapeChecking.h:35