Source code for spine.model.grappa

"""GrapPA: Graph Neural Network for Particle Aggregation.

This module implements the GrapPA (Graph Particle Aggregation) architecture,
a graph neural network designed for clustering and grouping particle instances.

GrapPA learns to aggregate fragment-level features into particle-level clusters
through message passing and edge classification, enabling particle instance
segmentation and identification.
"""

import numpy as np
import torch
from torch import nn

from spine.constants import GROUP_COL, LOWES_SHP, SHAPE_COL, TRACK_SHP
from spine.constants.factory import enum_factory
from spine.data import TensorBatch
from spine.utils.gnn.cluster import (
    form_clusters_batch,
    get_cluster_label_batch,
    get_cluster_primary_label_batch,
)
from spine.utils.gnn.evaluation import (
    node_assignment_batch,
    node_assignment_score_batch,
)

from .layer.common.dbscan import DBSCAN
from .layer.factories import final_factory
from .layer.gnn.factories import (
    edge_encoder_factory,
    edge_loss_factory,
    global_encoder_factory,
    global_loss_factory,
    gnn_model_factory,
    graph_factory,
    node_encoder_factory,
    node_loss_factory,
)

__all__ = ["GrapPA", "GrapPALoss"]


[docs] class GrapPA(torch.nn.Module): """Graph Particle Aggregator (GrapPA) model. This class mostly acts as a wrapper that will hand the graph data to the underlying graph neural network (GNN). When trained standalone, this model must be provided with a cluster label tensor, allowing it to build a set of intput clusters based on the label boundaries of the clusters and their semantic types. Typical configuration can look like this: .. code-block:: yaml model: name: grappa modules: grappa: nodes: <dictionary of arguments to specify the input type> graph: name: <name of the input graph type> <dictionary of arguments to specify the graph> node_encoder: name: <name of the type of node encoder> <dictionary of arguments to specify the node encoder> edge_encoder: name: <name of the type of edge encoder> <dictionary of arguments to specify the edge encoder> global_encoder: name: <name of the type of global encoder> <dictionary of arguments to specify the global encoder> gnn_model: name: <name of the type of backbone GNN feature extractor> <dictionary of arguments to specify the GNN> See configuration files prefixed with `grappa_` under the `config` directory for detailed examples of working configurations. See Also -------- :class:`GrapPALoss` """ # TODO: update MODULES = [ ("grappa", ["base", "dbscan", "node_encoder", "edge_encoder", "gnn_model"]), "grappa_loss", ] def __init__(self, grappa, grappa_loss=None): """Initialize the GrapPA model. Parameters ---------- grappa : dict Model configuration grappa_loss : dict, optional Loss configuration """ # Initialize the parent class super().__init__() # Process the model configuration self.process_model_config(**grappa)
[docs] def process_model_config( self, gnn_model, nodes=None, graph=None, node_encoder=None, edge_encoder=None, global_encoder=None, dbscan=None, return_features=False, ): """Process the top-level configuration block. This dispatches each block to its own configuration processor. Parameters ---------- gnn_model : dict Underlying graph neural network configuration nodes : dict, optional Input node configuration graph : dict, optional Input graph configuration node_encoder : dict, optional Node encoder configuration edge_encoder : dict, optional Edge encoder configuration global_encoder : dict, optional Global encoder configuration dbscan : dict, optional DBSCAN fragmentation configuration return_features : bool, default False If `True`, the model will return the node/edge/global features """ # Store the output types of GNNs self.out_types = ("node", "edge", "global") if nodes is None and (graph is not None or dbscan is not None): raise ValueError( "Must provide a `nodes` configuration when using GrapPA to " "build clusters/graphs on the fly." ) # Construct the underlying graph neural network self.process_gnn_config(**gnn_model) # Process the node configuration if nodes is not None: self.process_node_config(**nodes) # Process the graph configuration self.graph_constructor = None if graph is not None: self.graph_constructor = graph_factory(graph, self.node_type) # Process the encoder configurations self.node_encoder = None if node_encoder is not None: self.node_encoder = node_encoder_factory(node_encoder) # Initialize edge encoder self.edge_encoder = None if edge_encoder is not None: self.edge_encoder = edge_encoder_factory(edge_encoder) # Initialize the global encoder self.global_encoder = None if global_encoder is not None: self.global_encoder = global_encoder_factory(global_encoder) # Process the dbscan fragmenter configuration, if provided self.dbscan = None if dbscan is not None: self.process_dbscan_config(dbscan) # Store whether to return the features self.return_features = return_features
[docs] def process_node_config( self, source="cluster", shapes=None, min_size=-1, make_groups=False, grouping_method="score", grouping_through_track=False, ): """Process the node parameters of the model. Parameters ---------- source : str, default 'cluster' Column name in the label tensor which contains the input cluster IDs shapes : int, optional Type of nodes to include in the input. If not specified, include all types min_size : int, default -1 Minimum number of voxels in a cluster to be included in the input make_groups : bool, default False Use edge predictions to build node groups grouping_method : str, default 'score' Algorithm used to build a node partition grouping_through_track : bool, default False If `True`, shower objects can only be connected to one track object """ # Parse the node source self.node_source = enum_factory("cluster", source) # Interpret node type as list of shapes to cluster if shapes is None: self.node_type = list(np.arange(LOWES_SHP)) else: if np.isscalar(shapes): raise ValueError("Semantic classes should be provided as a list.") self.node_type = enum_factory("shape", shapes) # Store the node parameters self.node_min_size = min_size self.make_groups = make_groups self.grouping_method = grouping_method self.grouping_through_track = grouping_through_track
[docs] def process_gnn_config( self, node_pred=None, edge_pred=None, global_pred=None, **gnn_model ): """Process the GNN backbone structure and the output layers. Parameters ---------- node_pred : Union[int, dict], optional Number of node predictions. If there are multiple node predictions, provide a (key, value) pair for each type of prediction edge_pred : Union[int, dict], optional Number of edge predictions. If there are multiple edge predictions, provide a (key, value) pair for each type of prediction global_pred : Union[int, dict], optional Number of global predictions. If there are multiple global predictions, provide a (key, value) pair for each type of prediction **gnn_model, dict Paramters to initialize the GNN backbone """ # Initialize the GNN backbone self.gnn = gnn_model_factory( gnn_model, node_pred is not None, edge_pred is not None, global_pred is not None, ) # Initialize output layers based on the configuration self.process_final_config(node_pred, "node") self.process_final_config(edge_pred, "edge") self.process_final_config(global_pred, "global")
[docs] def process_final_config(self, final, prefix): """Process a final layer configuration. Parameters ---------- final : Union[int, dict] Final layer configuration prefix : str Name of the final layer """ # If the final layer is not specified, nothing to do here if final is None: setattr(self, f"{prefix}_pred_keys", []) return # If the final layer is specified as a number, use linear layer if isinstance(final, int): final = {"name": "linear", "out_channels": final} # Process the configuration dictionary otherwise out_keys = [] in_channels = getattr(self.gnn, f"{prefix}_feature_size") if "name" in final: # Initialize a single final layer (single prediction of this type) out_key = f"{prefix}_pred" out_keys.append(out_key) setattr(self, out_key, final_factory(in_channels, **final)) else: # Otherwise, initialize one final layer per prediction type for key, cfg in final.items(): # If the final layer is specified as a number, use linear layer out_key = f"{prefix}_{key}_pred" out_keys.append(out_key) if isinstance(cfg, int): cfg = {"name": "linear", "out_channels": cfg} setattr(self, out_key, final_factory(in_channels, **cfg)) setattr(self, f"{prefix}_pred_keys", out_keys)
[docs] def process_dbscan_config(self, shapes=None, min_size=None, **kwargs): """Process the DBSCAN fragmenter configuration. Parameters ---------- shapes : Union[int, list], optional This should not be specified (fetched from the node configuration) min_size : Union[int, list], optional This should not be specified (fetched from the node configuration) **kwargs : dict, optional Rest of the DBSCAN configuration """ # Make sure the basic parameters are not specified twice if shapes is not None or min_size is not None: raise ValueError( "Do not specify 'shapes' or 'min_size' in the " "`dbscan` block, it is shared with the `node` block" ) # Initialize DBSCAN fragmenter self.dbscan = DBSCAN( shapes=self.node_type, min_size=self.node_min_size, **kwargs )
[docs] def forward( self, data, coord_label=None, clusts=None, edge_index=None, node_features=None, edge_features=None, global_features=None, shapes=None, groups=None, points=None, extra=None, ): """Prepares particle clusters and feed them to the GNN model. Parameters ---------- data : TensorBatch Tensor of voxel/value pairs with shape `(N, 1 + D + N_f)`, where `N` is the total number of voxels, the leading column stores the batch ID, `D` is the image dimensionality and `N_f` is the number of features. When `clusts` is not provided, the features must also contain the labels needed to build clusters on the fly. coord_label : TensorBatch, optional (P, 1 + 2*D + 2) Tensor of label points (start/end/time/shape) clusts : IndexBatch, optional (C) List of indexes corresponding to each cluster edge_index : EdgeIndexBatch, optional (E, 2) Incidence matrix. If not provided, it will be built based on the cluster indexes and the graph configuration node_features : TensorBatch, optional (C, N_c,f) Node features. If not provided, they will be built based on edge_features : TensorBatch, optional (C, N_e,f) Edge features. If not provided, they will be built based on the cluster indexes and the edge encoder configuration global_features : TensorBatch, optional (C, N_g,f) Global features. If not provided, they will be built based on the cluster indexes and the global encoder configuration shapes : TensorBatch, optional (C) List of cluster semantic class used to define the max length groups : TensorBatch, optional (C) List of node groups, one per cluster. If specified, removes connections between nodes that belong to different groups. points : TensorBatch, optional (C, 3/6) Tensor of start (and end) points extra : TensorBatch, optional (C, N_f) Batch of features to append to the existing node features Returns ------- clusts : IndexBatch (C, N_c, N_{c,i}) Cluster indexes edge_index : EdgeIndexBatch (E, 2) Incidence matrix node_features : TensorBatch (C, N_c,f) Node features edge_features : TensorBatch (C, N_e,f) Node features global_features : TensorBatch (C, N_g,f) Global features node_pred : TensorBatch (C, N_n) Node predictions (logits) edge_pred : TensorBatch (C, N_e) Edge predictions (logits) global_pred : TensorBatch (C, N_e) Global predictions (logits) """ # Initialize the result dictionary that will be returned at the end of the method result = {} # Encode the node boundaries as clusters if they are not provided directly if clusts is None: clusts = self._make_clusters(data, coord_label=coord_label) result["clusts"] = clusts # If needed, infer per-cluster shapes once and reuse them downstream shapes = self._get_shapes(data, clusts, shapes) # Build the graph if it is not provided directly closest_index = None if edge_index is None: if self.graph_constructor is None: raise ValueError( "Must provide edge_index or graph configuration to build it." ) edge_index, closest_index = self._make_edge_index( data, clusts, shapes=shapes, groups=groups ) result["edge_index"] = edge_index # Fetch the node features if node_features is None: if self.node_encoder is None: raise ValueError( "Must provide node_features or node encoder configuration to build them." ) node_features = self.node_encoder( data, clusts, coord_label=coord_label, points=points, extra=extra ) if isinstance(node_features, tuple): # If the output of the node encoder is a tuple, separate points node_features, points = node_features start_points, end_points = points.tensor.split(3, dim=1) result["start_points"] = TensorBatch( start_points, points.counts, coord_cols=np.array([0, 1, 2]) ) result["end_points"] = TensorBatch( end_points, points.counts, coord_cols=np.array([0, 1, 2]) ) if self.return_features: result["node_features"] = node_features # Fetch the edge features if edge_features is None and self.edge_encoder is not None: edge_features = self.edge_encoder( data, clusts, edge_index, closest_index=closest_index ) if self.return_features and edge_features is not None: result["edge_features"] = edge_features # Fetch the global_features if global_features is None and self.global_encoder is not None: global_features = self.global_encoder(data, clusts) if global_features is not None and self.return_features: result["global_features"] = global_features # Bring edge_index and batch_ids to device # TODO: try to keep everything (apart from clusts?) on GPU? index = torch.tensor(edge_index.index, device=data.tensor.device) xbatch = torch.tensor(clusts.batch_ids, device=data.tensor.device) # Pass through the model, update results out = self.gnn(node_features, index, edge_features, global_features, xbatch) # Loop over the necessary node/edge/global predictions, store output for t in self.out_types: for key in getattr(self, f"{t}_pred_keys"): result[key] = getattr(self, key)(out[f"{t}_features"]) # If requested, build node groups from edge predictions if self.make_groups: self._make_groups(result, clusts, edge_index, shapes=shapes) return result
def _make_clusters(self, data, coord_label=None): """Make the list of node clusters based on the label tensor and the requested class. Parameters ---------- data : TensorBatch Tensor of voxel/value pairs with shape `(N, 1 + D + N_f)`, where `N` is the total number of voxels, the leading column stores the batch ID, `D` is the image dimensionality and `N_f` is the number of features. The features must also contain the labels needed to build clusters on the fly. coord_label : TensorBatch, optional (P, 1 + 2*D + 2) Tensor of label points Returns ------- clusts : IndexBatch (C, N_c, N_{c,i}) Cluster indexes """ if self.dbscan is not None: # Use the DBSCAN fragmenter to build the clusters seg_label = TensorBatch(data.tensor[:, SHAPE_COL], data.counts) clusts, _ = self.dbscan(data, seg_label, coord_label) else: # Use the label tensor to build the clusters clusts = form_clusters_batch( data.to_numpy(), self.node_min_size, self.node_source, shapes=self.node_type, ) return clusts def _get_shapes(self, data, clusts, shapes=None): """Return per-cluster semantic labels if the graph logic needs them. Parameters ---------- data : TensorBatch Tensor of voxel/value pairs with shape `(N, 1 + D + N_f)`. clusts : IndexBatch (C) List of indexes corresponding to each cluster shapes : TensorBatch, optional (C) Explicit semantic label per cluster Returns ------- TensorBatch or None Cluster semantic labels, or `None` if they are not needed and were not provided. """ if shapes is not None: return shapes if self.graph_constructor is None or not hasattr( self.graph_constructor.max_length, "__len__" ): return None data_np = data.to_numpy() if self.node_source == GROUP_COL: shapes = get_cluster_primary_label_batch(data_np, clusts, SHAPE_COL) else: shapes = get_cluster_label_batch(data_np, clusts, SHAPE_COL) shapes.data = shapes.data.astype(np.int64) return shapes def _make_edge_index(self, data, clusts, shapes=None, groups=None): """Make the edge index based on the cluster indexes and the graph configuration. Parameters ---------- data : TensorBatch Tensor of voxel/value pairs with shape `(N, 1 + D + N_f)`, where `N` is the total number of voxels, the leading column stores the batch ID, `D` is the image dimensionality and `N_f` is the number of features. The features must also contain the labels needed to build clusters on the fly. clusts : IndexBatch (C) List of indexes corresponding to each cluster shapes : TensorBatch, optional (C) List of cluster semantic class used to define the max length groups : TensorBatch, optional (C) List of node groups, one per cluster. If specified, removes connections between nodes that belong to different groups. Returns ------- edge_index : EdgeIndexBatch (E, 2) Incidence matrix closest_index : TensorBatch (E) List of closest voxel index for each edge """ # Check that the graph constructor is defined if self.graph_constructor is None: raise ValueError( "Must provide graph configuration to build edge index from clusters." ) # Bring data to numpy for the graph construction data_np = data.to_numpy() # Initialize the input graph edge_index, _, closest_index = self.graph_constructor( data_np, clusts, shapes, groups ) return edge_index, closest_index def _make_groups(self, result, clusts, edge_index, shapes=None): """Make node groups based on edge predictions. Parameters ---------- result : dict Dictionary containing the output of the model, including edge predictions and cluster indexes. The edge predictions should be stored under keys with the format `edge{key}_pred`, where `key` is the name of the type of edge prediction (e.g. score or threshold). The cluster indexes should be stored under the key `clusts`. clusts : IndexBatch (C) List of indexes corresponding to each cluster edge_index : EdgeIndexBatch (E, 2) Incidence matrix shapes : TensorBatch, optional (C) List of cluster semantic class used to restrict track association """ # Fetch the list of edge prediction keys edge_pred_keys = [ key for key in result if key.startswith("edge") and key.endswith("pred") ] if not edge_pred_keys: raise ValueError( "Must provide edge predictions to build node groups. " "Edge predictions should be stored under keys with the format " "`edge{key}_pred`, where `key` is the name of the prediction head." ) # Loop over the edge predictions, build node groups based on each of them for key in edge_pred_keys: edge_pred = result[key].to_numpy() prefix = "group" + key.replace("edge", "").replace("_pred", "") if self.grouping_method == "threshold": result[f"{prefix}_pred"] = node_assignment_batch( edge_index, edge_pred, clusts ) elif self.grouping_method == "score": if not self.grouping_through_track: result[f"{prefix}_pred"] = node_assignment_score_batch( edge_index, edge_pred, clusts ) else: if shapes is None: raise ValueError( "Must provide shapes to restrict track association." ) track_node = TensorBatch( shapes.data == TRACK_SHP, counts=shapes.counts ) result[f"{prefix}_pred"] = node_assignment_score_batch( edge_index, edge_pred, clusts, track_node ) else: raise ValueError( "Group prediction algorithm not recognized:", self.grouping_method, )
[docs] class GrapPALoss(torch.nn.modules.loss._Loss): """Takes the output of the GrapPA and computes the total loss. For use in config: .. code-block:: yaml model: name: grappa modules: grappa_loss: node_loss: name: <name of the node loss> <dictionary of arguments to pass to the loss> edge_loss: name: <name of the edge loss> <dictionary of arguments to pass to the loss> global_loss: name: <name of the global loss> <dictionary of arguments to pass to the loss> Each of the specific loss blocks can also contain multiple losses by providing a name key in a loss block nested below it. Each loss name of a specific type should be provided with a corresponding output from GRaPA. See configuration files prefixed with `grappa_` under the `config` directory for detailed examples of working configurations. """ def __init__(self, grappa_loss, grappa=None): """Initialize the GrapPA loss function. Parameters ---------- grappa_loss : dict Loss configuration grappa : dict, optional Model configuration """ # Initialize the parent class super().__init__() # Process the loss configuration self.process_loss_config(**grappa_loss)
[docs] def process_loss_config(self, node_loss=None, edge_loss=None, global_loss=None): """Process the loss configuration. Parameters ---------- node_loss : Union[dict, Dict[dict]], optional Node loss configuration edge_loss : Union[dict, Dict[dict]], optional Edge loss configuration global_loss : Union[dict, Dict[dict]], optional Global loss configuration """ # Check that there is at least one loss to apply self.out_types = ("node", "edge", "global") if node_loss is None and edge_loss is None and global_loss is None: raise ValueError( "Must provide at least one of `node_loss`, `edge_loss` or " "`global_loss` to the GrapPA loss function." ) # Initialize the node/edge/global losses self.process_single_loss_config("node", node_loss, node_loss_factory) self.process_single_loss_config("edge", edge_loss, edge_loss_factory) self.process_single_loss_config("global", global_loss, global_loss_factory)
[docs] def process_single_loss_config(self, prefix, loss, constructor): """Process a loss configuration. Parameters ---------- prefix : dict Name of the output type to apply the loss to loss : Union[int, dict] Loss configuration constructor : object Loss constructor function """ # If the loss is not specified, nothing to do here if loss is None: setattr(self, f"{prefix}_loss_keys", []) return # Process the configuration dictionary otherwise loss_keys = [] if "name" in loss: # Initialize a single loss loss_key = f"{prefix}_loss" loss_keys.append(loss_key) setattr(self, loss_key, constructor(loss)) else: # Otherwise, initialzie one loss per prediction type for key, cfg in loss.items(): loss_key = f"{prefix}_{key}_loss" loss_keys.append(loss_key) setattr(self, loss_key, constructor(cfg)) setattr(self, f"{prefix}_loss_keys", loss_keys)
[docs] def forward( self, clust_label, coord_label=None, graph_label=None, iteration=None, **output ): """Apply the node/edge/global losses to the logits from GrapPA. Parameters ---------- clust_label : TensorBatch (N, 1 + D + N_f) Tensor of voxel/value pairs - N is the the total number of voxels in the image - 1 is the batch ID - D is the number of dimensions in the input image - N_f is is the number of cluster labels coord_label : TensorBatch, optional (P, 1 + D + 8) Tensor of start/end point labels for each true particle in the image graph_label : EdgeIndexTensor, optional (2, E) Tensor of edges that correspond to physical connections between true particle in the image iteration : int, optional Iteration index **output : dict Output of the GrapPA model """ # Loop and apply the losses result = {} num_losses = 0 loss, accuracy = 0.0, 0.0 for t in self.out_types: loss_keys = getattr(self, f"{t}_loss_keys") for key in loss_keys: # If the number of loss keys is > 1 for this type of # prediction, must rename the prediction appropriately extra = {} if len(loss_keys) > 1: extra[f"{t}_pred"] = output[key.replace("loss", "pred")] # Compute the loss out = getattr(self, key)( clust_label=clust_label, coord_label=coord_label, graph_label=graph_label, iteration=iteration, **output, **extra, ) # Increment the loss and accuracy loss += out["loss"] accuracy += out["accuracy"] num_losses += 1 # Update the result dictionary prefix = "_".join(key.split("_")[:-1]) for k, v in out.items(): result[f"{prefix}_{k}"] = v # Append the total loss and total accuracy result["loss"] = torch.sum(loss) / num_losses result["accuracy"] = np.sum(accuracy) / num_losses return result