"""Supervi dense clustering model and its loss."""
import MinkowskiEngine as ME
import numpy as np
import torch
from spine.constants import (
COORD_COLS,
DELTA_SHP,
MICHL_SHP,
SHAPE_COL,
SHOWR_SHP,
TRACK_SHP,
)
from spine.constants.factory import enum_factory
from spine.data import IndexBatch, TensorBatch
from spine.utils.cluster.graph import ClusterGraphConstructor
from .layer.cluster import kernel_factory, loss_factory
from .layer.cluster.graph_spice_embedder import GraphSPICEEmbedder
__all__ = ["GraphSPICE", "GraphSPICELoss"]
[docs]
class GraphSPICE(torch.nn.Module):
"""Graph Scalable Proposal-free Instance Clustering Engine (Graph-SPICE).
Graph-SPICE has two main components:
- A voxel embedder, implemented as a UNet-like CNN for feature extraction
and embeddings
- An edge probability kernel that maps pairs of node attribute vectors to
edge scores
Prediction proceeds in three stages:
- A neighbor graph such as KNN or a radius graph is constructed
- Edge probabilities are evaluated and low-probability edges are dropped
- Voxels are clustered through connected-component clustering
A typical configuration is broken down into multiple components:
.. code-block:: yaml
model:
name: graph_spice
modules:
graph_spice:
<Basic parameters>
embedder:
<Feature embedding configuration block>
kernel:
<Edge kernel function configuration block>
constructor:
<Graph construction base parameters>
graph:
<Graph configuration block>
orphan:
<Orphan assignment configuration block>
See configuration file(s) prefixed with `graph_spice` under the `config`
directory for detailed examples of working configurations.
"""
MODULES = ["constructor", "embedder", "kernel"]
def __init__(self, graph_spice, graph_spice_loss=None):
"""Initialize the Graph-SPICE model.
Parameters
----------
graph_spice : dict
Graph-SPICE configuration dictionary
graph_spice_loss : dict, optional
Graph-SPICE loss configuration dictionary
"""
# Initialize the parent class
super().__init__()
# Initialize the model configuration
self.process_model_config(**graph_spice)
[docs]
def process_model_config(
self,
embedder,
kernel,
constructor,
shapes=[SHOWR_SHP, TRACK_SHP, MICHL_SHP, DELTA_SHP],
use_raw_features=False,
invert=True,
make_clusters=False,
):
"""Initialize the underlying modules.
Parameters
----------
embedder : dict
Pixel embedding configuration
kernel : dict
Edge kernel configuration
constructor : dict
Edge index construction configuration
shapes : List[str]
List of shape names to construct clusters for
use_raw_features : bool, default True
Use the list of embedder features as is, without the output layers
invert : bool, default True
Invert the edge scores so that 0 is on an 1 is off
make_clusters : bool, default False
If `True`, builds a list of cluster indexes
"""
# Initialize the embedder
self.embedder = GraphSPICEEmbedder(
**embedder, use_raw_features=use_raw_features
)
# Initialize the kernel function (must be owned here to be loaded)
self.kernel_fn = kernel_factory(kernel)
# Initialize the graph constructor
self.constructor = ClusterGraphConstructor(
**constructor,
kernel_fn=self.kernel_fn,
shapes=shapes,
invert=invert,
training=self.training,
)
# Parse the set of shapes to cluster
self.shapes = enum_factory("shape", shapes)
# Store model parameters
self.use_raw_features = use_raw_features
self.invert = invert
self.make_clusters = make_clusters
[docs]
def filter_class(self, data, seg_label, clust_label=None):
"""Filter the list of pixels to those in the list of requested shapes.
Parameters
----------
data : TensorBatch
(N, 1 + D + N_f) tensor of voxel/value pairs
- N is the the total number of voxels in the image
- 1 is the batch ID
- D is the number of dimensions in the input image
- N_f is the number of features per voxel
seg_label : TensorBatch
(N, 1 + D + 1) Tensor of segmentation labels
- 1 is the segmentation label
clust_label : TensorBatch, optional
(N, 1 + D + N_c) Tensor of cluster labels
- N_c is is the number of cluster labels
Returns
-------
data : TensorBatch
(M, 1+ + D + Nf) restricted tensor of voxel/value pairs
seg_label : TensorBatch
(M, 1 + D + 1) restricted tensor of segmentation labels
clust_label : TensorBatch
(M, 1 + D + N_c) Restricted tensor of cluster labels
index : torch.Tensor
(M) Index to narrow down the original tensor
"""
# Convert shapes to a torch tensor for easy comparison
shapes = torch.tensor(self.shapes, device=data.device)
# Create an index of the valid input rows
mask = (seg_label.tensor[:, SHAPE_COL] == shapes.view(-1, 1)).any(dim=0)
index = torch.where(mask)[0]
# Restrict the input
spans = data.counts
data = TensorBatch(
data.tensor[index], batch_size=data.batch_size, has_batch_col=True
)
# Restrict the label tensors
assert seg_label.shape[0] == mask.shape[0], (
"The segmentation label tensor is of the wrong shape: "
f"{seg_label.shape[0]} != {mask.shape[0]}"
)
seg_label = TensorBatch(seg_label.tensor[index], data.counts)
if clust_label is not None:
assert clust_label.shape[0] == mask.shape[0], (
"The cluster label tensor is of the wrong shape: "
f"{clust_label.shape[0]} != {mask.shape[0]}"
)
clust_label = TensorBatch(clust_label.tensor[index], data.counts)
# Store the index as an IndexBatch
index = IndexBatch(index, spans, data.counts)
return data, seg_label, clust_label, index
[docs]
def forward(self, data, seg_label, clust_label=None):
"""Run a batch of data through the forward function.
Parameters
----------
data : TensorBatch
(N, 1 + D + N_f) tensor of voxel/value pairs
- N is the the total number of voxels in the image
- 1 is the batch ID
- D is the number of dimensions in the input image
- N_f is the number of features per voxel
seg_label : TensorBatch
(N, 1 + D + 1) Tensor of segmentation labels
- 1 is the segmentation label
clust_label : TensorBatch, optional
(N, 1 + D + N_c) Tensor of cluster labels
- N_c is is the number of cluster labels
Returns
-------
dict
Dictionary of outputs
"""
# Filter the input down to the requested shapes
data, seg_label, clust_label, index = self.filter_class(
data, seg_label, clust_label
)
# Embed the input pixels into a feature space used for graph clustering
result = self.embedder(data)
# Store the index and the counts to not have to recompute them later
result["filter_index"] = index
# Build the graph on the pixel set
coords = result["coordinates"]
if self.use_raw_features:
features = result["features"]
else:
features = result["hypergraph_features"]
coords = TensorBatch(coords.data[:, coords.coord_cols], coords.counts)
graph = self.constructor(coords, features, seg_label, clust_label)
# If requested, convert edge predictions to node predictions
if self.make_clusters:
clusts, clust_shapes = self.constructor.fit_predict(graph)
result["clusts"] = clusts
result["clust_shapes"] = clust_shapes
# Save the graph dictionary
result.update(graph)
return result
[docs]
class GraphSPICELoss(torch.nn.Module):
"""Loss function for Graph-SPICE.
For use in config:
.. code-block:: yaml
model:
name: graph_spice
modules:
graph_spice_loss:
<Basic parameters>
edge_loss:
<Edge loss configuration block>
See configuration files prefixed with `graph_spice` under the `config`
directory for detailed examples of working configurations.
See Also
--------
:class:`GraphSPICE`
"""
def __init__(self, graph_spice, graph_spice_loss=None):
"""Intialize the Graph-SPICE loss.
Parameters
----------
graph_spice : dict
Graph-SPICE configuration dictionary
graph_spice_loss : dict
Graph-SPICE loss configuration dictionary
"""
# Initialize the parent class
super().__init__()
# Process the loss configuration
self.process_loss_config(**graph_spice_loss)
# Process the main mode configuration for its crucial elements
self.process_model_config(**graph_spice)
[docs]
def process_loss_config(self, evaluate_clustering_metrics=False, **loss):
"""Process the loss configuration
Parameters
----------
evaluate_clustering_metrics : bool, default False
If `True`, evaluates the clustering accuracy directly, rather than
simply reporting an edge assignment acurracy
**loss : dict
Loss configurationd dictionary
"""
# Store basic parameters
self.evaluate_clustering_metrics = evaluate_clustering_metrics
# Initialize the loss function
self.loss_fn = loss_factory(loss)
[docs]
def process_model_config(
self,
constructor,
shapes=[SHOWR_SHP, TRACK_SHP, MICHL_SHP, DELTA_SHP],
invert=True,
**kwargs,
):
"""Process the model configuration
Parameters
----------
constructor : dict, optional
Edge index construction configuration
shapes : List[int], default [0, 1, 2, 3]
List of semantic shapes to run DBSCAN on
invert : bool, default True
Invert the edge scores so that 0 is on an 1 is off
"""
# Initialize the graph constructor (used to produce node assignments)
if self.evaluate_clustering_metrics:
self.constructor = ClusterGraphConstructor(
**constructor, shapes=shapes, invert=invert
)
def filter_class(self, seg_label, clust_label, filter_index):
"""Filter the list of pixels to those in the list of requested shapes.
Parameters
----------
seg_label : TensorBatch
(N, 1 + D + 1) Tensor of segmentation labels
- 1 is the segmentation label
clust_label : TensorBatch, optional
(N, 1 + D + N_c) Tensor of cluster labels
- N_c is is the number of cluster labels
filter_index : IndexBatch
(M) Index to narrow down the original tensor
Parameters
----------
seg_label : TensorBatch
(M, 1 + D + 1) restricted tensor of segmentation labels
clust_label : TensorBatch
(M, 1 + D + N_c) Restricted tnesor of cluster labels
"""
seg_label = TensorBatch(
seg_label.tensor[filter_index.index], filter_index.counts
)
clust_label = TensorBatch(
clust_label.tensor[filter_index.index], filter_index.counts
)
return seg_label, clust_label
[docs]
def forward(self, seg_label, clust_label, filter_index, **output):
"""Run a batch of data through the loss function.
Parameters
----------
seg_label : TensorBatch
(N, 1 + D + 1) Tensor of segmentation labels
- 1 is the segmentation label
clust_label : TensorBatch, optional
(N, 1 + D + N_c) Tensor of cluster labelresul
- N_c is is the number of cluster labels
filter_index : IndexBatch
(M) Index to narrow down the original tensor
**output : dict
Output of the Graph-SPICE model
Returns
-------
dict
Dictionary of outputs
"""
# Narrow down the labels to those corresponding to the relevant shapes
seg_label, clust_label = self.filter_class(seg_label, clust_label, filter_index)
# Pass the output through the loss function
result = self.loss_fn(seg_label=seg_label, clust_label=clust_label, **output)
# If requested, compute clustering metrics
if self.evaluate_clustering_metrics:
# Assign cluster IDs to each of the input points, if not yet done
if "node_pred" not in output:
self.constructor.fit_predict(output)
# Evaluate clustering metrics
metrics = self.constructor.evaluate(output, mean=True)
# Append metrics to the result dictionary
result.update(metrics)
return result