Source code for spine.model.uresnet_ppn

"""Module that defines a model and a loss to jointly train the semantic
segmentation task and the point proposal task."""

import torch

from .layer.cnn.ppn import PPN, PPNLoss
from .uresnet import SegmentationLoss, UResNetSegmentation

__all__ = ["UResNetPPN", "UResNetPPNLoss"]


[docs] class UResNetPPN(torch.nn.Module): """A model made of a UResNet backbone and PPN layers. Typical configuration: .. code-block:: yaml model: name: uresnet_ppn_chain modules: uresnet: # Your backbone uresnet config here ppn: # Your ppn config here See Also -------- :class:`UResNetSegmentation`, :class:`PPN` """ MODULES = ["uresnet", "ppn"] def __init__(self, uresnet, ppn, uresnet_loss=None, ppn_loss=None): """Initialize the UResNet+PPN model. Parameters ---------- uresnet : dict UResNet configuration dictionary ppn : dict PPN configuration dictionary uresnet_loss : dict, optional UResNet loss configuration ppn_loss : dict, optional PPN loss configuration """ # Initialize the parent class super().__init__() # Initialize the UResNet backbone self.uresnet = UResNetSegmentation(uresnet) # Initialize the PPN layers self.ppn = PPN(uresnet, ppn) # Check that the UResNet and PPN configurations are compatible assert self.uresnet.ghost == self.ppn.ghost self.ghost = self.uresnet.ghost
[docs] def forward(self, data, seg_label=None): """Run a batch of data through the foward function. Parameters ---------- data : TensorBatch (N, 1 + D + N_f) tensor of voxel/value pairs - N is the the total number of voxels in the image - 1 is the batch ID - D is the number of dimensions in the input image - N_f is the number of features per voxel seg_label : TensorBatch, optional (N, 1 + D + 1) tensor of voxel/ghost label pairs """ # Pass the input through the backbone UResNet model result = self.uresnet(data) # Pass the decoder feature tensors to the PPN layers if self.ghost: # Deghost if self.ppn.use_true_ghost_mask: # Use the true ghost labels assert seg_label is not None, ( "If `use_true_ghost_mask` is set to `True`, must " "provide the `seg_label` tensor." ) result_ppn = self.ppn( result["final_tensor"], result["decoder_tensors"], result["ghost_tensor"], seg_label, ) else: # Use the ghost predictions result_ppn = self.ppn( result["final_tensor"], result["decoder_tensors"], result["ghost_tensor"], ) else: # No ghost result_ppn = self.ppn(result["final_tensor"], result["decoder_tensors"]) result.update(result_ppn) return result
[docs] class UResNetPPNLoss(torch.nn.Module): """Loss for amodel made of a UResNet backbone and PPN layers. It includes a segmentation loss and a PPN loss. Typical configuration: .. code-block:: yaml model: name: uresnet_ppn_chain modules: uresnet: # Your backbone uresnet config goes here ppn: # Your ppn config goes here ppn_loss: # Your ppn loss config goes here See Also -------- :class:`spine.model.uresnet.SegmentationLoss`, :class:`spine.model.layer.cnn.ppn.PPNLoss` """ def __init__(self, uresnet, ppn, ppn_loss, uresnet_loss=None): """Initialize the UResNet+PPN model loss. Parameters ---------- uresnet : dict UResNet configuration dictionary ppn : dict PPN configuration dictionary uresnet_loss : dict, optional UResNet loss configuration ppn_loss : dict, optional PPN loss configuration """ # Initialize the parent clas super().__init__() # Initialize the segmentation loss self.seg_loss = SegmentationLoss(uresnet, uresnet_loss) # Initialize the point proposal loss self.ppn_loss = PPNLoss(uresnet, ppn, ppn_loss)
[docs] def forward(self, seg_label, ppn_label, clust_label=None, weights=None, **result): """Run a batch of data through the loss function. Parameters ---------- seg_label : TensorBatch (N, 1 + D + 1) Tensor of segmentation labels for the batch ppn_label : TensorBatch (N, 1 + D + N_l) Tensor of PPN labels for the batch clust_label : TensorBatch, optional (N, 1 + D + N_c) Tensor of cluster labels - N_c is is the number of cluster labels weights : torch.Tensor, optional (N) Tensor of segmentation weights for each pixel in the batch **result : dict Outputs of the UResNet + PPN forward function Returns ------- TODO """ # Apply the segmentation loss result_seg = self.seg_loss(seg_label, weights=weights, **result) # Apply the PPN loss result_ppn = self.ppn_loss(ppn_label, clust_label=clust_label, **result) # Combine the two outputs result = { "loss": result_seg["loss"] + result_ppn["loss"], "accuracy": (result_seg["accuracy"] + result_ppn["accuracy"]) / 2, } result.update({"uresnet_" + k: v for k, v in result_seg.items()}) result.update({"ppn_" + k: v for k, v in result_ppn.items()}) return result