Source code for spine.model.uresnet

"""UResNet segmentation model and its loss."""

from collections import defaultdict

import MinkowskiEngine as ME
import numpy as np
import torch
import torch.nn as nn

from spine.constants import BATCH_COL, COORD_COLS, GHOST_SHP, VALUE_COL
from spine.data import TensorBatch
from spine.utils.logger import logger
from spine.utils.torch.scripts import cdist_fast

from .layer.cnn.act_norm import act_factory, norm_factory
from .layer.cnn.uresnet_layers import UResNet
from .layer.factories import loss_fn_factory

__all__ = ["UResNetSegmentation", "SegmentationLoss"]


[docs] class UResNetSegmentation(nn.Module): """UResNet for semantic segmentation. Typical configuration should look like: .. code-block:: yaml model: name: uresnet modules: uresnet: # Your config goes here See :func:`setup_cnn_configuration` for available parameters for the backbone UResNet architecture. See configuration file(s) prefixed with `uresnet_` under the `config` directory for detailed examples of working configurations. """ INPUT_SCHEMA = [["sparse3d", (float,), (3, 1)]] MODULES = ["uresnet"] def __init__(self, uresnet, uresnet_loss=None): """Initializes the standalone UResNet model. Parameters ---------- uresnet : dict Model configuration uresnet_loss : dict, optional Loss configuration """ # Initialize the parent class super().__init__() # Initialize the model configuration self.process_model_config(**uresnet) # Initialize the output layer self.output = [ norm_factory(self.backbone.norm_cfg, self.num_filters), act_factory(self.backbone.act_cfg), ] self.output = nn.Sequential(*self.output) self.linear_segmentation = ME.MinkowskiLinear( self.num_filters, self.num_classes ) # If needed, activate the ghost classification layer if self.ghost: logger.debug("Ghost Masking is enabled for UResNet Segmentation") self.linear_ghost = ME.MinkowskiLinear(self.num_filters, 2)
[docs] def process_model_config(self, num_classes, ghost=False, **backbone): """Initialize the underlying UResNet model. Parameters ---------- num_classes : int Number of classes to classify the voxels as ghost : bool, default False Whether to add a deghosting step in the classification model **backbone : dict UResNet backbone configuration """ # Store the semantic segmentation configuration self.num_classes = num_classes self.ghost = ghost # Initialize the UResNet backbone, store the relevant parameters self.backbone = UResNet(backbone) self.num_filters = self.backbone.num_filters
[docs] def forward(self, data): """Run a batch of data through the forward function. Parameters ---------- data : TensorBatch (N, 1 + D + N_f) tensor of voxel/value pairs - N is the the total number of voxels in the image - 1 is the batch ID - D is the number of dimensions in the input image - N_f is the number of features per voxel Returns ------- dict Dictionary of outputs """ # Restrict the input to the requested number of features num_cols = 1 + self.backbone.dim + self.backbone.num_input input_tensor = data.tensor[:, :num_cols] # Pass the data through the UResNet backbone result_backbone = self.backbone(input_tensor) # Pass the output features through the output layer feats = result_backbone["decoder_tensors"][-1] feats = self.output(feats) seg = self.linear_segmentation(feats) # Store the output as tensor batches segmentation = TensorBatch(seg.F, data.counts) batch_size = data.batch_size final_tensor = TensorBatch( result_backbone["final_tensor"], batch_size=batch_size, is_sparse=True ) encoder_tensors = [ TensorBatch(t, batch_size=batch_size, is_sparse=True) for t in result_backbone["encoder_tensors"] ] decoder_tensors = [ TensorBatch(t, batch_size=batch_size, is_sparse=True) for t in result_backbone["decoder_tensors"] ] result = { "segmentation": segmentation, "final_tensor": final_tensor, "encoder_tensors": encoder_tensors, "decoder_tensors": decoder_tensors, } # If needed, pass the output features through the ghost linear layer if self.ghost: ghost = self.linear_ghost(feats) result["ghost"] = TensorBatch(ghost.F, data.counts) result["ghost_tensor"] = TensorBatch(ghost, data.counts, is_sparse=True) return result
[docs] class SegmentationLoss(torch.nn.modules.loss._Loss): """Loss definition for semantic segmentation. For a regular flavor UResNet, it is a cross-entropy loss. For deghosting, it depends on a configuration parameter `ghost`: - If `ghost=True`, we first compute the cross-entropy loss on the ghost point classification (weighted on the fly with sample statistics). Then we compute a mask = all non-ghost points (based on true information in label) and within this mask, compute a cross-entropy loss for the rest of classes. - If `ghost=False`, we compute a N+1-classes cross-entropy loss, where N is the number of classes, not counting the ghost point class. See Also -------- :class:`UResNetSegmentation` """ INPUT_SCHEMA = [["parse_sparse3d", (int,), (3, 1)]] def __init__(self, uresnet, uresnet_loss): """ Initializes the segmentation loss Parameters ---------- uresnet : dict Model configuration uresnet_loss : dict Loss configuration """ # Initialize the parent class super().__init__() # Initialize what we need from the model configuration self.process_model_config(**uresnet) # Initialize the loss configuration self.process_loss_config(**uresnet_loss)
[docs] def process_model_config(self, num_classes, ghost=False, **kwargs): """Process the parameters of the upstream model needed for in the loss. Parameters ---------- num_classes : int Number of classes to classify the voxels as ghost : bool, default False Whether to add a deghosting step in the classification model **kwargs : dict, optional Leftover model configuration (no need in the loss) """ # Store the semantic segmentation configuration self.num_classes = num_classes self.ghost = ghost
[docs] def process_loss_config( self, loss="ce", ghost_label=-1, alpha=1.0, beta=1.0, balance_loss=False, upweight_points=False, upweight_radius=20, ): """Process the loss function parameters. Parameters ---------- loss : str, default 'ce' Loss function used for semantic segmentation ghost_label : int, default -1 ID of ghost points. If specified (> -1), classify ghosts only alpha : float, default 1.0 Classification loss prefactor beta : float, default 1.0 Ghost mask loss prefactor balance_loss : bool, default False Whether to weight the loss to account for class imbalance upweight_points : bool, default False Whether to weight the loss higher near specific points (to be provided as `point_label` as a loss input) upweight_radius: bool, default False Radius around the points of interest for which to upweight the loss """ # Set the loss function self.loss_fn = loss_fn_factory(loss, reduction="none") # Store the loss configuration self.ghost_label = ghost_label self.alpha = alpha self.beta = beta self.balance_loss = balance_loss self.upweight_points = upweight_points self.upweight_radius = upweight_radius # If a ghost label is provided, it cannot be in conjecture with # having a dedicated ghost masking layer assert not (self.ghost and self.ghost_label > -1), ( "Cannot classify ghost exclusively (ghost_label) and " "have a dedicated ghost masking layer at the same time." )
[docs] def forward( self, seg_label, segmentation, point_label=None, ghost=None, weights=None, **kwargs, ): """Computes the cross-entropy loss of the semantic segmentation predictions. Parameters ---------- seg_label : TensorBatch (N, 1 + D + 1) Tensor of segmentation labels for the batch segmentation : TensorBatch (N, N_c) Tensor of logits from the segmentation model point_label : TensorBatch, optional (P, 1 + D + 1) Tensor of points of interests for the batch. This is used to upweight the loss near specific points. ghost : TensorBatch, optional (N, 2) Tensor of ghost logits from the segmentation model weights : TensorBatch, optional (N) Tensor of weights for each pixel in the batch **kwargs : dict, optional Other outputs of the upstream model which are not relevant here Returns ------- dict Dictionary of accuracies and losses """ # Get the underlying tensor in each TensorBatch seg_label_t = seg_label.tensor segmentation_t = segmentation.tensor ghost_t = ghost.tensor if ghost is not None else ghost weights_t = weights.tensor if weights is not None else weights # Make sure that the segmentation output and labels have the same length assert len(seg_label_t) == len(segmentation_t), ( f"The `segmentation` output length ({len(segmentation_t)}) " f"and its labels ({len(seg_label_t)}) do not match." ) assert not self.ghost or len(seg_label_t) == len(ghost_t), ( f"The `ghost` output length ({len(ghost_t)}) and " f"its labels ({len(seg_label_t)}) do not match." ) assert not self.ghost or weights is None, ( "Providing explicit weights is not compatible when peforming " "deghosting in tandem with semantic segmentation." ) # If requested, produce weights based on point-proximity if self.upweight_points: assert point_label is not None, ( "If upweighting the loss nearby points of interests, must " "provide a list of such points in `point_label`." ) dist_weights = self.get_distance_weights(seg_label, point_label) if weights is not None: weights_t *= dist_weights.tensor else: weights_t = dist_weights # Check that the labels have sensible values if self.ghost_label > -1: labels_t = (seg_label_t[:, VALUE_COL] == self.ghost_label).long() else: labels_t = seg_label_t[:, VALUE_COL].long() if torch.any(labels_t > self.num_classes): raise ValueError( "The segmentation labels contain labels larger than " "the number of logits output by the model." ) # If there is a dedicated ghost layer, apply the mask first if self.ghost: # Count the number of voxels in each class ghost_labels_t = (labels_t == GHOST_SHP).long() ghost_loss, ghost_acc, ghost_acc_class = self.get_loss_accuracy( ghost_t, ghost_labels_t ) # Restrict the segmentation target to true non-ghosts nonghost = torch.nonzero(ghost_labels_t == 0).flatten() segmentation_t = segmentation_t[nonghost] labels_t = labels_t[nonghost] # Compute the loss/accuracy of the semantic segmentation step seg_loss, seg_acc, seg_acc_class, weights_t = self.get_loss_accuracy( segmentation_t, labels_t, weights_t ) # Get the combined loss and accuracies result = {} if self.ghost: result.update( { "loss": self.alpha * seg_loss + self.beta * ghost_loss, "accuracy": (seg_acc + ghost_acc) / 2.0, "seg_loss": seg_loss, "seg_accuracy": seg_acc, "ghost_loss": ghost_loss, "ghost_accuracy": ghost_acc, "ghost_accuracy_class_0": ghost_acc_class[0], "ghost_accuracy_class_1": ghost_acc_class[1], } ) for c in range(self.num_classes): result[f"seg_accuracy_class_{c}"] = seg_acc_class[c] else: result.update({"loss": seg_loss, "accuracy": seg_acc}) for c in range(self.num_classes): result[f"accuracy_class_{c}"] = seg_acc_class[c] if weights_t is not None: result["weights"] = TensorBatch(weights_t, seg_label.counts) return result
[docs] def get_distance_weights(self, seg_label, point_label): """Define weights for each of the points in the image based on their distance from points of interests (typically vertices, but user defined). Parameters ---------- seg_label : TensorBatch (N, 1 + D + 1) Tensor of segmentation labels for the batch point_label : TensorBatch (P, 1 + D + 1) Tensor of points of interests for the batch. This is used to upweight the loss of points near a vertex. Returns ------- torch.Tensor (N) Array of weights associated with each point """ # Loop over the entries in the batch, compute proximity for each point dists = torch.full_like(seg_label.tensor[:, 0], float("inf")) for b in range(seg_label.batch_size): # Fetch image voxel and point coordinates for this entry voxels_b = seg_label[b][:, COORD_COLS] points_b = point_label[b][:, COORD_COLS] if not len(points_b) or not len(voxels_b): continue # Compute the minimal distance to any point in this entry dist_mat = cdist_fast(voxels_b, points_b) dists_b = torch.min(dist_mat, dim=1).values # Record information in the batch-wise tensor lower, upper = seg_label.edges[b], seg_label.edges[b + 1] dists[lower:upper] = dists_b # Upweight the points within some distance of the points of interest proximity = (dists < self.upweight_radius).long() close_count = torch.sum(proximity) counts = torch.tensor( [len(dists) - close_count, close_count], dtype=torch.long, device=dists.device, ) weights = len(proximity) / 2 / counts return weights[proximity]
[docs] def get_loss_accuracy(self, logits, labels, weights=None): """Computes the loss, global and classwise accuracy. Parameters ---------- logits : torch.Tensor (N, N_c) Output logits from the network for each voxel labels : torch.Tensor (N) Target values for each voxel weights : torch.Tensor, optional (N) Tensor of weights for each pixel in the batch Returns ------- torch.Tensor Cross-entropy loss value float Global accuracy np.ndarray (N_c) Vector of class-wise accuracy torch.Tensor (N) Updated set of weights for each pixel in the batch """ # If there is no input, nothing to do num_classes = logits.shape[1] if not len(logits): return 0.0, 1.0, np.ones(num_classes, dtype=np.float32), weights # Count the number of voxels in each class counts = torch.empty(num_classes, dtype=torch.long, device=labels.device) for c in range(num_classes): counts[c] = torch.sum(labels == c).item() # If requested, create a set of weights based on class prevalance if self.balance_loss: class_weight = torch.ones( len(counts), dtype=logits.dtype, device=logits.device ) class_weight[counts > 0] = len(labels) / num_classes / counts[counts > 0] class_weights = class_weight[labels] if weights is not None: weights *= class_weights else: weights = class_weights # Compute the loss if weights is None: loss = self.loss_fn(logits, labels).mean() else: loss = (weights * self.loss_fn(logits, labels)).sum() / weights.sum() # Compute the accuracies with torch.no_grad(): # Per-class prediction accuracy preds = torch.argmax(logits, dim=-1) acc_class = np.ones(num_classes, dtype=np.float32) for c in range(num_classes): if counts[c] > 0: mask = torch.nonzero(labels == c).flatten() acc_class[c] = (preds[mask] == c).sum().item() / counts[c] # Global prediction accuracy acc = (preds == labels).sum().item() / torch.sum(counts).item() return loss, acc, acc_class, weights