spine.model.graph_spice

Supervi dense clustering model and its loss.

Classes

GraphSPICE(*args, **kwargs)

Graph Scalable Proposal-free Instance Clustering Engine (Graph-SPICE).

GraphSPICELoss(*args, **kwargs)

Loss function for Graph-SPICE.

class spine.model.graph_spice.GraphSPICE(*args: Any, **kwargs: Any)[source]

Graph Scalable Proposal-free Instance Clustering Engine (Graph-SPICE).

Graph-SPICE has two main components:

  • A voxel embedder, implemented as a UNet-like CNN for feature extraction and embeddings

  • An edge probability kernel that maps pairs of node attribute vectors to edge scores

Prediction proceeds in three stages:

  • A neighbor graph such as KNN or a radius graph is constructed

  • Edge probabilities are evaluated and low-probability edges are dropped

  • Voxels are clustered through connected-component clustering

A typical configuration is broken down into multiple components:

model:
  name: graph_spice
  modules:
    graph_spice:
      <Basic parameters>
      embedder:
        <Feature embedding configuration block>
      kernel:
        <Edge kernel function configuration block>
      constructor:
        <Graph construction base parameters>
        graph:
          <Graph configuration block>
        orphan:
          <Orphan assignment configuration block>

See configuration file(s) prefixed with graph_spice under the config directory for detailed examples of working configurations.

Methods

__call__(*args, **kwargs)

Call self as a function.

filter_class(data, seg_label[, clust_label])

Filter the list of pixels to those in the list of requested shapes.

forward(data, seg_label[, clust_label])

Run a batch of data through the forward function.

process_model_config(embedder, kernel, ...)

Initialize the underlying modules.

MODULES = ['constructor', 'embedder', 'kernel']
process_model_config(embedder, kernel, constructor, shapes=[0, 1, 2, 3], use_raw_features=False, invert=True, make_clusters=False)[source]

Initialize the underlying modules.

Parameters:
  • embedder (dict) – Pixel embedding configuration

  • kernel (dict) – Edge kernel configuration

  • constructor (dict) – Edge index construction configuration

  • shapes (List[str]) – List of shape names to construct clusters for

  • use_raw_features (bool, default True) – Use the list of embedder features as is, without the output layers

  • invert (bool, default True) – Invert the edge scores so that 0 is on an 1 is off

  • make_clusters (bool, default False) – If True, builds a list of cluster indexes

filter_class(data, seg_label, clust_label=None)[source]

Filter the list of pixels to those in the list of requested shapes.

Parameters:
  • data (TensorBatch) – (N, 1 + D + N_f) tensor of voxel/value pairs - N is the the total number of voxels in the image - 1 is the batch ID - D is the number of dimensions in the input image - N_f is the number of features per voxel

  • seg_label (TensorBatch) – (N, 1 + D + 1) Tensor of segmentation labels - 1 is the segmentation label

  • clust_label (TensorBatch, optional) – (N, 1 + D + N_c) Tensor of cluster labels - N_c is is the number of cluster labels

Returns:

  • data (TensorBatch) – (M, 1+ + D + Nf) restricted tensor of voxel/value pairs

  • seg_label (TensorBatch) – (M, 1 + D + 1) restricted tensor of segmentation labels

  • clust_label (TensorBatch) – (M, 1 + D + N_c) Restricted tensor of cluster labels

  • index (torch.Tensor) –

    1. Index to narrow down the original tensor

forward(data, seg_label, clust_label=None)[source]

Run a batch of data through the forward function.

Parameters:
  • data (TensorBatch) – (N, 1 + D + N_f) tensor of voxel/value pairs - N is the the total number of voxels in the image - 1 is the batch ID - D is the number of dimensions in the input image - N_f is the number of features per voxel

  • seg_label (TensorBatch) – (N, 1 + D + 1) Tensor of segmentation labels - 1 is the segmentation label

  • clust_label (TensorBatch, optional) – (N, 1 + D + N_c) Tensor of cluster labels - N_c is is the number of cluster labels

Returns:

Dictionary of outputs

Return type:

dict

class spine.model.graph_spice.GraphSPICELoss(*args: Any, **kwargs: Any)[source]

Loss function for Graph-SPICE.

For use in config:

model:
  name: graph_spice
  modules:
    graph_spice_loss:
      <Basic parameters>
      edge_loss:
        <Edge loss configuration block>

See configuration files prefixed with graph_spice under the config directory for detailed examples of working configurations.

See also

GraphSPICE

Methods

__call__(*args, **kwargs)

Call self as a function.

filter_class

forward(seg_label, clust_label, ...)

Run a batch of data through the loss function.

process_loss_config([...])

Process the loss configuration

process_model_config(constructor[, shapes, ...])

Process the model configuration

process_loss_config(evaluate_clustering_metrics=False, **loss)[source]

Process the loss configuration

Parameters:
  • evaluate_clustering_metrics (bool, default False) – If True, evaluates the clustering accuracy directly, rather than simply reporting an edge assignment acurracy

  • **loss (dict) – Loss configurationd dictionary

process_model_config(constructor, shapes=[0, 1, 2, 3], invert=True, **kwargs)[source]

Process the model configuration

Parameters:
  • constructor (dict, optional) – Edge index construction configuration

  • shapes (List[int], default [0, 1, 2, 3]) – List of semantic shapes to run DBSCAN on

  • invert (bool, default True) – Invert the edge scores so that 0 is on an 1 is off

forward(seg_label, clust_label, filter_index, **output)[source]

Run a batch of data through the loss function.

Parameters:
  • seg_label (TensorBatch) – (N, 1 + D + 1) Tensor of segmentation labels - 1 is the segmentation label

  • clust_label (TensorBatch, optional) – (N, 1 + D + N_c) Tensor of cluster labelresul - N_c is is the number of cluster labels

  • filter_index (IndexBatch) –

    1. Index to narrow down the original tensor

  • **output (dict) – Output of the Graph-SPICE model

Returns:

Dictionary of outputs

Return type:

dict