29 #include "../TensorFlowHelper.h" 31 #include "tensorflow/core/framework/op.h" 32 #include "tensorflow/core/framework/op_kernel.h" 33 #include "tensorflow/core/lib/core/errors.h" 39 class RadiusSearchOpKernel :
public tensorflow::OpKernel {
41 explicit RadiusSearchOpKernel(
42 tensorflow::OpKernelConstruction* construction)
43 : OpKernel(construction) {
46 std::string metric_str;
47 OP_REQUIRES_OK(construction,
48 construction->GetAttr(
"metric", &metric_str));
49 if (metric_str ==
"L1")
54 OP_REQUIRES_OK(construction,
55 construction->GetAttr(
"ignore_query_point",
56 &ignore_query_point));
58 OP_REQUIRES_OK(construction, construction->GetAttr(
"return_distances",
60 OP_REQUIRES_OK(construction,
61 construction->GetAttr(
"normalize_distances",
62 &normalize_distances));
65 void Compute(tensorflow::OpKernelContext*
context)
override {
67 static_assert(
sizeof(int64) ==
sizeof(int64_t),
68 "int64 type is not compatible");
70 const Tensor&
points = context->input(0);
71 const Tensor& queries = context->input(1);
72 const Tensor& radii = context->input(2);
73 const Tensor& points_row_splits = context->input(3);
74 const Tensor& queries_row_splits = context->input(4);
78 Dim num_points(
"num_points");
79 Dim num_queries(
"num_queries");
80 Dim batch_size(
"batch_size");
84 CHECK_SHAPE(context, points_row_splits, batch_size + 1);
85 CHECK_SHAPE(context, queries_row_splits, batch_size + 1);
88 Tensor* query_neighbors_row_splits = 0;
89 TensorShape query_neighbors_row_splits_shape(
90 {queries.shape().dim_size(0) + 1});
91 OP_REQUIRES_OK(context, context->allocate_output(
92 1, query_neighbors_row_splits_shape,
93 &query_neighbors_row_splits));
95 Kernel(context, points, queries, radii, points_row_splits,
96 queries_row_splits, *query_neighbors_row_splits);
99 virtual void Kernel(tensorflow::OpKernelContext* context,
100 const tensorflow::Tensor& points,
101 const tensorflow::Tensor& queries,
102 const tensorflow::Tensor& radius,
103 const tensorflow::Tensor& points_row_splits,
104 const tensorflow::Tensor& queries_row_splits,
105 tensorflow::Tensor& query_neighbors_row_splits) = 0;
109 bool ignore_query_point;
110 bool return_distances;
111 bool normalize_distances;
Definition: NeighborSearchCommon.h:38
#define CHECK_SHAPE(tensor,...)
Definition: TorchHelper.h:205
Metric
Supported metrics.
Definition: NeighborSearchCommon.h:38
Class for dimensions for which the value should be inferred.
Definition: ShapeChecking.h:69
ImGuiContext * context
Definition: Window.cpp:95
Definition: NeighborSearchCommon.h:38
Definition: ShapeChecking.h:35
Definition: FaissIndex.cpp:47