Source code for spine.math.neighbors

"""Numba JIT compiled implementation of neighbor query routines.

In particular, this module supports:
- Radius-based neighbor classification
- kNN-based neighbor classification
"""

import numba as nb
import numpy as np

from .base import mode
from .distance import cdist, get_metric_id

__all__ = ["RadiusNeighborsClassifier", "KNeighborsClassifier"]


RNC_DTYPE = (
    ("radius", nb.float32),
    ("metric_id", nb.int64),
    ("p", nb.float32),
    ("iterate", nb.boolean),
)


KNC_DTYPE = (("k", nb.int64), ("metric_id", nb.int64), ("p", nb.float32))


[docs] @nb.experimental.jitclass(spec=RNC_DTYPE) # type: ignore[call-arg] class RadiusNeighborsClassifier: """Class which assigns labels to points based on radial neighborhood majority vote. More specifically, for each point that is to be labeled: - Find all labeled points within some radius R; - Label the point based on majority vote. If there are no labeled points in the neighborhood of a query point, a label of -1 is assigned to the query point. Currently this is bruteforced with cdist, but in the future this is intended to be used with a KDTree backend for quicker query. Attributes ---------- radius : float Radius around which to check metric_id : int Distance metric enumerator p : float p-norm factor for the Minkowski metric, if used iterate : bool Whether to recurse the search until no new labels are assigned """ def __init__( self, radius: float, metric: str = "euclidean", p: float = 2.0, iterate: bool = True, ) -> None: """Initialize the RadiusNeighborsClassifier parameters. Parameters ---------- radius : float Radius around which to check metric : str, default 'euclidean' Distance metric p : float, default 2. p-norm factor for the Minkowski metric, if used iterate : bool, default True Whether to recurse the search until no new labels are assigned """ if radius < 0.0: raise ValueError("Radius must be non-negative.") # For Euclidean, save time by using squared Euclidean if metric == "euclidean": metric = "sqeuclidean" radius = radius * radius # Store parameters self.radius = radius self.metric_id = get_metric_id(metric, p) self.p = p self.iterate = iterate def fit_predict( self, X: np.ndarray, y: np.ndarray, Xq: np.ndarray ) -> tuple[np.ndarray, np.ndarray]: """Assign labels to a set of points given a set of reference points. Parameters ---------- X : np.ndarray (N, 3) Set of reference points y : np.ndarray (N,) Labels of reference points Xq : nb.ndarray (M, 3) Set of query points Returns ------- np.ndarray (M,) Labels assigned to the query points np.ndarray Index of points which have not been successfully assigned """ # Loop over query points until no new labels can be assigned num_query = len(Xq) labels = np.empty(num_query, dtype=np.int64) orphan_index = np.arange(num_query, dtype=np.int64) while num_query > 0: # Start by computing the distance between the query and reference dists = cdist(Xq, X, metric_id=self.metric_id, p=self.p) # Fetch the mask of reference points closer than some radius mask = dists < self.radius # Loop over query points assigned = np.zeros(num_query, dtype=np.bool_) for i in range(num_query): # Find the set of points within the predefined radius index = np.where(mask[i])[0] # Use the mode to define the label if len(index): labels[orphan_index[i]] = mode(y[index]) assigned[i] = True else: labels[orphan_index[i]] = -1 # If the number of orphans is unchanged, break orphan_update = np.where(~assigned)[0] if len(orphan_update) == 0 or len(orphan_update) == num_query: orphan_index = orphan_index[orphan_update] break # If no recursion is required, abort loop if not self.iterate: orphan_index = orphan_index[orphan_update] break # Update the reference and query points label_update = np.where(assigned)[0] X = Xq[label_update] Xq = Xq[orphan_update] y = labels[orphan_index[label_update]] # Update orphan list orphan_index = orphan_index[orphan_update] num_query = len(orphan_index) return labels, orphan_index
[docs] @nb.experimental.jitclass(spec=KNC_DTYPE) # type: ignore[call-arg] class KNeighborsClassifier: """Class which assigns labels to points based on a nearest neighbor majority vote. More specifically, for each point that is to be labeled: - Find the k closest labeled points; - Label the point based on majority vote. If there are no labeled points in the neighborhood of a query point, a label of -1 is assigned to the query point. Currently this is bruteforced with cdist, but in the future this is intended to be used with a KDTree backend for quicker query. Attributes ---------- k : int Number of neighbors to query metric_id : int Distance metric enumerator p : float p-norm factor for the Minkowski metric, if used """ def __init__(self, k: int, metric: str = "euclidean", p: float = 2.0) -> None: """Initialize the RadiusNeighborsClassifier parameters. Parameters ---------- k : int Number of neighbors to query metric : str, default 'euclidean' Distance metric p : float, default 2. p-norm factor for the Minkowski metric, if used """ if k <= 0: raise ValueError("Number of neighbors must be positive.") # For Euclidean, save time by using squared Euclidean if metric == "euclidean": metric = "sqeuclidean" # Store parameters self.k = k self.metric_id = get_metric_id(metric, p) self.p = p def fit_predict( self, X: np.ndarray, y: np.ndarray, Xq: np.ndarray ) -> tuple[np.ndarray, np.ndarray]: """Assign labels to a set of points given a set of reference points. Parameters ---------- X : np.ndarray (N, 3) Set of reference points y : np.ndarray (N,) Labels of reference points Xq : nb.ndarray (M, 3) Set of query points Returns ------- np.ndarray (M,) Labels assigned to the query points np.ndarray Index of points which have not been successfully assigned """ # If there are no labeled points provided, nothing to do if len(X) == 0: return ( np.full(len(Xq), -1, dtype=np.int64), np.arange(len(Xq), dtype=np.int64), ) # Start by computing the distance between the query and reference dists = cdist(Xq, X, metric_id=self.metric_id, p=self.p) # Loop over query points labels = np.empty(len(Xq), dtype=np.int64) for i in range(len(Xq)): # Find the list k closest labels index = np.argsort(dists[i])[: self.k] # Use the mode to define the label labels[i] = mode(y[index]) return labels, np.empty(0, dtype=np.int64)