spine.model.grappa

GrapPA: Graph Neural Network for Particle Aggregation.

This module implements the GrapPA (Graph Particle Aggregation) architecture, a graph neural network designed for clustering and grouping particle instances.

GrapPA learns to aggregate fragment-level features into particle-level clusters through message passing and edge classification, enabling particle instance segmentation and identification.

Classes

GrapPA(*args, **kwargs)

Graph Particle Aggregator (GrapPA) model.

GrapPALoss(*args, **kwargs)

Takes the output of the GrapPA and computes the total loss.

class spine.model.grappa.GrapPA(*args: Any, **kwargs: Any)[source]

Graph Particle Aggregator (GrapPA) model.

This class mostly acts as a wrapper that will hand the graph data to the underlying graph neural network (GNN).

When trained standalone, this model must be provided with a cluster label tensor, allowing it to build a set of intput clusters based on the label boundaries of the clusters and their semantic types.

Typical configuration can look like this:

model:
  name: grappa
  modules:
    grappa:
      nodes:
        <dictionary of arguments to specify the input type>
      graph:
        name: <name of the input graph type>
        <dictionary of arguments to specify the graph>
      node_encoder:
        name: <name of the type of node encoder>
        <dictionary of arguments to specify the node encoder>
      edge_encoder:
        name: <name of the type of edge encoder>
        <dictionary of arguments to specify the edge encoder>
      global_encoder:
        name: <name of the type of global encoder>
        <dictionary of arguments to specify the global encoder>
      gnn_model:
        name: <name of the type of backbone GNN feature extractor>
        <dictionary of arguments to specify the GNN>

See configuration files prefixed with grappa_ under the config directory for detailed examples of working configurations.

See also

GrapPALoss

Methods

__call__(*args, **kwargs)

Call self as a function.

forward(data[, coord_label, clusts, ...])

Prepares particle clusters and feed them to the GNN model.

process_dbscan_config([shapes, min_size])

Process the DBSCAN fragmenter configuration.

process_final_config(final, prefix)

Process a final layer configuration.

process_gnn_config([node_pred, edge_pred, ...])

Process the GNN backbone structure and the output layers.

process_model_config(gnn_model[, nodes, ...])

Process the top-level configuration block.

process_node_config([source, shapes, ...])

Process the node parameters of the model.

MODULES = [('grappa', ['base', 'dbscan', 'node_encoder', 'edge_encoder', 'gnn_model']), 'grappa_loss']
process_model_config(gnn_model, nodes=None, graph=None, node_encoder=None, edge_encoder=None, global_encoder=None, dbscan=None, return_features=False)[source]

Process the top-level configuration block.

This dispatches each block to its own configuration processor.

Parameters:
  • gnn_model (dict) – Underlying graph neural network configuration

  • nodes (dict, optional) – Input node configuration

  • graph (dict, optional) – Input graph configuration

  • node_encoder (dict, optional) – Node encoder configuration

  • edge_encoder (dict, optional) – Edge encoder configuration

  • global_encoder (dict, optional) – Global encoder configuration

  • dbscan (dict, optional) – DBSCAN fragmentation configuration

  • return_features (bool, default False) – If True, the model will return the node/edge/global features

process_node_config(source='cluster', shapes=None, min_size=-1, make_groups=False, grouping_method='score', grouping_through_track=False)[source]

Process the node parameters of the model.

Parameters:
  • source (str, default 'cluster') – Column name in the label tensor which contains the input cluster IDs

  • shapes (int, optional) – Type of nodes to include in the input. If not specified, include all types

  • min_size (int, default -1) – Minimum number of voxels in a cluster to be included in the input

  • make_groups (bool, default False) – Use edge predictions to build node groups

  • grouping_method (str, default 'score') – Algorithm used to build a node partition

  • grouping_through_track (bool, default False) – If True, shower objects can only be connected to one track object

process_gnn_config(node_pred=None, edge_pred=None, global_pred=None, **gnn_model)[source]

Process the GNN backbone structure and the output layers.

Parameters:
  • node_pred (Union[int, dict], optional) – Number of node predictions. If there are multiple node predictions, provide a (key, value) pair for each type of prediction

  • edge_pred (Union[int, dict], optional) – Number of edge predictions. If there are multiple edge predictions, provide a (key, value) pair for each type of prediction

  • global_pred (Union[int, dict], optional) – Number of global predictions. If there are multiple global predictions, provide a (key, value) pair for each type of prediction

  • **gnn_model – Paramters to initialize the GNN backbone

  • dict – Paramters to initialize the GNN backbone

process_final_config(final, prefix)[source]

Process a final layer configuration.

Parameters:
  • final (Union[int, dict]) – Final layer configuration

  • prefix (str) – Name of the final layer

process_dbscan_config(shapes=None, min_size=None, **kwargs)[source]

Process the DBSCAN fragmenter configuration.

