class*args, **kwargs)

KNN search for 3D point clouds.

This layer computes the k nearest neighbors for each query point.


This example shows a neighbor search that returns the indices to the found neighbors and the distances.:

import tensorflow as tf
import as ml3d

points = tf.random.normal([20,3])
queries = tf.random.normal([10,3])
k = 8

nsearch = ml3d.layers.KNNSearch(return_distances=True)
ans = nsearch(points, queries, k)
# returns a tuple of neighbors_index, neighbors_row_splits, and neighbors_distance
# Since there are more than k points and we do not ignore any points we can
# reshape the output to [num_queries, k] with
neighbors_index = tf.reshape(ans.neighbors_index, [10,k])
neighbors_distance = tf.reshape(ans.neighbors_distance, [10,k])
  • metric – Either L1, L2 or Linf. Default is L2.

  • ignore_query_point – If True the points that coincide with the center of the search window will be ignored. This excludes the query point if ‘queries’ and ‘points’ are the same point cloud.

  • return_distances – If True the distances for each neighbor will be returned. If False a zero length Tensor will be returned instead.

__init__(metric='L2', ignore_query_point=False, return_distances=False, index_dtype=tf.int32, **kwargs)

Creates the variables of the layer (optional, for subclass implementers).

This is a method that implementers of subclasses of Layer or Model can override if they need a state-creation step in-between layer instantiation and layer call. It is invoked automatically before the first execution of call().

This is typically used to create the weights of Layer subclasses (at the discretion of the subclass implementer).


input_shape – Instance of TensorShape, or list of instances of TensorShape if the layer expects a list of inputs (one instance per input).

call(points, queries, k, points_row_splits=None, queries_row_splits=None)

This function computes the k nearest neighbors for each query point.

  • points – The 3D positions of the input points. This argument must be given as a positional argument!

  • queries – The 3D positions of the query points.

  • k – The number of nearest neighbors to search.

  • points_row_splits – Optional 1D vector with the row splits information if points is batched. This vector is [0, num_points] if there is only 1 batch item.

  • queries_row_splits – Optional 1D vector with the row splits information if queries is batched. This vector is [0, num_queries] if there is only 1 batch item.

Returns: 3 Tensors in the following order


The compact list of indices of the neighbors. The corresponding query point can be inferred from the ‘neighbor_count_row_splits’ vector.


The exclusive prefix sum of the neighbor count for the query points including the total neighbor count as the last element. The size of this array is the number of queries + 1.


Stores the distance to each neighbor if ‘return_distances’ is True. Note that the distances are squared if metric is L2. This is a zero length Tensor if ‘return_distances’ is False.