spine.model.uresnet_ppn

Module that defines a model and a loss to jointly train the semantic segmentation task and the point proposal task.

Classes

UResNetPPN(*args, **kwargs)

A model made of a UResNet backbone and PPN layers.

UResNetPPNLoss(*args, **kwargs)

Loss for amodel made of a UResNet backbone and PPN layers.

class spine.model.uresnet_ppn.UResNetPPN(*args: Any, **kwargs: Any)[source]

A model made of a UResNet backbone and PPN layers.

Typical configuration:

model:
  name: uresnet_ppn_chain
  modules:
    uresnet:
      # Your backbone uresnet config here
    ppn:
      # Your ppn config here

See also

UResNetSegmentation, PPN

Methods

__call__(*args, **kwargs)

Call self as a function.

forward(data[, seg_label])

Run a batch of data through the foward function.

MODULES = ['uresnet', 'ppn']
forward(data, seg_label=None)[source]

Run a batch of data through the foward function.

Parameters:
  • data (TensorBatch) – (N, 1 + D + N_f) tensor of voxel/value pairs - N is the the total number of voxels in the image - 1 is the batch ID - D is the number of dimensions in the input image - N_f is the number of features per voxel

  • seg_label (TensorBatch, optional) – (N, 1 + D + 1) tensor of voxel/ghost label pairs

class spine.model.uresnet_ppn.UResNetPPNLoss(*args: Any, **kwargs: Any)[source]

Loss for amodel made of a UResNet backbone and PPN layers.

It includes a segmentation loss and a PPN loss.

Typical configuration:

model:
  name: uresnet_ppn_chain
  modules:
    uresnet:
      # Your backbone uresnet config goes here
    ppn:
      # Your ppn config goes here
    ppn_loss:
      # Your ppn loss config goes here

See also

spine.model.uresnet.SegmentationLoss, spine.model.layer.cnn.ppn.PPNLoss

Methods

__call__(*args, **kwargs)

Call self as a function.

forward(seg_label, ppn_label[, clust_label, ...])

Run a batch of data through the loss function.

forward(seg_label, ppn_label, clust_label=None, weights=None, **result)[source]

Run a batch of data through the loss function.

Parameters:
  • seg_label (TensorBatch) – (N, 1 + D + 1) Tensor of segmentation labels for the batch

  • ppn_label (TensorBatch) – (N, 1 + D + N_l) Tensor of PPN labels for the batch

  • clust_label (TensorBatch, optional) – (N, 1 + D + N_c) Tensor of cluster labels - N_c is is the number of cluster labels

  • weights (torch.Tensor, optional) –

    1. Tensor of segmentation weights for each pixel in the batch

  • **result (dict) – Outputs of the UResNet + PPN forward function

Return type:

TODO