Source code for spine.math.graph

"""Numba JIT compiled implementation of graph routines.

In particular, this module supports the CSR data structure and derived methods,
which tremendously speeds up graph construction and computation in Numba.
"""

import numba as nb
import numpy as np

from .distance import (
    CHEBYSHEV,
    CITYBLOCK,
    EUCLIDEAN,
    METRICS,
    MINKOWSKI,
    SQEUCLIDEAN,
    chebyshev,
    cityblock,
    minkowski,
    sqeuclidean,
)

CSR_DTYPE = (
    ("num_nodes", nb.int64),
    ("neighbors", nb.int64[:]),
    ("offsets", nb.int64[:]),
)


[docs] @nb.experimental.jitclass(spec=CSR_DTYPE) # type: ignore[call-arg] class CSRGraph: """Numba-enabled compressed Sparse Row (CSR) representation of a sparse matrix. Attributes ---------- neighbors : np.ndarray (E,) List of node neighbors in a compressed array offsets : np.ndarray (N + 1,) Per-node slicing boundaries to query each node neighborhood num_nodes : int Number of nodes in the graph, N """ def __init__(self, neighbors: np.ndarray, offsets: np.ndarray, num_nodes: int): """Construct the Compressed Sparse Row (CSR) representation of a sparse matrix based on a list of nodes and edges. Parameters ---------- neighbors : np.ndarray (E,) List of node neighbors in a compressed array offsets : np.ndarray (N + 1,) Per-node slicing boundaries to query each node neighborhood num_nodes : int Number of nodes in the graph, N """ self.neighbors = neighbors self.offsets = offsets self.num_nodes = num_nodes def __getitem__(self, node_id: int) -> np.ndarray: """Get the list of neighbors associated with a node. Parameters ---------- node_id : int Node index i Returns ------- np.ndarray List of neighbors associated with node i """ start, end = self.offsets[node_id], self.offsets[node_id + 1] return self.neighbors[start:end] def num_neighbors(self, node_id: int) -> int: """Returns the number of neighbors of a node. Parameters ---------- node_id : int Node index i Returns ------- int Number of neighbors of node i """ start, end = self.offsets[node_id], self.offsets[node_id + 1] return end - start
[docs] @nb.njit def csr_graph( edge_index: np.ndarray, num_nodes: int, directed: bool = True ) -> CSRGraph: """Construct the Compressed Sparse Row (CSR) representation of a sparse matrix based on a list of nodes and edges. Parameters ---------- edge_index : np.ndarray (E, 2) List of active edge indices in the graph num_nodes : int Number of nodes in the graph, N directed : bool Whether the input graph is directed or not """ # Count the number of connections per node counts = np.zeros(num_nodes, dtype=np.int64) for s, t in edge_index: counts[s] += 1 if not directed: counts[t] += 1 # Build the offsets array offsets = np.empty(num_nodes + 1, dtype=np.int64) offsets[0] = 0 for i in range(num_nodes): offsets[i + 1] = offsets[i] + counts[i] # Build the neighbors array neighbors = np.empty(offsets[-1], dtype=np.int64) fill = np.zeros(num_nodes, dtype=np.int64) for s, t in edge_index: idx = offsets[s] + fill[s] neighbors[idx] = t fill[s] += 1 if not directed: idx = offsets[t] + fill[t] neighbors[idx] = s fill[t] += 1 # Initialize the CSR graph return CSRGraph(neighbors, offsets, num_nodes)
[docs] @nb.njit(cache=True) def connected_components( edge_index: np.ndarray, num_nodes: int, min_samples: int = 1, directed: bool = True, ) -> np.ndarray: """Find connected components. Parameters ---------- edge_index : np.ndarray (E, 2) List of active edge indices in the graph num_nodes : int Number of nodes in the graph, N directed : bool, default True Whether the input graph is directed or not Returns ------- np.ndarray (N,) Cluster label associated with each node """ # Initialize the CSR data structure graph = csr_graph(edge_index, num_nodes, directed) # Initialize output labels = np.arange(graph.num_nodes) visited = np.zeros(graph.num_nodes, dtype=np.bool_) component = np.empty(graph.num_nodes, dtype=np.int64) comp_idx = np.empty(1, dtype=np.int64) # Acts as pointer # Loop through all nodes and start DFS from unvisited nodes label = 0 min_neighbors = min_samples - 1 for node in range(graph.num_nodes): if not visited[node]: if graph.num_neighbors(node) >= min_neighbors: # Perform DFS and collect all nodes in this connected component comp_idx[0] = 0 dfs_iterative(graph, visited, node, component, comp_idx) # Collect all nodes that belong to the same connected component for i in range(comp_idx[0]): labels[component[i]] = label else: # Relabel solitary nodes to maintain ordering labels[node] = label # Increment label label += 1 return labels
[docs] @nb.njit(cache=True) def dfs( graph: CSRGraph, visited: np.ndarray, node: int, component: np.ndarray, comp_idx: np.ndarray, ) -> None: """Does a depth-first search and builds a connected component. Parameters ---------- graph : CSRGraph CSR representation of a graph visited : np.ndarray (N,) Boolean array which specifies whether a node has been visited. node : int Current node index component : np.ndarray (N,) Current component (padded) comp_idx : np.ndarray Current component index (pointer) Notes ----- This implementation is recursive, which is the fastest implementation but silently throws segmentation faults if the maximum recursion depth is reached. The :func:`dfs_iterative` function is safer, but slightly slower. """ # Mark the node as visited, increment pointer visited[node] = True component[comp_idx[0]] = node comp_idx[0] += 1 # Traverse all the neighbors of the node for neighbor in graph[node]: if not visited[neighbor]: dfs(graph, visited, neighbor, component, comp_idx)
[docs] @nb.njit(cache=True) def dfs_iterative( graph: CSRGraph, visited: np.ndarray, start_node: int, component: np.ndarray, comp_idx: np.ndarray, ) -> None: """Does a depth-first search and builds a connected component. Parameters ---------- graph : CSRGraph CSR representation of a graph visited : np.ndarray (N,) Boolean array which specifies whether a node has been visited. start_node : int Starting node index component : np.ndarray (N,) Current component (padded) comp_idx : np.ndarray Current component index (pointer) Notes ----- This implementation is iterative and does not suffer from the recursion depth maximum issue which affects the recursive version, at a small cost to the overall execution speed. """ # Initialize a node stack (fixed size) stack = np.empty(graph.num_nodes, dtype=np.int64) stack[0] = start_node stack_idx = 1 visited[start_node] = True # Loop until there is no more node to visit while stack_idx > 0: stack_idx -= 1 node = stack[stack_idx] component[comp_idx[0]] = node comp_idx[0] += 1 for neighbor in graph[node]: if not visited[neighbor]: visited[neighbor] = True stack[stack_idx] = neighbor stack_idx += 1
[docs] @nb.njit(cache=True) def radius_graph( x: np.ndarray, radius: float, metric_id: int = METRICS["euclidean"], p: float = 2.0, ) -> np.ndarray: """Builds an undirected radius graph. This function generates a list of edges in a graph which connects all nodes which live within some radius R of each other. Parameters ---------- x : np.ndarray (N, 3) array of node coordinates radius : float Radius within which to build connections in the graph metric_id : int, default 2 (Euclidean) Distance metric enumerator p : float, default 2. p-norm factor for the Minkowski metric, if used Returns ------- np.ndarray (E, 2) array of edges in the radius graph """ # Determine the distance function to use. If the metric is Euclidean, it # is cheaper to square the radius and use the squared Euclidean metric if metric_id == MINKOWSKI: return _radius_graph_minkowski(x, radius, p) elif metric_id == CITYBLOCK: return _radius_graph_cityblock(x, radius) elif metric_id == EUCLIDEAN: radius = radius * radius return _radius_graph_sqeuclidean(x, radius) elif metric_id == SQEUCLIDEAN: return _radius_graph_sqeuclidean(x, radius) elif metric_id == CHEBYSHEV: return _radius_graph_chebyshev(x, radius) else: raise ValueError("Distance metric not recognized.")
@nb.njit(cache=True) def _radius_graph_minkowski(x: np.ndarray, radius: float, p: float) -> np.ndarray: # Initialize a data structure to hold edges num_nodes = len(x) max_edges = num_nodes * (num_nodes - 1) // 2 edge_index = np.empty((max_edges, 2), dtype=np.int64) # Loop over pairs of nodes, add edges if the distance fits the bill edge_count = 0 for i in range(num_nodes): for j in range(i + 1, num_nodes): if minkowski(x[i], x[j], p) <= radius: edge_index[edge_count, 0], edge_index[edge_count, 1] = i, j edge_count += 1 return edge_index[:edge_count] @nb.njit(cache=True) def _radius_graph_cityblock(x: np.ndarray, radius: float) -> np.ndarray: # Initialize a data structure to hold edges num_nodes = len(x) max_edges = num_nodes * (num_nodes - 1) // 2 edge_index = np.empty((max_edges, 2), dtype=np.int64) # Loop over pairs of nodes, add edges if the distance fits the bill edge_count = 0 for i in range(num_nodes): for j in range(i + 1, num_nodes): if cityblock(x[i], x[j]) <= radius: edge_index[edge_count, 0], edge_index[edge_count, 1] = i, j edge_count += 1 return edge_index[:edge_count] @nb.njit(cache=True) def _radius_graph_sqeuclidean(x: np.ndarray, radius: float) -> np.ndarray: # Initialize a data structure to hold edges num_nodes = len(x) max_edges = num_nodes * (num_nodes - 1) // 2 edge_index = np.empty((max_edges, 2), dtype=np.int64) # Loop over pairs of nodes, add edges if the distance fits the bill edge_count = 0 for i in range(num_nodes): for j in range(i + 1, num_nodes): if sqeuclidean(x[i], x[j]) <= radius: edge_index[edge_count, 0], edge_index[edge_count, 1] = i, j edge_count += 1 return edge_index[:edge_count] @nb.njit(cache=True) def _radius_graph_chebyshev(x: np.ndarray, radius: float) -> np.ndarray: # Initialize a data structure to hold edges num_nodes = len(x) max_edges = num_nodes * (num_nodes - 1) // 2 edge_index = np.empty((max_edges, 2), dtype=np.int64) # Loop over pairs of nodes, add edges if the distance fits the bill edge_count = 0 for i in range(num_nodes): for j in range(i + 1, num_nodes): if chebyshev(x[i], x[j]) <= radius: edge_index[edge_count, 0], edge_index[edge_count, 1] = i, j edge_count += 1 return edge_index[:edge_count] @nb.njit(cache=True) def _find_root(parents: np.ndarray, node: int) -> int: """Find the root parent of a node with path compression.""" root = node while parents[root] != root: root = parents[root] while parents[node] != node: parent = parents[node] parents[node] = root node = parent return root
[docs] @nb.njit(cache=True) def union_find( edge_index: np.ndarray, count: int, return_inverse: bool = True ) -> tuple[np.ndarray, dict[int, np.ndarray]]: """Numba implementation of the Union-Find algorithm. This function assigns a group to each node in a graph, provided a set of edges connecting the nodes together. Parameters ---------- edge_index : np.ndarray (E, 2) List of edges (sparse adjacency matrix) count : int Number of nodes in the graph, C return_inverse : bool, default True Make sure the group IDs range from 0 to N_groups-1 Returns ------- np.ndarray (C,) Group assignments for each of the nodes in the graph Dict[int, np.ndarray] Dictionary which maps groups to indexes """ if count == 0: labels = np.empty(0, dtype=np.int64) groups = {0: labels} del groups[0] return labels, groups parents = np.arange(count) for src, dst in edge_index: src_root = _find_root(parents, int(src)) dst_root = _find_root(parents, int(dst)) if src_root != dst_root: if src_root < dst_root: parents[dst_root] = src_root else: parents[src_root] = dst_root labels = np.empty(count, dtype=np.int64) for node in range(count): labels[node] = _find_root(parents, node) if return_inverse: mask = np.zeros(count, dtype=np.bool_) mask[labels] = True mapping = np.empty(count, dtype=labels.dtype) mapping[mask] = np.arange(np.sum(mask)) labels = mapping[labels] groups = {labels[0]: np.array([0])} for node in range(1, count): label = labels[node] node_arr = np.array([node]) if label in groups: groups[label] = np.concatenate((groups[label], node_arr)) else: groups[label] = node_arr return labels, groups