"""Numba JIT compiled implementation of distance computation routines.
This module is entirely dedicated to 3D points, which is the core representation
of objects targeted by this software package.
"""
import numba as nb
import numpy as np
from .base import argmin, mean
__all__ = [
"cityblock",
"euclidean",
"sqeuclidean",
"minkowski",
"chebyshev",
"pdist",
"cdist",
"farthest_pair",
"closest_pair",
"closest_pair_legacy",
]
MINKOWSKI = 0
CITYBLOCK = 1
EUCLIDEAN = 2
SQEUCLIDEAN = 3
CHEBYSHEV = 4
# Available distance metrics. Keep the public mapping for callers, while using
# named integer constants internally so Numba sees stable scalar IDs.
METRICS = {
"minkowski": MINKOWSKI,
"cityblock": CITYBLOCK,
"euclidean": EUCLIDEAN,
"sqeuclidean": SQEUCLIDEAN,
"chebyshev": CHEBYSHEV,
}
@nb.njit(cache=True)
def get_metric_id(metric: str, p: float) -> int:
"""Checks on the metric name, returns an enumerated form of the metric.
Parameters
----------
metric : str, default 'euclidean'
Distance metric
p : float
p-norm factor for the Minkowski metric, if used
Returns
-------
int
Enumerated form of the distance metric
"""
if metric == "minkowski":
if p == 1.0:
return CITYBLOCK
elif p == 2.0:
return EUCLIDEAN
else:
return MINKOWSKI
elif metric == "cityblock":
return CITYBLOCK
elif metric == "euclidean":
return EUCLIDEAN
elif metric == "sqeuclidean":
return SQEUCLIDEAN
elif metric == "chebyshev":
return CHEBYSHEV
else:
raise ValueError(f"Distance metric not recognized: {metric}")
[docs]
@nb.njit(cache=True)
def cityblock(x: np.ndarray, y: np.ndarray) -> float:
"""Compute the cityblock distance (L1) between two 3D points.
Parameters
----------
x : np.ndarray
(3,) Coordinates of the first point
y : np.ndarray
(3,) Coordinates of the second point
Returns
-------
float
Cityblock distance
"""
return abs(y[0] - x[0]) + abs(y[1] - x[1]) + abs(y[2] - x[2])
[docs]
@nb.njit(cache=True)
def euclidean(x: np.ndarray, y: np.ndarray) -> float:
"""Compute the Euclidean distance (L2) between two 3D points.
Parameters
----------
x : np.ndarray
(3,) Coordinates of the first point
y : np.ndarray
(3,) Coordinates of the second point
Returns
-------
float
Euclidean distance
"""
return np.sqrt((y[0] - x[0]) ** 2 + (y[1] - x[1]) ** 2 + (y[2] - x[2]) ** 2)
[docs]
@nb.njit(cache=True)
def sqeuclidean(x: np.ndarray, y: np.ndarray) -> float:
"""Compute the squared Euclidean distance (L2) between two 3D points.
Parameters
----------
x : np.ndarray
(3,) Coordinates of the first point
y : np.ndarray
(3,) Coordinates of the second point
Returns
-------
float
Squared Euclidean distance
"""
return (y[0] - x[0]) ** 2 + (y[1] - x[1]) ** 2 + (y[2] - x[2]) ** 2
[docs]
@nb.njit(cache=True)
def chebyshev(x: np.ndarray, y: np.ndarray) -> float:
"""Compute the Chebyshev distance (Linf) between two 3D points.
Parameters
----------
x : np.ndarray
(3,) Coordinates of the first point
y : np.ndarray
(3,) Coordinates of the second point
Returns
-------
float
Chebyshev distance
"""
return max(abs(y[0] - x[0]), abs(y[1] - x[1]), abs(y[2] - x[2]))
[docs]
@nb.njit(cache=True)
def minkowski(x: np.ndarray, y: np.ndarray, p: float) -> float:
"""Compute the Minkowski distance (Lp) between two 3D points.
Parameters
----------
x : np.ndarray
(3,) Coordinates of the first point
y : np.ndarray
(3,) Coordinates of the second point
Returns
-------
float
Minkowski distance
"""
return pow(
abs(y[0] - x[0]) ** p + abs(y[1] - x[1]) ** p + abs(y[2] - x[2]) ** p, 1.0 / p
)
[docs]
@nb.njit(cache=True)
def pdist(
x: np.ndarray, metric_id: int = METRICS["euclidean"], p: float = 2.0
) -> np.ndarray:
"""Numba implementation of
`scipy.spatial.distance.pdist(x, metric=metric, p=p)` in 3D.
Parameters
----------
x : np.ndarray
(N, 3) array of point coordinates in the set
metric_id : int, default 2 (Euclidean)
Distance metric enumerator
p : float, default 2.
p-norm factor for the Minkowski metric, if used
Returns
-------
np.ndarray
(N, N) array of pair-wise Euclidean distances
"""
# Check on the input
assert x.shape[1] == 3, "Only supports 3D points for now."
# Dispatch (faster this way than dispatching at each distance call)
if metric_id == MINKOWSKI:
return _pdist_minkowski(x, p)
elif metric_id == CITYBLOCK:
return _pdist_cityblock(x)
elif metric_id == EUCLIDEAN:
return _pdist_euclidean(x)
elif metric_id == SQEUCLIDEAN:
return _pdist_sqeuclidean(x)
elif metric_id == CHEBYSHEV:
return _pdist_chebyshev(x)
else:
raise ValueError("Distance metric not recognized.")
@nb.njit(cache=True)
def _pdist_cityblock(x: np.ndarray) -> np.ndarray:
res = np.empty((len(x), len(x)), dtype=x.dtype)
for i, xi in enumerate(x):
res[i, i] = 0.0
for j in range(i + 1, len(x)):
res[i, j] = res[j, i] = cityblock(xi, x[j])
return res
@nb.njit(cache=True)
def _pdist_euclidean(x: np.ndarray) -> np.ndarray:
res = np.empty((len(x), len(x)), dtype=x.dtype)
for i, xi in enumerate(x):
res[i, i] = 0.0
for j in range(i + 1, len(x)):
res[i, j] = res[j, i] = euclidean(xi, x[j])
return res
@nb.njit(cache=True)
def _pdist_sqeuclidean(x: np.ndarray) -> np.ndarray:
res = np.empty((len(x), len(x)), dtype=x.dtype)
for i, xi in enumerate(x):
res[i, i] = 0.0
for j in range(i + 1, len(x)):
res[i, j] = res[j, i] = sqeuclidean(xi, x[j])
return res
@nb.njit(cache=True)
def _pdist_chebyshev(x: np.ndarray) -> np.ndarray:
res = np.empty((len(x), len(x)), dtype=x.dtype)
for i, xi in enumerate(x):
res[i, i] = 0.0
for j in range(i + 1, len(x)):
res[i, j] = res[j, i] = chebyshev(xi, x[j])
return res
@nb.njit(cache=True)
def _pdist_minkowski(x: np.ndarray, p: float) -> np.ndarray:
res = np.empty((len(x), len(x)), dtype=x.dtype)
for i, xi in enumerate(x):
res[i, i] = 0.0
for j in range(i + 1, len(x)):
res[i, j] = res[j, i] = minkowski(xi, x[j], p)
return res
[docs]
@nb.njit(cache=True)
def cdist(
x1: np.ndarray,
x2: np.ndarray,
metric_id: int = METRICS["euclidean"],
p: float = 2.0,
) -> np.ndarray:
"""Numba implementation of Euclidean
`scipy.spatial.distance.cdist(x, metric=p=2)` in 3D.
Parameters
----------
x1 : np.ndarray
(N, 3) array of point coordinates in the first set
x2 : np.ndarray
(M, 3) array of point coordinates in the second set
metric_id : int, default 2 (Euclidean)
Distance metric enumerator
p : float, default 2.
p-norm factor for the Minkowski metric, if used
Returns
-------
np.ndarray
(N, M) array of pair-wise Euclidean distances
"""
# Check on the input
assert x1.shape[1] == 3 and x2.shape[1] == 3, "Only supports 3D points for now."
# Dispatch (faster this way than dispatching at each distance call)
if metric_id == MINKOWSKI:
return _cdist_minkowski(x1, x2, p)
elif metric_id == CITYBLOCK:
return _cdist_cityblock(x1, x2)
elif metric_id == EUCLIDEAN:
return _cdist_euclidean(x1, x2)
elif metric_id == SQEUCLIDEAN:
return _cdist_sqeuclidean(x1, x2)
elif metric_id == CHEBYSHEV:
return _cdist_chebyshev(x1, x2)
else:
raise ValueError("Distance metric not recognized.")
@nb.njit(cache=True)
def _cdist_cityblock(x1: np.ndarray, x2: np.ndarray) -> np.ndarray:
res = np.empty((len(x1), len(x2)), dtype=x1.dtype)
for i1, x1i in enumerate(x1):
for i2, x2i in enumerate(x2):
res[i1, i2] = cityblock(x1i, x2i)
return res
@nb.njit(cache=True)
def _cdist_euclidean(x1: np.ndarray, x2: np.ndarray) -> np.ndarray:
res = np.empty((len(x1), len(x2)), dtype=x1.dtype)
for i1, x1i in enumerate(x1):
for i2, x2i in enumerate(x2):
res[i1, i2] = euclidean(x1i, x2i)
return res
@nb.njit(cache=True)
def _cdist_sqeuclidean(x1: np.ndarray, x2: np.ndarray) -> np.ndarray:
res = np.empty((len(x1), len(x2)), dtype=x1.dtype)
for i1, x1i in enumerate(x1):
for i2, x2i in enumerate(x2):
res[i1, i2] = sqeuclidean(x1i, x2i)
return res
@nb.njit(cache=True)
def _cdist_chebyshev(x1: np.ndarray, x2: np.ndarray) -> np.ndarray:
res = np.empty((len(x1), len(x2)), dtype=x1.dtype)
for i1, x1i in enumerate(x1):
for i2, x2i in enumerate(x2):
res[i1, i2] = chebyshev(x1i, x2i)
return res
@nb.njit(cache=True)
def _cdist_minkowski(x1: np.ndarray, x2: np.ndarray, p: float) -> np.ndarray:
res = np.empty((len(x1), len(x2)), dtype=x1.dtype)
for i1, x1i in enumerate(x1):
for i2, x2i in enumerate(x2):
res[i1, i2] = minkowski(x1i, x2i, p)
return res
[docs]
@nb.njit(cache=True)
def farthest_pair(
x: np.ndarray,
iterative: bool = False,
metric_id: int = METRICS["euclidean"],
p: float = 2.0,
) -> tuple[int, int, float]:
"""Algorithm which finds the two points which are farthest from each other
in a set, in the Euclidean sense.
Two algorithms are available:
- `brute`: computes all pairwise distances and uses `argmax`.
- `iterative`: repeatedly jumps to the current farthest point until
convergence. It is not exact, but it is fast.
Parameters
----------
x : np.ndarray
(N, 3) array of point coordinates
iterative : bool
If `True`, uses an iterative, fast approximation
metric_id : int, default 2 (Euclidean)
Distance metric enumerator
p : float
p-norm factor for the Minkowski metric, if used
Returns
-------
int
ID of the first point that makes up the pair
int
ID of the second point that makes up the pair
float
Distance between the two points
"""
# To save time, if Euclidean distance is used, use its square
is_euclidean = False
if metric_id == EUCLIDEAN:
is_euclidean = True
metric_id = SQEUCLIDEAN
# Dispatch
if not iterative:
# Find the distance between every pair of points
dist_mat = pdist(x, metric_id, p)
# Select the pair with the farthest distance, fetch indexes
index = np.argmax(dist_mat)
i, j = index // len(x), index % len(x)
# Record farthest distance
dist = dist_mat[i, j]
else:
# Seed the search with the point farthest from the centroid
centroid = mean(x, 0)
start_idx = np.argmax(cdist(centroid[None, :], x, metric_id, p))
# Jump to the farthest point until convergence
pair_idxs, set_id = [start_idx, start_idx], 0
dist = -np.inf
while True:
previous_dist = dist
other_id = 1 - set_id
dists = cdist(x[pair_idxs[set_id]][None, :], x, metric_id, p).flatten()
farthest_idx = np.argmax(dists)
farthest_dist = float(dists[farthest_idx])
if farthest_dist <= previous_dist:
break
pair_idxs[other_id] = farthest_idx
dist = farthest_dist
set_id = other_id
# Unroll index
i, j = pair_idxs
# If needed, take the square root of the distance
if is_euclidean:
dist = np.sqrt(dist)
return int(i), int(j), float(dist)
[docs]
@nb.njit(cache=True)
def closest_pair_legacy(
x1: np.ndarray,
x2: np.ndarray,
iterative: bool = False,
seed: bool = True,
metric_id: int = METRICS["euclidean"],
p: float = 2.0,
) -> tuple[int, int, float]:
"""Legacy closest-pair implementation kept for model compatibility.
This preserves the historical iterative behavior, including the missing
set switch after each closest-point update. New code should use
:func:`closest_pair`.
"""
# To save time, if Euclidean distance is used, use its square
is_euclidean = False
if metric_id == EUCLIDEAN:
is_euclidean = True
metric_id = SQEUCLIDEAN
# Find the two points in two sets of points that are closest to each other
if not iterative:
# Compute every pair-wise distances between the two sets
dist_mat = cdist(x1, x2, metric_id, p)
# Select the closest pair of point, fetch indexes
index = np.argmin(dist_mat)
i, j = index // len(x2), index % len(x2)
# Record closest distance
dist = dist_mat[i, j]
else:
# Pick the point to start iterating from
xarr = [x1, x2]
idxs, set_id, dist, tempdist = [0, 0], 0, 1e9, 1e9 + 1.0
if seed:
# Find the end points of the two sets
for i, xi in enumerate(xarr):
other_id = 1 - i
seed_idxs = np.array(farthest_pair(xi, True)[:2])
seed_dists = cdist(xi[seed_idxs], xarr[other_id], metric_id, p)
seed_argmins = argmin(seed_dists, axis=1)
seed_mins = np.array(
[seed_dists[0][seed_argmins[0]], seed_dists[1][seed_argmins[1]]]
)
if np.min(seed_mins) < dist:
set_id = other_id
seed_choice = int(np.argmin(seed_mins))
idxs[i] = int(seed_idxs[seed_choice])
idxs[set_id] = int(seed_argmins[seed_choice])
dist = float(seed_mins[seed_choice])
# Historically this loop did not switch `set_id` after updating the
# closest point in the opposite set. Preserve that behavior here for
# compatibility with trained models and reference outputs.
while dist < tempdist:
tempdist = dist
other_id = 1 - set_id
dists = cdist(
xarr[set_id][idxs[set_id]][None, :], xarr[other_id], metric_id, p
).flatten()
closest_idx = int(np.argmin(dists))
idxs[other_id] = closest_idx
dist = float(dists[closest_idx])
# Unroll index
i, j = idxs
# If needed, take the square root of the distance
if is_euclidean:
dist = np.sqrt(dist)
return int(i), int(j), float(dist)
[docs]
@nb.njit(cache=True)
def closest_pair(
x1: np.ndarray,
x2: np.ndarray,
iterative: bool = False,
seed: bool = True,
metric_id: int = METRICS["euclidean"],
p: float = 2.0,
) -> tuple[int, int, float]:
"""Algorithm which finds the two points which are closest to each other
from two separate sets.
Two algorithms are available:
- `brute`: computes all cross-distances and uses `argmin`.
- `iterative`: repeatedly jumps to the current closest point until
convergence. It is not exact, but it is fast.
Parameters
----------
x1 : np.ndarray
(N, 3) array of point coordinates in the first set
x2 : np.ndarray
(M, 3) array of point coordinates in the second set
iterative : bool
If `True`, uses an iterative, fast approximation
seed : bool
Whether or not to use the two farthest points in one of the two sets
to seed the iterative algorithm
metric_id : int, default 2 (Euclidean)
Distance metric enumerator
p : float, default 2.
p-norm factor for the Minkowski metric, if used
Returns
-------
int
ID of the first point that makes up the pair
int
ID of the second point that makes up the pair
float
Distance between the two points
"""
# To save time, if Euclidean distance is used, use its square
is_euclidean = False
if metric_id == EUCLIDEAN:
is_euclidean = True
metric_id = SQEUCLIDEAN
# Find the two points in two sets of points that are closest to each other
if not iterative:
# Compute every pair-wise distances between the two sets
dist_mat = cdist(x1, x2, metric_id, p)
# Select the closest pair of point, fetch indexes
index = np.argmin(dist_mat)
i, j = index // len(x2), index % len(x2)
# Record closest distance
dist = dist_mat[i, j]
else:
# Pick the point to start iterating from
point_sets = [x1, x2]
pair_idxs, set_id = [0, 0], 0
dist = np.inf
if seed:
# Find the end points of the two sets
for i, xi in enumerate(point_sets):
other_id = 1 - i
seed_idxs = np.array(farthest_pair(xi, True)[:2])
seed_dists = cdist(xi[seed_idxs], point_sets[other_id], metric_id, p)
seed_argmins = argmin(seed_dists, axis=1)
seed_mins = np.array(
[seed_dists[0][seed_argmins[0]], seed_dists[1][seed_argmins[1]]]
)
if np.min(seed_mins) < dist:
set_id = other_id
seed_choice = int(np.argmin(seed_mins))
pair_idxs[i] = int(seed_idxs[seed_choice])
pair_idxs[set_id] = int(seed_argmins[seed_choice])
dist = float(seed_mins[seed_choice])
# Find the closest point in the other set, repeat until convergence
while True:
previous_dist = dist
other_id = 1 - set_id
dists = cdist(
point_sets[set_id][pair_idxs[set_id]][None, :],
point_sets[other_id],
metric_id,
p,
).flatten()
closest_idx = int(np.argmin(dists))
closest_dist = float(dists[closest_idx])
if closest_dist >= previous_dist:
break
pair_idxs[other_id] = closest_idx
dist = closest_dist
set_id = other_id
# Unroll index
i, j = pair_idxs
# If needed, take the square root of the distance
if is_euclidean:
dist = np.sqrt(dist)
return int(i), int(j), float(dist)