spine.model.uresnet_ppn
Module that defines a model and a loss to jointly train the semantic segmentation task and the point proposal task.
Classes
|
A model made of a UResNet backbone and PPN layers. |
|
Loss for amodel made of a UResNet backbone and PPN layers. |
- class spine.model.uresnet_ppn.UResNetPPN(*args: Any, **kwargs: Any)[source]
A model made of a UResNet backbone and PPN layers.
Typical configuration:
model: name: uresnet_ppn_chain modules: uresnet: # Your backbone uresnet config here ppn: # Your ppn config here
See also
UResNetSegmentation,PPNMethods
__call__(*args, **kwargs)Call self as a function.
forward(data[, seg_label])Run a batch of data through the foward function.
- MODULES = ['uresnet', 'ppn']
- forward(data, seg_label=None)[source]
Run a batch of data through the foward function.
- Parameters:
data (TensorBatch) – (N, 1 + D + N_f) tensor of voxel/value pairs - N is the the total number of voxels in the image - 1 is the batch ID - D is the number of dimensions in the input image - N_f is the number of features per voxel
seg_label (TensorBatch, optional) – (N, 1 + D + 1) tensor of voxel/ghost label pairs
- class spine.model.uresnet_ppn.UResNetPPNLoss(*args: Any, **kwargs: Any)[source]
Loss for amodel made of a UResNet backbone and PPN layers.
It includes a segmentation loss and a PPN loss.
Typical configuration:
model: name: uresnet_ppn_chain modules: uresnet: # Your backbone uresnet config goes here ppn: # Your ppn config goes here ppn_loss: # Your ppn loss config goes here
See also
spine.model.uresnet.SegmentationLoss,spine.model.layer.cnn.ppn.PPNLossMethods
__call__(*args, **kwargs)Call self as a function.
forward(seg_label, ppn_label[, clust_label, ...])Run a batch of data through the loss function.
- forward(seg_label, ppn_label, clust_label=None, weights=None, **result)[source]
Run a batch of data through the loss function.
- Parameters:
seg_label (TensorBatch) – (N, 1 + D + 1) Tensor of segmentation labels for the batch
ppn_label (TensorBatch) – (N, 1 + D + N_l) Tensor of PPN labels for the batch
clust_label (TensorBatch, optional) – (N, 1 + D + N_c) Tensor of cluster labels - N_c is is the number of cluster labels
weights (torch.Tensor, optional) –
Tensor of segmentation weights for each pixel in the batch
**result (dict) – Outputs of the UResNet + PPN forward function
- Return type:
TODO