29 #include "../TensorFlowHelper.h" 32 #include "tensorflow/core/framework/op.h" 33 #include "tensorflow/core/framework/op_kernel.h" 34 #include "tensorflow/core/lib/core/errors.h" 42 class OutputAllocator {
44 OutputAllocator(tensorflow::OpKernelContext*
context) : context(context) {}
46 void AllocIndices(
int32_t** ptr,
size_t num) {
50 TensorShape shape({int64_t(num)});
51 OP_REQUIRES_OK(
context,
context->allocate_output(0, shape, &tensor));
52 auto flat_tensor = tensor->flat<int32>();
53 static_assert(
sizeof(int32) ==
sizeof(
int32_t),
54 "int32 and int32_t not compatible");
55 *ptr = (
int32_t*)flat_tensor.data();
58 void AllocDistances(T** ptr,
size_t num) {
62 TensorShape shape({int64_t(num)});
63 OP_REQUIRES_OK(
context,
context->allocate_output(2, shape, &tensor));
64 auto flat_tensor = tensor->flat<T>();
65 *ptr = flat_tensor.data();
69 tensorflow::OpKernelContext*
context;
72 class FixedRadiusSearchOpKernel :
public tensorflow::OpKernel {
74 explicit FixedRadiusSearchOpKernel(
75 tensorflow::OpKernelConstruction* construction)
76 : OpKernel(construction) {
80 std::string metric_str;
81 OP_REQUIRES_OK(construction,
82 construction->GetAttr(
"metric", &metric_str));
83 if (metric_str ==
"L1")
85 else if (metric_str ==
"L2")
90 OP_REQUIRES_OK(construction,
91 construction->GetAttr(
"ignore_query_point",
92 &ignore_query_point));
94 OP_REQUIRES_OK(construction, construction->GetAttr(
"return_distances",
98 void Compute(tensorflow::OpKernelContext*
context)
override {
100 static_assert(
sizeof(int64) ==
sizeof(int64_t),
101 "int64 type is not compatible");
103 const Tensor&
points = context->input(0);
104 const Tensor& queries = context->input(1);
106 const Tensor& radius = context->input(2);
107 OP_REQUIRES(context, TensorShapeUtils::IsScalar(radius.shape()),
108 errors::InvalidArgument(
"radius must be scalar, got shape ",
109 radius.shape().DebugString()));
111 const Tensor& points_row_splits = context->input(3);
112 const Tensor& queries_row_splits = context->input(4);
114 const Tensor& hash_table_splits = context->input(5);
115 const Tensor& hash_table_index = context->input(6);
116 const Tensor& hash_table_cell_splits = context->input(7);
121 Dim num_points(
"num_points");
122 Dim num_queries(
"num_queries");
123 Dim batch_size(
"batch_size");
124 Dim num_cells(
"num_cells");
126 CHECK_SHAPE(context, hash_table_index, num_points);
128 CHECK_SHAPE(context, points_row_splits, batch_size + 1);
129 CHECK_SHAPE(context, queries_row_splits, batch_size + 1);
130 CHECK_SHAPE(context, hash_table_splits, batch_size + 1);
131 CHECK_SHAPE(context, hash_table_cell_splits, num_cells + 1);
133 Tensor* query_neighbors_row_splits = 0;
134 TensorShape query_neighbors_row_splits_shape(
135 {queries.shape().dim_size(0) + 1});
136 OP_REQUIRES_OK(context, context->allocate_output(
137 1, query_neighbors_row_splits_shape,
138 &query_neighbors_row_splits));
140 Kernel(context, points, queries, radius, points_row_splits,
141 queries_row_splits, hash_table_splits, hash_table_index,
142 hash_table_cell_splits, *query_neighbors_row_splits);
145 virtual void Kernel(tensorflow::OpKernelContext* context,
146 const tensorflow::Tensor& points,
147 const tensorflow::Tensor& queries,
148 const tensorflow::Tensor& radius,
149 const tensorflow::Tensor& points_row_splits,
150 const tensorflow::Tensor& queries_row_splits,
151 const tensorflow::Tensor& hash_table_splits,
152 const tensorflow::Tensor& hash_table_index,
153 const tensorflow::Tensor& hash_table_cell_splits,
154 tensorflow::Tensor& query_neighbors_row_splits) = 0;
158 bool ignore_query_point;
159 bool return_distances;
Definition: NeighborSearchCommon.h:38
#define CHECK_SHAPE(tensor,...)
Definition: TorchHelper.h:205
Metric
Supported metrics.
Definition: NeighborSearchCommon.h:38
const char const char value recording_handle imu_sample recording_handle uint8_t size_t data_size k4a_record_configuration_t config target_format k4a_capture_t capture_handle k4a_imu_sample_t imu_sample playback_handle k4a_logging_message_cb_t void min_level device_handle k4a_imu_sample_t int32_t
Definition: K4aPlugin.cpp:398
Class for dimensions for which the value should be inferred.
Definition: ShapeChecking.h:69
ImGuiContext * context
Definition: Window.cpp:95
Definition: NeighborSearchCommon.h:38
Definition: NeighborSearchCommon.h:38
Definition: ShapeChecking.h:35
Definition: FaissIndex.cpp:47