Source code for spine.vis.drawer.network

"""Draw graph connectivity on top of clustered point clouds."""

from __future__ import annotations

from typing import Any

import numpy as np
import plotly.graph_objs as go

from spine.constants import COORD_COLS
from spine.math.distance import closest_pair

from ..trace.cluster import scatter_clusters
from ..trace.point import scatter_points_2d, scatter_points_3d

__all__ = ["network_topology", "network_schematic"]


[docs] def network_topology( points: np.ndarray, clusts: list[np.ndarray], edge_index: np.ndarray, clust_labels: np.ndarray | None = None, edge_labels: np.ndarray | None = None, mode: str = "scatter", color: str | np.ndarray | None = None, line: dict[str, Any] | None = None, linewidth: float = 2, name: str | None = None, **kwargs: Any, ) -> list[go.Scatter3d] | list[go.Scatter3d | go.Mesh3d]: """Network 3D topological representation in Euclidean space. Parameters ---------- points : np.ndarray (N, 3) array of N points of (..., x, y, z,...) coordinate information clusts : List[np.ndarray] (C) List of cluster indexes edge_index : np.ndarray (E, 2) List of connections between clusters clust_labels : np.ndarray, optional (C) List of cluster labels edge_labels : np.ndarray, optional (E) List of edge labels mode : str, default 'scatter' Drawing mode; one of 'circle', 'scatter', 'ellipsoid', 'cone' or 'hull' color : Union[str, np.ndarray], optional Color of clusters or (C) list of color of clusters line : dict, optional Line property dictionary linewidth : float, default 2 Width of the edge lines name : str, optional Name of the network **kwargs : dict, optional List of additional arguments to pass to plotly Returns ------- List[Union[plotly.graph_objs.Scatter3d, plotly.graph_objs.Mesh3d]] Node and edge traces in the same list """ # Fetch the list of point coordinates if points.shape[1] != 3: points = points[:, COORD_COLS] # Check that color is not passed directly, ambiguous for a network if color is not None: raise ValueError( "Use `clust_labels` instead of `color` to specify node colors." ) # Set the prefix to add to the trace names prefix = f"{name}" if name is not None else "Graph" node_name = name if edge_index is None else f"{prefix} nodes" edge_name = f"{prefix} edges" # Define the trace(s) associated with the graph nodes single_trace = mode in ["circle", "scatter"] traces = scatter_clusters( points, clusts, color=clust_labels, single_trace=single_trace, name=node_name, mode=mode, **kwargs, ) # Define the trace associated with graph edges edge_vertices = np.empty((0, 3), dtype=points.dtype) if len(edge_index): edge_vertices = [] if mode in ["circle", "ellipsoid"]: # For circles and ellipsoids, join centroid to centroid cent = [points[c].mean(axis=0) for c in clusts] for i, j in edge_index: edge_vertices.extend([cent[i], cent[j], [None, None, None]]) elif mode in ["scatter", "hull"]: # For scatter and hull, join closest point to closest point for i, j in edge_index: vi, vj = points[clusts[i]], points[clusts[j]] i1, i2, _ = closest_pair(vi, vj, iterative=True) edge_vertices.extend([vi[i1], vj[i2], [None, None, None]]) else: # For cones, use the cone start points sts = [] for trace in traces: x, y, z = trace["x"], trace["y"], trace["z"] if x is None or y is None or z is None: raise ValueError("Trace is missing coordinate information.") start = [x[0], y[0], z[0]] sts.append(start) for i, j in edge_index: edge_vertices.extend([sts[i], sts[j], [None, None, None]]) edge_vertices = np.vstack(edge_vertices) # Initialize the edge labels, if they are provided if edge_labels is not None: edge_labels = np.repeat(edge_labels, 3) # Add the edge trace traces += scatter_points_3d( edge_vertices, color=edge_labels, line=line, linewidth=linewidth, mode="lines", name=edge_name, ) # Return return traces
[docs] def network_schematic( clusts: list[np.ndarray], edge_index: np.ndarray, clust_labels: np.ndarray, edge_labels: np.ndarray | None = None, color: str | np.ndarray | None = None, name: str | None = None, linewidth: float = 2, **kwargs: Any, ) -> list[go.Scatter]: """Network 2D schematic representation. This is to be used exclusevely with bipartite graphs where the nodes are either classified as primary or secondaries under clust_labels and connections only exist between primaries and secondaries. Parameters ---------- clusts : List[np.ndarray] (C) List of cluster indexes edge_index : np.ndarray (E, 2) List of connections between clusters clust_labels : np.ndarray (C) Whether a cluster is a primary or a secondary edge_labels : np.ndarray, optional (E) List of edge labels linewidth : float, default 2 Width of the edge lines color : Union[str, np.ndarray], optional Color of clusters or (C) list of color of clusters name : str, optional Name of the network linewidth : float, default 2 Width of the edge lines **kwargs : dict, optional List of additional arguments to pass to plotly Returns ------- List[plotly.graph_objs.Scatter] Node and edge traces in the same list """ # Check that color is not passed directly, ambiguous for a network if color is not None: raise ValueError( "Use `clust_labels` instead of `color` to specify node colors." ) # Define the node size on the bases of the cluster size counts = np.array([len(c) for c in clusts]) node_sizes = np.sqrt(counts) # Set the prefix to add to the trace names prefix = f"{name}" if name is not None else "Graph" node_name = name if edge_index is None else f"{prefix} nodes" edge_name = f"{prefix} edges" # Check that the labels are binary (0 or 1) if len(clust_labels) != len(clusts): raise ValueError("Must provide a primary label for each cluster.") if not np.all((clust_labels == 0) | (clust_labels == 1)): raise ValueError("All cluster labels should be 0 or 1.") # Define the hovertext attribute num_clusts = len(clusts) node_labels = [] for i in range(num_clusts): node_labels.append(f"Cluster ID: {i:d}") node_labels[i] += f"<br>Primary: {clust_labels[i]:0.0f}" node_labels[i] += f"<br>Size: {counts[i]:d}" # Define the positions (primaries on the left, secondaries on the right) pos = np.array([[i, l] for i, l in enumerate(clust_labels)]) # Define the trace associated with the graph nodes node_trace = scatter_points_2d( pos, color=clust_labels, hovertext=node_labels, hoverinfo=["x", "y", "text"], markersize=node_sizes, name=node_name, **kwargs, ) # Define the trace associated with the graph edges edge_vertices = np.empty((0, 2), dtype=pos.dtype) if len(edge_index): edge_vertices = [] for i, j in edge_index: edge_vertices.extend([pos[i], pos[j], [None, None]]) edge_vertices = np.vstack(edge_vertices) # Initialize the edge labels, if they are provided. Plotly 2D line traces # do not support per-vertex color arrays, so expose labels as hover text. edge_hovertext = None if edge_labels is not None: edge_hovertext = np.repeat([f"Edge label: {label}" for label in edge_labels], 3) # Add the edge trace edge_trace = scatter_points_2d( edge_vertices, hovertext=edge_hovertext, linewidth=linewidth, mode="lines", name=edge_name, ) return node_trace + edge_trace