spine.model.uresnet
UResNet segmentation model and its loss.
Classes
|
Loss definition for semantic segmentation. |
|
UResNet for semantic segmentation. |
- class spine.model.uresnet.UResNetSegmentation(*args: Any, **kwargs: Any)[source]
UResNet for semantic segmentation.
Typical configuration should look like:
model: name: uresnet modules: uresnet: # Your config goes here
See
setup_cnn_configuration()for available parameters for the backbone UResNet architecture.See configuration file(s) prefixed with uresnet_ under the config directory for detailed examples of working configurations.
Methods
__call__(*args, **kwargs)Call self as a function.
forward(data)Run a batch of data through the forward function.
process_model_config(num_classes[, ghost])Initialize the underlying UResNet model.
- INPUT_SCHEMA = [['sparse3d', (<class 'float'>,), (3, 1)]]
- MODULES = ['uresnet']
- process_model_config(num_classes, ghost=False, **backbone)[source]
Initialize the underlying UResNet model.
- Parameters:
num_classes (int) – Number of classes to classify the voxels as
ghost (bool, default False) – Whether to add a deghosting step in the classification model
**backbone (dict) – UResNet backbone configuration
- forward(data)[source]
Run a batch of data through the forward function.
- Parameters:
data (TensorBatch) – (N, 1 + D + N_f) tensor of voxel/value pairs - N is the the total number of voxels in the image - 1 is the batch ID - D is the number of dimensions in the input image - N_f is the number of features per voxel
- Returns:
Dictionary of outputs
- Return type:
dict
- class spine.model.uresnet.SegmentationLoss(*args: Any, **kwargs: Any)[source]
Loss definition for semantic segmentation.
For a regular flavor UResNet, it is a cross-entropy loss. For deghosting, it depends on a configuration parameter ghost:
If ghost=True, we first compute the cross-entropy loss on the ghost point classification (weighted on the fly with sample statistics). Then we compute a mask = all non-ghost points (based on true information in label) and within this mask, compute a cross-entropy loss for the rest of classes.
If ghost=False, we compute a N+1-classes cross-entropy loss, where N is the number of classes, not counting the ghost point class.
See also
Methods
__call__(*args, **kwargs)Call self as a function.
forward(seg_label, segmentation[, ...])Computes the cross-entropy loss of the semantic segmentation predictions.
get_distance_weights(seg_label, point_label)Define weights for each of the points in the image based on their distance from points of interests (typically vertices, but user defined).
get_loss_accuracy(logits, labels[, weights])Computes the loss, global and classwise accuracy.
process_loss_config([loss, ghost_label, ...])Process the loss function parameters.
process_model_config(num_classes[, ghost])Process the parameters of the upstream model needed for in the loss.
- INPUT_SCHEMA = [['parse_sparse3d', (<class 'int'>,), (3, 1)]]
- process_model_config(num_classes, ghost=False, **kwargs)[source]
Process the parameters of the upstream model needed for in the loss.
- Parameters:
num_classes (int) – Number of classes to classify the voxels as
ghost (bool, default False) – Whether to add a deghosting step in the classification model
**kwargs (dict, optional) – Leftover model configuration (no need in the loss)
- process_loss_config(loss='ce', ghost_label=-1, alpha=1.0, beta=1.0, balance_loss=False, upweight_points=False, upweight_radius=20)[source]
Process the loss function parameters.
- Parameters:
loss (str, default 'ce') – Loss function used for semantic segmentation
ghost_label (int, default -1) – ID of ghost points. If specified (> -1), classify ghosts only
alpha (float, default 1.0) – Classification loss prefactor
beta (float, default 1.0) – Ghost mask loss prefactor
balance_loss (bool, default False) – Whether to weight the loss to account for class imbalance
upweight_points (bool, default False) – Whether to weight the loss higher near specific points (to be provided as point_label as a loss input)
upweight_radius (bool, default False) – Radius around the points of interest for which to upweight the loss
- forward(seg_label, segmentation, point_label=None, ghost=None, weights=None, **kwargs)[source]
Computes the cross-entropy loss of the semantic segmentation predictions.
- Parameters:
seg_label (TensorBatch) – (N, 1 + D + 1) Tensor of segmentation labels for the batch
segmentation (TensorBatch) – (N, N_c) Tensor of logits from the segmentation model
point_label (TensorBatch, optional) – (P, 1 + D + 1) Tensor of points of interests for the batch. This is used to upweight the loss near specific points.
ghost (TensorBatch, optional) – (N, 2) Tensor of ghost logits from the segmentation model
weights (TensorBatch, optional) –
Tensor of weights for each pixel in the batch
**kwargs (dict, optional) – Other outputs of the upstream model which are not relevant here
- Returns:
Dictionary of accuracies and losses
- Return type:
dict
- get_distance_weights(seg_label, point_label)[source]
Define weights for each of the points in the image based on their distance from points of interests (typically vertices, but user defined).
- Parameters:
seg_label (TensorBatch) – (N, 1 + D + 1) Tensor of segmentation labels for the batch
point_label (TensorBatch) – (P, 1 + D + 1) Tensor of points of interests for the batch. This is used to upweight the loss of points near a vertex.
- Returns:
Array of weights associated with each point
- Return type:
torch.Tensor
- get_loss_accuracy(logits, labels, weights=None)[source]
Computes the loss, global and classwise accuracy.
- Parameters:
logits (torch.Tensor) – (N, N_c) Output logits from the network for each voxel
labels (torch.Tensor) –
Target values for each voxel
weights (torch.Tensor, optional) –
Tensor of weights for each pixel in the batch
- Returns:
torch.Tensor – Cross-entropy loss value
float – Global accuracy
np.ndarray – (N_c) Vector of class-wise accuracy
torch.Tensor –
Updated set of weights for each pixel in the batch