"""Functions to find the best overlaps between point sets."""
import numba as nb
import numpy as np
from spine.math.distance import cdist
__all__ = [
"overlap_count",
"overlap_iou",
"overlap_weighted_iou",
"overlap_dice",
"overlap_weighted_dice",
"overlap_chamfer",
]
@nb.njit(cache=True)
def intersection_size_sorted(x: nb.int64[:], y: nb.int64[:]) -> nb.int64:
"""Compute the size of the intersection of two sorted unique arrays."""
i = j = count = 0
while i < len(x) and j < len(y):
if x[i] == y[j]:
count += 1
i += 1
j += 1
elif x[i] < y[j]:
i += 1
else:
j += 1
return count
[docs]
@nb.njit(cache=True, parallel=True)
def overlap_count(
index_x: nb.types.List(nb.int64[:]), index_y: nb.types.List(nb.int64[:])
) -> nb.int64[:, :]:
"""Computes a set overlap matrix by overlap count.
Parameters
----------
index_x: nb.types.List[np.ndarray]
(N) nb.types.List of tensor index, one per object to match
index_y: nb.types.List[np.ndarray]
(M) nb.types.List of tensor index, one per object to be matched to
Returns
-------
np.ndarray
(M, N) Overlap count matrix
"""
overlap_matrix = np.zeros((len(index_x), len(index_y)), dtype=np.int64)
for i in nb.prange(len(index_x)):
px = index_x[i]
if len(px):
for j, py in enumerate(index_y):
if len(py):
if px[-1] < py[0] or py[-1] < px[0]:
continue
overlap_matrix[i, j] = intersection_size_sorted(px, py)
return overlap_matrix
[docs]
@nb.njit(cache=True, parallel=True)
def overlap_iou(
index_x: nb.types.List(nb.int64[:]), index_y: nb.types.List(nb.int64[:])
) -> nb.float32[:, :]:
"""Computes a set overlap matrix by IoU.
IoU stands for Intersection-over-Union.
Parameters
----------
index_x: nb.types.List[np.ndarray]
(N) nb.types.List of tensor index, one per object to match
index_y: nb.types.List[np.ndarray]
(M) nb.types.List of tensor index, one per object to be matched to
Returns
-------
np.ndarray
(M, N) Overlap IoU matrix
"""
overlap_matrix = np.zeros((len(index_x), len(index_y)), dtype=np.float32)
for i in nb.prange(len(index_x)):
px = index_x[i]
if len(px):
for j, py in enumerate(index_y):
if len(py):
if px[-1] < py[0] or py[-1] < px[0]:
continue
cap = intersection_size_sorted(px, py)
if cap > 0:
cup = len(px) + len(py) - cap
overlap_matrix[i, j] = cap / cup
return overlap_matrix
[docs]
@nb.njit(cache=True, parallel=True)
def overlap_weighted_iou(
index_x: nb.types.List(nb.int64[:]), index_y: nb.types.List(nb.int64[:])
) -> nb.float32[:, :]:
"""Computes a set overlap matrix by IoU, weighted by the set sizes.
IoU stands for Intersection-over-Union. The weighting scheme is as follows:
`w = abs(size_x + size_y) / (abs(size_x - size_y) + 1)`.
Parameters
----------
index_x: nb.types.List[np.ndarray]
(N) nb.types.List of tensor index, one per object to match
index_y: nb.types.List[np.ndarray]
(M) nb.types.List of tensor index, one per object to be matched to
Returns
-------
np.ndarray
(M, N) Overlap weighted IoU matrix
"""
overlap_matrix = np.zeros((len(index_x), len(index_y)), dtype=np.float32)
for i in nb.prange(len(index_x)):
px = index_x[i]
if len(px):
for j, py in enumerate(index_y):
if len(py):
if px[-1] < py[0] or py[-1] < px[0]:
continue
cap = intersection_size_sorted(px, py)
if cap > 0:
cup = len(px) + len(py) - cap
n, m = px.shape[0], py.shape[0]
overlap_matrix[i, j] = (cap / cup) * (n + m) / (1 + abs(n - m))
return overlap_matrix
[docs]
@nb.njit(cache=True, parallel=True)
def overlap_dice(
index_x: nb.types.List(nb.int64[:]), index_y: nb.types.List(nb.int64[:])
) -> nb.float32[:, :]:
"""Computes a set overlap matrix by Dice coefficient.
The Dice coefficient corresponds to the 2 times the intersection of two
sets over the sum of set sizes.
Parameters
----------
index_x: nb.types.List[np.ndarray]
(N) nb.types.List of tensor index, one per object to match
index_y: nb.types.List[np.ndarray]
(M) nb.types.List of tensor index, one per object to be matched to
Returns
-------
np.ndarray
(M, N) Overlap weighted IoU matrix
"""
overlap_matrix = np.zeros((len(index_x), len(index_y)), dtype=np.float32)
for i in nb.prange(len(index_x)):
px = index_x[i]
if len(px):
for j, py in enumerate(index_y):
if len(py):
if px[-1] < py[0] or py[-1] < px[0]:
continue
cap = intersection_size_sorted(px, py)
if cap > 0:
denom = len(px) + len(py)
overlap_matrix[i, j] = 2.0 * cap / denom
return overlap_matrix
[docs]
@nb.njit(cache=True, parallel=True)
def overlap_weighted_dice(
index_x: nb.types.List(nb.int64[:]), index_y: nb.types.List(nb.int64[:])
) -> nb.float32[:, :]:
"""Computes a set overlap matrix by Dice coefficient, weighted by the
set sizes.
The Dice coefficient corresponds to the 2 times the intersection of two
sets over the sum of set sizes. The weighting scheme is as follows:
`w = abs(size_x + size_y) / (abs(size_x - size_y) + 1)`.
Parameters
----------
index_x: nb.types.List[np.ndarray]
(N) nb.types.List of tensor index, one per object to match
index_y: nb.types.List[np.ndarray]
(M) nb.types.List of tensor index, one per object to be matched to
Returns
-------
np.ndarray
(M, N) Overlap weighted IoU matrix
"""
overlap_matrix = np.zeros((len(index_x), len(index_y)), dtype=np.float32)
for i in nb.prange(len(index_x)):
px = index_x[i]
if len(px):
for j, py in enumerate(index_y):
if len(py):
if px[-1] < py[0] or py[-1] < px[0]:
continue
cap = intersection_size_sorted(px, py)
if cap > 0:
denom = len(px) + len(py)
n, m = px.shape[0], py.shape[0]
w = (n + m) / (1 + abs(n - m))
overlap_matrix[i, j] = (2.0 * cap / denom) * w
return overlap_matrix
[docs]
@nb.njit(cache=True, parallel=True)
def overlap_chamfer(
points_x: nb.types.List(nb.int64[:]), points_y: nb.types.List(nb.int64[:])
) -> nb.float32[:, :]:
"""Computes a set overlap matrix by Chamfer distance.
This function can match two arbitrary points clouds, hence there is no need
for the two particle lists to share the same underlying voxel sets.
Parameters
----------
points_x: nb.types.List[np.ndarray]
(N, 3) nb.types.List of coordinates, one per object to match
points_y: nb.types.List[np.ndarray]
(M, 3) nb.types.List of coordinates, one per object to be matched to
Returns
-------
np.ndarray
(M, N) Chamfer distance matrix
Notes
-----
Unlike the overlap metrics, this metric should be minimized.
"""
overlap_matrix = np.full(np.inf, (len(points_x), len(points_y)), dtype=np.float32)
for i in nb.prange(len(points_x)):
px = points_x[i]
if len(px):
for j, py in enumerate(points_y):
if len(py):
# Compute the voxel pairwise distances
dist = cdist(px, py)
# Compute the average chamfer distance
loss_x = np.min(dist, axis=1)
loss_y = np.min(dist, axis=0)
loss = loss_x.sum() / len(loss_x) + loss_y.sum() / len(loss_y)
overlap_matrix[i, j] = loss
return overlap_matrix