Source code for spine.model.spice

import torch
import torch.nn as nn

from spine.model.layer.cluster import loss_factory
from spine.model.layer.cluster.embeddings import (
    SPICE as SPICE_base,  # TODO why does this live out of this module?
)


[docs] class SPICE(SPICE_base): MODULES = [ "network_base", "uresnet_encoder", "embedding_decoder", "seediness_decoder", ] def __init__(self, cfg): super(SPICE, self).__init__(cfg)
[docs] class SPICELoss(nn.Module): """ Loss function for Proposal-Free Mask Generators. """ def __init__(self, cfg, name="spice_loss"): super(SPICELoss, self).__init__() self.model_config = cfg.get("spice", {}) self.skip_classes = self.model_config.get("skip_classes", [2, 3, 4]) self.loss_config = cfg.get(name, {}) self.loss_func_name = self.loss_config.get("name", "se_lovasz_inter") self.loss_func = loss_factory(self.loss_func_name) self.loss_func = self.loss_func(cfg) # print(self.loss_func)
[docs] def class_mask(self, cluster_label): """ Filter classes according to segmentation label. """ mask = torch.ones(len(cluster_label), dtype=bool, device=cluster_label.device) for c in self.skip_classes: mask &= cluster_label[:, -1] != c return mask
[docs] def forward(self, result, cluster_label): mask = self.class_mask(cluster_label[0]) segment_label = [cluster_label[0][mask][:, [0, 1, 2, 3, -1]]] group_label = [cluster_label[0][mask][:, [0, 1, 2, 3, 5]]] return self.loss_func(result, segment_label, group_label)