Parameters:
  • shapes (Union[int, list], optional) – This should not be specified (fetched from the node configuration)

  • min_size (Union[int, list], optional) – This should not be specified (fetched from the node configuration)

  • **kwargs (dict, optional) – Rest of the DBSCAN configuration

forward(data, coord_label=None, clusts=None, edge_index=None, node_features=None, edge_features=None, global_features=None, shapes=None, groups=None, points=None, extra=None)[source]

Prepares particle clusters and feed them to the GNN model.

Parameters:
  • data (TensorBatch) – Tensor of voxel/value pairs with shape (N, 1 + D + N_f), where N is the total number of voxels, the leading column stores the batch ID, D is the image dimensionality and N_f is the number of features. When clusts is not provided, the features must also contain the labels needed to build clusters on the fly.

  • coord_label (TensorBatch, optional) – (P, 1 + 2*D + 2) Tensor of label points (start/end/time/shape)

  • clusts (IndexBatch, optional) –

    1. List of indexes corresponding to each cluster

  • edge_index (EdgeIndexBatch, optional) – (E, 2) Incidence matrix. If not provided, it will be built based on the cluster indexes and the graph configuration

  • node_features (TensorBatch, optional) – (C, N_c,f) Node features. If not provided, they will be built based on

  • edge_features (TensorBatch, optional) – (C, N_e,f) Edge features. If not provided, they will be built based on the cluster indexes and the edge encoder configuration

  • global_features (TensorBatch, optional) – (C, N_g,f) Global features. If not provided, they will be built based on the cluster indexes and the global encoder configuration

  • shapes (TensorBatch, optional) –

    1. List of cluster semantic class used to define the max length

  • groups (TensorBatch, optional) – (C) List of node groups, one per cluster. If specified, removes connections between nodes that belong to different groups.

  • points (TensorBatch, optional) – (C, 3/6) Tensor of start (and end) points

  • extra (TensorBatch, optional) – (C, N_f) Batch of features to append to the existing node features

Returns:

  • clusts (IndexBatch) – (C, N_c, N_{c,i}) Cluster indexes

  • edge_index (EdgeIndexBatch) – (E, 2) Incidence matrix

  • node_features (TensorBatch) – (C, N_c,f) Node features

  • edge_features (TensorBatch) – (C, N_e,f) Node features

  • global_features (TensorBatch) – (C, N_g,f) Global features

  • node_pred (TensorBatch) – (C, N_n) Node predictions (logits)

  • edge_pred (TensorBatch) – (C, N_e) Edge predictions (logits)

  • global_pred (TensorBatch) – (C, N_e) Global predictions (logits)

class spine.model.grappa.GrapPALoss(*args: Any, **kwargs: Any)[source]

Takes the output of the GrapPA and computes the total loss.

For use in config:

model:
  name: grappa
  modules:
    grappa_loss:
      node_loss:
        name: <name of the node loss>
        <dictionary of arguments to pass to the loss>
      edge_loss:
        name: <name of the edge loss>
        <dictionary of arguments to pass to the loss>
      global_loss:
        name: <name of the global loss>
        <dictionary of arguments to pass to the loss>

Each of the specific loss blocks can also contain multiple losses by providing a name key in a loss block nested below it. Each loss name of a specific type should be provided with a corresponding output from GRaPA.

See configuration files prefixed with grappa_ under the config directory for detailed examples of working configurations.

Methods

__call__(*args, **kwargs)

Call self as a function.

forward(clust_label[, coord_label, ...])

Apply the node/edge/global losses to the logits from GrapPA.

process_loss_config([node_loss, edge_loss, ...])

Process the loss configuration.

process_single_loss_config(prefix, loss, ...)

Process a loss configuration.

process_loss_config(node_loss=None, edge_loss=None, global_loss=None)[source]

Process the loss configuration.

Parameters:
  • node_loss (Union[dict, Dict[dict]], optional) – Node loss configuration

  • edge_loss (Union[dict, Dict[dict]], optional) – Edge loss configuration

  • global_loss (Union[dict, Dict[dict]], optional) – Global loss configuration

process_single_loss_config(prefix, loss, constructor)[source]

Process a loss configuration.

Parameters:
  • prefix (dict) – Name of the output type to apply the loss to

  • loss (Union[int, dict]) – Loss configuration

  • constructor (object) – Loss constructor function

forward(clust_label, coord_label=None, graph_label=None, iteration=None, **output)[source]

Apply the node/edge/global losses to the logits from GrapPA.

Parameters:
  • clust_label (TensorBatch) – (N, 1 + D + N_f) Tensor of voxel/value pairs - N is the the total number of voxels in the image - 1 is the batch ID - D is the number of dimensions in the input image - N_f is is the number of cluster labels

  • coord_label (TensorBatch, optional) – (P, 1 + D + 8) Tensor of start/end point labels for each true particle in the image

  • graph_label (EdgeIndexTensor, optional) – (2, E) Tensor of edges that correspond to physical connections between true particle in the image

  • iteration (int, optional) – Iteration index

  • **output (dict) – Output of the GrapPA model