"""GrapPA: Graph Neural Network for Particle Aggregation.
This module implements the GrapPA (Graph Particle Aggregation) architecture,
a graph neural network designed for clustering and grouping particle instances.
GrapPA learns to aggregate fragment-level features into particle-level clusters
through message passing and edge classification, enabling particle instance
segmentation and identification.
"""
import numpy as np
import torch
from torch import nn
from spine.constants import GROUP_COL, LOWES_SHP, SHAPE_COL, TRACK_SHP
from spine.constants.factory import enum_factory
from spine.data import TensorBatch
from spine.utils.gnn.cluster import (
form_clusters_batch,
get_cluster_label_batch,
get_cluster_primary_label_batch,
)
from spine.utils.gnn.evaluation import (
node_assignment_batch,
node_assignment_score_batch,
)
from .layer.common.dbscan import DBSCAN
from .layer.factories import final_factory
from .layer.gnn.factories import (
edge_encoder_factory,
edge_loss_factory,
global_encoder_factory,
global_loss_factory,
gnn_model_factory,
graph_factory,
node_encoder_factory,
node_loss_factory,
)
__all__ = ["GrapPA", "GrapPALoss"]
[docs]
class GrapPA(torch.nn.Module):
"""Graph Particle Aggregator (GrapPA) model.
This class mostly acts as a wrapper that will hand the graph data
to the underlying graph neural network (GNN).
When trained standalone, this model must be provided with a cluster
label tensor, allowing it to build a set of intput clusters based on the
label boundaries of the clusters and their semantic types.
Typical configuration can look like this:
.. code-block:: yaml
model:
name: grappa
modules:
grappa:
nodes:
<dictionary of arguments to specify the input type>
graph:
name: <name of the input graph type>
<dictionary of arguments to specify the graph>
node_encoder:
name: <name of the type of node encoder>
<dictionary of arguments to specify the node encoder>
edge_encoder:
name: <name of the type of edge encoder>
<dictionary of arguments to specify the edge encoder>
global_encoder:
name: <name of the type of global encoder>
<dictionary of arguments to specify the global encoder>
gnn_model:
name: <name of the type of backbone GNN feature extractor>
<dictionary of arguments to specify the GNN>
See configuration files prefixed with `grappa_` under the `config`
directory for detailed examples of working configurations.
See Also
--------
:class:`GrapPALoss`
"""
# TODO: update
MODULES = [
("grappa", ["base", "dbscan", "node_encoder", "edge_encoder", "gnn_model"]),
"grappa_loss",
]
def __init__(self, grappa, grappa_loss=None):
"""Initialize the GrapPA model.
Parameters
----------
grappa : dict
Model configuration
grappa_loss : dict, optional
Loss configuration
"""
# Initialize the parent class
super().__init__()
# Process the model configuration
self.process_model_config(**grappa)
[docs]
def process_model_config(
self,
gnn_model,
nodes=None,
graph=None,
node_encoder=None,
edge_encoder=None,
global_encoder=None,
dbscan=None,
return_features=False,
):
"""Process the top-level configuration block.
This dispatches each block to its own configuration processor.
Parameters
----------
gnn_model : dict
Underlying graph neural network configuration
nodes : dict, optional
Input node configuration
graph : dict, optional
Input graph configuration
node_encoder : dict, optional
Node encoder configuration
edge_encoder : dict, optional
Edge encoder configuration
global_encoder : dict, optional
Global encoder configuration
dbscan : dict, optional
DBSCAN fragmentation configuration
return_features : bool, default False
If `True`, the model will return the node/edge/global features
"""
# Store the output types of GNNs
self.out_types = ("node", "edge", "global")
if nodes is None and (graph is not None or dbscan is not None):
raise ValueError(
"Must provide a `nodes` configuration when using GrapPA to "
"build clusters/graphs on the fly."
)
# Construct the underlying graph neural network
self.process_gnn_config(**gnn_model)
# Process the node configuration
if nodes is not None:
self.process_node_config(**nodes)
# Process the graph configuration
self.graph_constructor = None
if graph is not None:
self.graph_constructor = graph_factory(graph, self.node_type)
# Process the encoder configurations
self.node_encoder = None
if node_encoder is not None:
self.node_encoder = node_encoder_factory(node_encoder)
# Initialize edge encoder
self.edge_encoder = None
if edge_encoder is not None:
self.edge_encoder = edge_encoder_factory(edge_encoder)
# Initialize the global encoder
self.global_encoder = None
if global_encoder is not None:
self.global_encoder = global_encoder_factory(global_encoder)
# Process the dbscan fragmenter configuration, if provided
self.dbscan = None
if dbscan is not None:
self.process_dbscan_config(dbscan)
# Store whether to return the features
self.return_features = return_features
[docs]
def process_node_config(
self,
source="cluster",
shapes=None,
min_size=-1,
make_groups=False,
grouping_method="score",
grouping_through_track=False,
):
"""Process the node parameters of the model.
Parameters
----------
source : str, default 'cluster'
Column name in the label tensor which contains the input cluster IDs
shapes : int, optional
Type of nodes to include in the input. If not specified, include
all types
min_size : int, default -1
Minimum number of voxels in a cluster to be included in the input
make_groups : bool, default False
Use edge predictions to build node groups
grouping_method : str, default 'score'
Algorithm used to build a node partition
grouping_through_track : bool, default False
If `True`, shower objects can only be connected to one track object
"""
# Parse the node source
self.node_source = enum_factory("cluster", source)
# Interpret node type as list of shapes to cluster
if shapes is None:
self.node_type = list(np.arange(LOWES_SHP))
else:
if np.isscalar(shapes):
raise ValueError("Semantic classes should be provided as a list.")
self.node_type = enum_factory("shape", shapes)
# Store the node parameters
self.node_min_size = min_size
self.make_groups = make_groups
self.grouping_method = grouping_method
self.grouping_through_track = grouping_through_track
[docs]
def process_gnn_config(
self, node_pred=None, edge_pred=None, global_pred=None, **gnn_model
):
"""Process the GNN backbone structure and the output layers.
Parameters
----------
node_pred : Union[int, dict], optional
Number of node predictions. If there are multiple node predictions,
provide a (key, value) pair for each type of prediction
edge_pred : Union[int, dict], optional
Number of edge predictions. If there are multiple edge predictions,
provide a (key, value) pair for each type of prediction
global_pred : Union[int, dict], optional
Number of global predictions. If there are multiple global predictions,
provide a (key, value) pair for each type of prediction
**gnn_model, dict
Paramters to initialize the GNN backbone
"""
# Initialize the GNN backbone
self.gnn = gnn_model_factory(
gnn_model,
node_pred is not None,
edge_pred is not None,
global_pred is not None,
)
# Initialize output layers based on the configuration
self.process_final_config(node_pred, "node")
self.process_final_config(edge_pred, "edge")
self.process_final_config(global_pred, "global")
[docs]
def process_final_config(self, final, prefix):
"""Process a final layer configuration.
Parameters
----------
final : Union[int, dict]
Final layer configuration
prefix : str
Name of the final layer
"""
# If the final layer is not specified, nothing to do here
if final is None:
setattr(self, f"{prefix}_pred_keys", [])
return
# If the final layer is specified as a number, use linear layer
if isinstance(final, int):
final = {"name": "linear", "out_channels": final}
# Process the configuration dictionary otherwise
out_keys = []
in_channels = getattr(self.gnn, f"{prefix}_feature_size")
if "name" in final:
# Initialize a single final layer (single prediction of this type)
out_key = f"{prefix}_pred"
out_keys.append(out_key)
setattr(self, out_key, final_factory(in_channels, **final))
else:
# Otherwise, initialize one final layer per prediction type
for key, cfg in final.items():
# If the final layer is specified as a number, use linear layer
out_key = f"{prefix}_{key}_pred"
out_keys.append(out_key)
if isinstance(cfg, int):
cfg = {"name": "linear", "out_channels": cfg}
setattr(self, out_key, final_factory(in_channels, **cfg))
setattr(self, f"{prefix}_pred_keys", out_keys)
[docs]
def process_dbscan_config(self, shapes=None, min_size=None, **kwargs):
"""Process the DBSCAN fragmenter configuration.
Parameters
----------
shapes : Union[int, list], optional
This should not be specified (fetched from the node configuration)
min_size : Union[int, list], optional
This should not be specified (fetched from the node configuration)
**kwargs : dict, optional
Rest of the DBSCAN configuration
"""
# Make sure the basic parameters are not specified twice
if shapes is not None or min_size is not None:
raise ValueError(
"Do not specify 'shapes' or 'min_size' in the "
"`dbscan` block, it is shared with the `node` block"
)
# Initialize DBSCAN fragmenter
self.dbscan = DBSCAN(
shapes=self.node_type, min_size=self.node_min_size, **kwargs
)
[docs]
def forward(
self,
data,
coord_label=None,
clusts=None,
edge_index=None,
node_features=None,
edge_features=None,
global_features=None,
shapes=None,
groups=None,
points=None,
extra=None,
):
"""Prepares particle clusters and feed them to the GNN model.
Parameters
----------
data : TensorBatch
Tensor of voxel/value pairs with shape `(N, 1 + D + N_f)`, where
`N` is the total number of voxels, the leading column stores the
batch ID, `D` is the image dimensionality and `N_f` is the number
of features. When `clusts` is not provided, the features must also
contain the labels needed to build clusters on the fly.
coord_label : TensorBatch, optional
(P, 1 + 2*D + 2) Tensor of label points (start/end/time/shape)
clusts : IndexBatch, optional
(C) List of indexes corresponding to each cluster
edge_index : EdgeIndexBatch, optional
(E, 2) Incidence matrix. If not provided, it will be built based on
the cluster indexes and the graph configuration
node_features : TensorBatch, optional
(C, N_c,f) Node features. If not provided, they will be built based on
edge_features : TensorBatch, optional
(C, N_e,f) Edge features. If not provided, they will be built based on
the cluster indexes and the edge encoder configuration
global_features : TensorBatch, optional
(C, N_g,f) Global features. If not provided, they will be built based on
the cluster indexes and the global encoder configuration
shapes : TensorBatch, optional
(C) List of cluster semantic class used to define the max length
groups : TensorBatch, optional
(C) List of node groups, one per cluster. If specified, removes
connections between nodes that belong to different groups.
points : TensorBatch, optional
(C, 3/6) Tensor of start (and end) points
extra : TensorBatch, optional
(C, N_f) Batch of features to append to the existing node features
Returns
-------
clusts : IndexBatch
(C, N_c, N_{c,i}) Cluster indexes
edge_index : EdgeIndexBatch
(E, 2) Incidence matrix
node_features : TensorBatch
(C, N_c,f) Node features
edge_features : TensorBatch
(C, N_e,f) Node features
global_features : TensorBatch
(C, N_g,f) Global features
node_pred : TensorBatch
(C, N_n) Node predictions (logits)
edge_pred : TensorBatch
(C, N_e) Edge predictions (logits)
global_pred : TensorBatch
(C, N_e) Global predictions (logits)
"""
# Initialize the result dictionary that will be returned at the end of the method
result = {}
# Encode the node boundaries as clusters if they are not provided directly
if clusts is None:
clusts = self._make_clusters(data, coord_label=coord_label)
result["clusts"] = clusts
# If needed, infer per-cluster shapes once and reuse them downstream
shapes = self._get_shapes(data, clusts, shapes)
# Build the graph if it is not provided directly
closest_index = None
if edge_index is None:
if self.graph_constructor is None:
raise ValueError(
"Must provide edge_index or graph configuration to build it."
)
edge_index, closest_index = self._make_edge_index(
data, clusts, shapes=shapes, groups=groups
)
result["edge_index"] = edge_index
# Fetch the node features
if node_features is None:
if self.node_encoder is None:
raise ValueError(
"Must provide node_features or node encoder configuration to build them."
)
node_features = self.node_encoder(
data, clusts, coord_label=coord_label, points=points, extra=extra
)
if isinstance(node_features, tuple):
# If the output of the node encoder is a tuple, separate points
node_features, points = node_features
start_points, end_points = points.tensor.split(3, dim=1)
result["start_points"] = TensorBatch(
start_points, points.counts, coord_cols=np.array([0, 1, 2])
)
result["end_points"] = TensorBatch(
end_points, points.counts, coord_cols=np.array([0, 1, 2])
)
if self.return_features:
result["node_features"] = node_features
# Fetch the edge features
if edge_features is None and self.edge_encoder is not None:
edge_features = self.edge_encoder(
data, clusts, edge_index, closest_index=closest_index
)
if self.return_features and edge_features is not None:
result["edge_features"] = edge_features
# Fetch the global_features
if global_features is None and self.global_encoder is not None:
global_features = self.global_encoder(data, clusts)
if global_features is not None and self.return_features:
result["global_features"] = global_features
# Bring edge_index and batch_ids to device
# TODO: try to keep everything (apart from clusts?) on GPU?
index = torch.tensor(edge_index.index, device=data.tensor.device)
xbatch = torch.tensor(clusts.batch_ids, device=data.tensor.device)
# Pass through the model, update results
out = self.gnn(node_features, index, edge_features, global_features, xbatch)
# Loop over the necessary node/edge/global predictions, store output
for t in self.out_types:
for key in getattr(self, f"{t}_pred_keys"):
result[key] = getattr(self, key)(out[f"{t}_features"])
# If requested, build node groups from edge predictions
if self.make_groups:
self._make_groups(result, clusts, edge_index, shapes=shapes)
return result
def _make_clusters(self, data, coord_label=None):
"""Make the list of node clusters based on the label tensor and the requested class.
Parameters
----------
data : TensorBatch
Tensor of voxel/value pairs with shape `(N, 1 + D + N_f)`, where
`N` is the total number of voxels, the leading column stores the
batch ID, `D` is the image dimensionality and `N_f` is the number
of features. The features must also contain the labels needed to
build clusters on the fly.
coord_label : TensorBatch, optional
(P, 1 + 2*D + 2) Tensor of label points
Returns
-------
clusts : IndexBatch
(C, N_c, N_{c,i}) Cluster indexes
"""
if self.dbscan is not None:
# Use the DBSCAN fragmenter to build the clusters
seg_label = TensorBatch(data.tensor[:, SHAPE_COL], data.counts)
clusts, _ = self.dbscan(data, seg_label, coord_label)
else:
# Use the label tensor to build the clusters
clusts = form_clusters_batch(
data.to_numpy(),
self.node_min_size,
self.node_source,
shapes=self.node_type,
)
return clusts
def _get_shapes(self, data, clusts, shapes=None):
"""Return per-cluster semantic labels if the graph logic needs them.
Parameters
----------
data : TensorBatch
Tensor of voxel/value pairs with shape `(N, 1 + D + N_f)`.
clusts : IndexBatch
(C) List of indexes corresponding to each cluster
shapes : TensorBatch, optional
(C) Explicit semantic label per cluster
Returns
-------
TensorBatch or None
Cluster semantic labels, or `None` if they are not needed and were
not provided.
"""
if shapes is not None:
return shapes
if self.graph_constructor is None or not hasattr(
self.graph_constructor.max_length, "__len__"
):
return None
data_np = data.to_numpy()
if self.node_source == GROUP_COL:
shapes = get_cluster_primary_label_batch(data_np, clusts, SHAPE_COL)
else:
shapes = get_cluster_label_batch(data_np, clusts, SHAPE_COL)
shapes.data = shapes.data.astype(np.int64)
return shapes
def _make_edge_index(self, data, clusts, shapes=None, groups=None):
"""Make the edge index based on the cluster indexes and the graph configuration.
Parameters
----------
data : TensorBatch
Tensor of voxel/value pairs with shape `(N, 1 + D + N_f)`, where
`N` is the total number of voxels, the leading column stores the
batch ID, `D` is the image dimensionality and `N_f` is the number
of features. The features must also contain the labels needed to
build clusters on the fly.
clusts : IndexBatch
(C) List of indexes corresponding to each cluster
shapes : TensorBatch, optional
(C) List of cluster semantic class used to define the max length
groups : TensorBatch, optional
(C) List of node groups, one per cluster. If specified, removes
connections between nodes that belong to different groups.
Returns
-------
edge_index : EdgeIndexBatch
(E, 2) Incidence matrix
closest_index : TensorBatch
(E) List of closest voxel index for each edge
"""
# Check that the graph constructor is defined
if self.graph_constructor is None:
raise ValueError(
"Must provide graph configuration to build edge index from clusters."
)
# Bring data to numpy for the graph construction
data_np = data.to_numpy()
# Initialize the input graph
edge_index, _, closest_index = self.graph_constructor(
data_np, clusts, shapes, groups
)
return edge_index, closest_index
def _make_groups(self, result, clusts, edge_index, shapes=None):
"""Make node groups based on edge predictions.
Parameters
----------
result : dict
Dictionary containing the output of the model, including edge predictions and
cluster indexes. The edge predictions should be stored under keys with the format
`edge{key}_pred`, where `key` is the name of the type of edge prediction (e.g. score or threshold).
The cluster indexes should be stored under the key `clusts`.
clusts : IndexBatch
(C) List of indexes corresponding to each cluster
edge_index : EdgeIndexBatch
(E, 2) Incidence matrix
shapes : TensorBatch, optional
(C) List of cluster semantic class used to restrict track association
"""
# Fetch the list of edge prediction keys
edge_pred_keys = [
key for key in result if key.startswith("edge") and key.endswith("pred")
]
if not edge_pred_keys:
raise ValueError(
"Must provide edge predictions to build node groups. "
"Edge predictions should be stored under keys with the format "
"`edge{key}_pred`, where `key` is the name of the prediction head."
)
# Loop over the edge predictions, build node groups based on each of them
for key in edge_pred_keys:
edge_pred = result[key].to_numpy()
prefix = "group" + key.replace("edge", "").replace("_pred", "")
if self.grouping_method == "threshold":
result[f"{prefix}_pred"] = node_assignment_batch(
edge_index, edge_pred, clusts
)
elif self.grouping_method == "score":
if not self.grouping_through_track:
result[f"{prefix}_pred"] = node_assignment_score_batch(
edge_index, edge_pred, clusts
)
else:
if shapes is None:
raise ValueError(
"Must provide shapes to restrict track association."
)
track_node = TensorBatch(
shapes.data == TRACK_SHP, counts=shapes.counts
)
result[f"{prefix}_pred"] = node_assignment_score_batch(
edge_index, edge_pred, clusts, track_node
)
else:
raise ValueError(
"Group prediction algorithm not recognized:",
self.grouping_method,
)
[docs]
class GrapPALoss(torch.nn.modules.loss._Loss):
"""Takes the output of the GrapPA and computes the total loss.
For use in config:
.. code-block:: yaml
model:
name: grappa
modules:
grappa_loss:
node_loss:
name: <name of the node loss>
<dictionary of arguments to pass to the loss>
edge_loss:
name: <name of the edge loss>
<dictionary of arguments to pass to the loss>
global_loss:
name: <name of the global loss>
<dictionary of arguments to pass to the loss>
Each of the specific loss blocks can also contain multiple losses by
providing a name key in a loss block nested below it. Each loss name of a
specific type should be provided with a corresponding output from GRaPA.
See configuration files prefixed with `grappa_` under the `config`
directory for detailed examples of working configurations.
"""
def __init__(self, grappa_loss, grappa=None):
"""Initialize the GrapPA loss function.
Parameters
----------
grappa_loss : dict
Loss configuration
grappa : dict, optional
Model configuration
"""
# Initialize the parent class
super().__init__()
# Process the loss configuration
self.process_loss_config(**grappa_loss)
[docs]
def process_loss_config(self, node_loss=None, edge_loss=None, global_loss=None):
"""Process the loss configuration.
Parameters
----------
node_loss : Union[dict, Dict[dict]], optional
Node loss configuration
edge_loss : Union[dict, Dict[dict]], optional
Edge loss configuration
global_loss : Union[dict, Dict[dict]], optional
Global loss configuration
"""
# Check that there is at least one loss to apply
self.out_types = ("node", "edge", "global")
if node_loss is None and edge_loss is None and global_loss is None:
raise ValueError(
"Must provide at least one of `node_loss`, `edge_loss` or "
"`global_loss` to the GrapPA loss function."
)
# Initialize the node/edge/global losses
self.process_single_loss_config("node", node_loss, node_loss_factory)
self.process_single_loss_config("edge", edge_loss, edge_loss_factory)
self.process_single_loss_config("global", global_loss, global_loss_factory)
[docs]
def process_single_loss_config(self, prefix, loss, constructor):
"""Process a loss configuration.
Parameters
----------
prefix : dict
Name of the output type to apply the loss to
loss : Union[int, dict]
Loss configuration
constructor : object
Loss constructor function
"""
# If the loss is not specified, nothing to do here
if loss is None:
setattr(self, f"{prefix}_loss_keys", [])
return
# Process the configuration dictionary otherwise
loss_keys = []
if "name" in loss:
# Initialize a single loss
loss_key = f"{prefix}_loss"
loss_keys.append(loss_key)
setattr(self, loss_key, constructor(loss))
else:
# Otherwise, initialzie one loss per prediction type
for key, cfg in loss.items():
loss_key = f"{prefix}_{key}_loss"
loss_keys.append(loss_key)
setattr(self, loss_key, constructor(cfg))
setattr(self, f"{prefix}_loss_keys", loss_keys)
[docs]
def forward(
self, clust_label, coord_label=None, graph_label=None, iteration=None, **output
):
"""Apply the node/edge/global losses to the logits from GrapPA.
Parameters
----------
clust_label : TensorBatch
(N, 1 + D + N_f) Tensor of voxel/value pairs
- N is the the total number of voxels in the image
- 1 is the batch ID
- D is the number of dimensions in the input image
- N_f is is the number of cluster labels
coord_label : TensorBatch, optional
(P, 1 + D + 8) Tensor of start/end point labels for each
true particle in the image
graph_label : EdgeIndexTensor, optional
(2, E) Tensor of edges that correspond to physical
connections between true particle in the image
iteration : int, optional
Iteration index
**output : dict
Output of the GrapPA model
"""
# Loop and apply the losses
result = {}
num_losses = 0
loss, accuracy = 0.0, 0.0
for t in self.out_types:
loss_keys = getattr(self, f"{t}_loss_keys")
for key in loss_keys:
# If the number of loss keys is > 1 for this type of
# prediction, must rename the prediction appropriately
extra = {}
if len(loss_keys) > 1:
extra[f"{t}_pred"] = output[key.replace("loss", "pred")]
# Compute the loss
out = getattr(self, key)(
clust_label=clust_label,
coord_label=coord_label,
graph_label=graph_label,
iteration=iteration,
**output,
**extra,
)
# Increment the loss and accuracy
loss += out["loss"]
accuracy += out["accuracy"]
num_losses += 1
# Update the result dictionary
prefix = "_".join(key.split("_")[:-1])
for k, v in out.items():
result[f"{prefix}_{k}"] = v
# Append the total loss and total accuracy
result["loss"] = torch.sum(loss) / num_losses
result["accuracy"] = np.sum(accuracy) / num_losses
return result