14 #include "tensorflow/core/framework/op.h"
15 #include "tensorflow/core/framework/op_kernel.h"
16 #include "tensorflow/core/lib/core/errors.h"
20 namespace nms_opkernel {
22 class OutputAllocator {
26 void AllocKeepIndices(int64_t** ptr, int64_t num) {
27 using namespace tensorflow;
30 TensorShape shape({num});
31 OP_REQUIRES_OK(
context,
context->allocate_output(0, shape, &tensor));
32 auto flat_tensor = tensor->flat<int64_t>();
33 *ptr = (int64_t*)flat_tensor.data();
37 tensorflow::OpKernelContext*
context;
41 class NmsOpKernel :
public tensorflow::OpKernel {
43 explicit NmsOpKernel(tensorflow::OpKernelConstruction* construction)
44 : OpKernel(construction) {
45 OP_REQUIRES_OK(construction,
46 construction->GetAttr(
"nms_overlap_thresh",
47 &nms_overlap_thresh));
50 void Compute(tensorflow::OpKernelContext*
context)
override {
51 using namespace tensorflow;
52 const Tensor& boxes =
context->input(0);
53 const Tensor& scores =
context->input(1);
57 Dim num_points(
"num_points");
67 virtual void Kernel(tensorflow::OpKernelContext*
context,
68 const tensorflow::Tensor& boxes,
69 const tensorflow::Tensor& scores) = 0;
72 float nms_overlap_thresh;
#define CHECK_SHAPE(tensor,...)
Definition: TorchHelper.h:190
ImGuiContext * context
Definition: Window.cpp:76
Class for dimensions for which the value should be inferred.
Definition: ShapeChecking.h:50
Definition: ShapeChecking.h:16