spine.model.uresnet

UResNet segmentation model and its loss.

Classes

SegmentationLoss(*args, **kwargs)

Loss definition for semantic segmentation.

UResNetSegmentation(*args, **kwargs)

UResNet for semantic segmentation.

class spine.model.uresnet.UResNetSegmentation(*args: Any, **kwargs: Any)[source]

UResNet for semantic segmentation.

Typical configuration should look like:

model:
  name: uresnet
  modules:
    uresnet:
      # Your config goes here

See setup_cnn_configuration() for available parameters for the backbone UResNet architecture.

See configuration file(s) prefixed with uresnet_ under the config directory for detailed examples of working configurations.

Methods

__call__(*args, **kwargs)

Call self as a function.

forward(data)

Run a batch of data through the forward function.

process_model_config(num_classes[, ghost])

Initialize the underlying UResNet model.

INPUT_SCHEMA = [['sparse3d', (<class 'float'>,), (3, 1)]]
MODULES = ['uresnet']
process_model_config(num_classes, ghost=False, **backbone)[source]

Initialize the underlying UResNet model.

Parameters:
  • num_classes (int) – Number of classes to classify the voxels as

  • ghost (bool, default False) – Whether to add a deghosting step in the classification model

  • **backbone (dict) – UResNet backbone configuration

forward(data)[source]

Run a batch of data through the forward function.

Parameters:

data (TensorBatch) – (N, 1 + D + N_f) tensor of voxel/value pairs - N is the the total number of voxels in the image - 1 is the batch ID - D is the number of dimensions in the input image - N_f is the number of features per voxel

Returns:

Dictionary of outputs

Return type:

dict

class spine.model.uresnet.SegmentationLoss(*args: Any, **kwargs: Any)[source]

Loss definition for semantic segmentation.

For a regular flavor UResNet, it is a cross-entropy loss. For deghosting, it depends on a configuration parameter ghost:

  • If ghost=True, we first compute the cross-entropy loss on the ghost point classification (weighted on the fly with sample statistics). Then we compute a mask = all non-ghost points (based on true information in label) and within this mask, compute a cross-entropy loss for the rest of classes.

  • If ghost=False, we compute a N+1-classes cross-entropy loss, where N is the number of classes, not counting the ghost point class.

Methods

__call__(*args, **kwargs)

Call self as a function.

forward(seg_label, segmentation[, ...])

Computes the cross-entropy loss of the semantic segmentation predictions.

get_distance_weights(seg_label, point_label)

Define weights for each of the points in the image based on their distance from points of interests (typically vertices, but user defined).

get_loss_accuracy(logits, labels[, weights])

Computes the loss, global and classwise accuracy.

process_loss_config([loss, ghost_label, ...])

Process the loss function parameters.

process_model_config(num_classes[, ghost])

Process the parameters of the upstream model needed for in the loss.

INPUT_SCHEMA = [['parse_sparse3d', (<class 'int'>,), (3, 1)]]
process_model_config(num_classes, ghost=False, **kwargs)[source]

Process the parameters of the upstream model needed for in the loss.

Parameters:
  • num_classes (int) – Number of classes to classify the voxels as

  • ghost (bool, default False) – Whether to add a deghosting step in the classification model

  • **kwargs (dict, optional) – Leftover model configuration (no need in the loss)

process_loss_config(loss='ce', ghost_label=-1, alpha=1.0, beta=1.0, balance_loss=False, upweight_points=False, upweight_radius=20)[source]

Process the loss function parameters.

Parameters:
  • loss (str, default 'ce') – Loss function used for semantic segmentation

  • ghost_label (int, default -1) – ID of ghost points. If specified (> -1), classify ghosts only

  • alpha (float, default 1.0) – Classification loss prefactor

  • beta (float, default 1.0) – Ghost mask loss prefactor

  • balance_loss (bool, default False) – Whether to weight the loss to account for class imbalance

  • upweight_points (bool, default False) – Whether to weight the loss higher near specific points (to be provided as point_label as a loss input)

  • upweight_radius (bool, default False) – Radius around the points of interest for which to upweight the loss

forward(seg_label, segmentation, point_label=None, ghost=None, weights=None, **kwargs)[source]

Computes the cross-entropy loss of the semantic segmentation predictions.

Parameters:
  • seg_label (TensorBatch) – (N, 1 + D + 1) Tensor of segmentation labels for the batch

  • segmentation (TensorBatch) – (N, N_c) Tensor of logits from the segmentation model

  • point_label (TensorBatch, optional) – (P, 1 + D + 1) Tensor of points of interests for the batch. This is used to upweight the loss near specific points.

  • ghost (TensorBatch, optional) – (N, 2) Tensor of ghost logits from the segmentation model

  • weights (TensorBatch, optional) –

    1. Tensor of weights for each pixel in the batch

  • **kwargs (dict, optional) – Other outputs of the upstream model which are not relevant here

Returns:

Dictionary of accuracies and losses

Return type:

dict

get_distance_weights(seg_label, point_label)[source]

Define weights for each of the points in the image based on their distance from points of interests (typically vertices, but user defined).

Parameters:
  • seg_label (TensorBatch) – (N, 1 + D + 1) Tensor of segmentation labels for the batch

  • point_label (TensorBatch) – (P, 1 + D + 1) Tensor of points of interests for the batch. This is used to upweight the loss of points near a vertex.

Returns:

  1. Array of weights associated with each point

Return type:

torch.Tensor

get_loss_accuracy(logits, labels, weights=None)[source]

Computes the loss, global and classwise accuracy.

Parameters:
  • logits (torch.Tensor) – (N, N_c) Output logits from the network for each voxel

  • labels (torch.Tensor) –

    1. Target values for each voxel

  • weights (torch.Tensor, optional) –

    1. Tensor of weights for each pixel in the batch

Returns:

  • torch.Tensor – Cross-entropy loss value

  • float – Global accuracy

  • np.ndarray – (N_c) Vector of class-wise accuracy

  • torch.Tensor

    1. Updated set of weights for each pixel in the batch