Source code for spine.model.bayes_uresnet

from collections import defaultdict

import MinkowskiEngine as ME
import torch
import torch.nn as nn
import torch.nn.functional as F

from .experimental.bayes.decoder import MCDropoutDecoder
from .experimental.bayes.encoder import MCDropoutEncoder
from .experimental.bayes.evidential import EVDLoss
from .layer.cnn.act_norm import act_factory, norm_factory
from .layer.cnn.configuration import setup_cnn_configuration
from .layer.cnn.uresnet_layers import UResNet


[docs] class BayesianUResNet(torch.nn.Module): """ UResNet with Uncertainty Quantification The backbone model consists of UResNet Encoder-Decoder format with standard residual layers for the shallow half and dropout residual layers for the deep half of the network. Configuration ------------- mode: str string indicator for slight changes in network behavior/architecture. Supports three options: - standard: standard dropout segmentation network. This also includes MCDropout segnet, since training behavior is identical for both standard and mcdropout networks. - evd: Changes network into evidential segmentation network num_samples: int if used as MCDropout Segnet, the number of stochastic forward samples to be taken. num_classes: int number of segmentation classes (default: 5) """ MODULES = [] def __init__(self, cfg, name="mcdropout_uresnet"): super(BayesianUResNet, self).__init__() setup_cnn_configuration(self, cfg, "network_base") self.model_config = cfg.get(name, {}) self.num_classes = self.model_config.get("num_classes", 5) self.num_samples = self.model_config.get("num_samples", 20) self.encoder = MCDropoutEncoder(cfg) self.decoder = MCDropoutDecoder(cfg) self.mode = self.model_config.get("mode", "standard") if "edl" in self.model_config.get("loss_fn", "cross_entropy"): self.classifier = nn.Sequential( ME.MinkowskiLinear(self.encoder.num_filters, self.num_classes), ME.MinkowskiSoftplus(), ) else: self.classifier = ME.MinkowskiLinear( self.encoder.num_filters, self.num_classes )
[docs] def mc_forward(self, input, num_samples=None): """ Forwarding operation for MC Dropout segmentation network. Args: num_samples: number of stochastic forward samples to be taken """ res = defaultdict(list) if num_samples is None: num_samples = self.num_samples for m in self.modules(): if m.__class__.__name__ == "Dropout": m.train() for igpu, x in enumerate(input): num_voxels = x.shape[0] device = x.device x_sparse = ME.SparseTensor( coordinates=x[:, :4].int(), features=x[:, -1].view(-1, 1) ) pvec = torch.zeros((num_voxels, self.num_classes)).to(device) logits = torch.zeros((num_voxels, self.num_classes)).to(device) for i in range(num_samples): res_encoder = self.encoder.encoder(x_sparse) decoderTensors = self.decoder( res_encoder["finalTensor"], res_encoder["encoderTensors"] ) feats = decoderTensors[-1] out = self.classifier(feats) logits += out.F pvec += F.softmax(out.F, dim=1) logits /= num_samples softmax_probs = pvec / num_samples res["softmax"].append(softmax_probs) res["segmentation"].append(logits) return res
[docs] def evidential_forward(self, input): """ Forwarding operation for evidential segmentation network. """ out = defaultdict(list) for igpu, x in enumerate(input): x = ME.SparseTensor( coordinates=x[:, :4].int(), features=x[:, -1].view(-1, 1) ) res_encoder = self.encoder.encoder(x) print([t.F.shape for t in res_encoder["encoderTensors"]]) decoderTensors = self.decoder( res_encoder["finalTensor"], res_encoder["encoderTensors"] ) feats = decoderTensors[-1] # For evidential models, logits correspond to collected evidence. logits = self.classifier(feats) ev = logits.F concentration = ev + 1.0 S = torch.sum(concentration, dim=1, keepdim=True) uncertainty = self.num_classes / (S + 0.000001) out["segmentation"].append(ev) out["evidence"].append(ev) out["uncertainty"].append(uncertainty) out["concentration"].append(concentration) out["expected_probability"].append(concentration / S) return out
[docs] def standard_forward(self, input): """ Forwarding operation for standard dropout segmentation network. """ out = defaultdict(list) for igpu, x in enumerate(input): x = ME.SparseTensor( coordinates=x[:, :4].int(), features=x[:, -1].view(-1, 1) ) res_encoder = self.encoder.encoder(x) print([t.F.shape for t in res_encoder["encoderTensors"]]) decoderTensors = self.decoder( res_encoder["finalTensor"], res_encoder["encoderTensors"] ) feats = decoderTensors[-1] # For evidential models, logits correspond to collected evidence. logits = self.classifier(feats) out["segmentation"].append(logits.F) return out
[docs] def forward(self, input): """ """ if self.mode == "mc_dropout": return self.mc_forward(input) elif self.mode == "evidential": return self.evidential_forward(input) else: return self.standard_forward(input)
[docs] class DUQUResNet(torch.nn.Module): """ Single Pass Deep Uncertainty Quantification Network Original Paper: https://arxiv.org/abs/2003.02037 Implementation adapted from the DUQ main Github Repository: https://github.com/y0ast/deterministic-uncertainty-quantification Author: Joost van Amersfoort """ MODULES = [] def __init__(self, cfg, name="duq_uresnet"): super(DUQUResNet, self).__init__() setup_cnn_configuration(self, cfg, name) self.model_config = cfg.get(name, {}) self.num_classes = self.model_config.get("num_classes", 5) self.num_samples = self.model_config.get("num_samples", 20) self.net = UResNet(cfg) self.gamma = self.model_config.get("gamma", 0.999) self.sigma = self.model_config.get("sigma", 0.3) self.embedding_dim = self.model_config.get("embedding_dim", 10) self.latent_size = self.model_config.get("latent_size", 32) self.w = nn.Parameter( torch.zeros(self.embedding_dim, self.num_classes, self.latent_size) ) nn.init.kaiming_normal_(self.w, nonlinearity="relu") self.register_buffer("N", torch.ones(self.num_classes) * 20) self.register_buffer( "m", torch.normal(torch.zeros(self.embedding_dim, self.num_classes), 0.05) ) self.m = self.m * self.N.unsqueeze(0)
[docs] def embed(self, x): res = self.net(x) feats = res["decoderTensors"][-1] print(feats.F) out = torch.einsum("ij,mnj->imn", feats.F, self.w) return out
[docs] def bilinear(self, z): embeddings = self.m / self.N.unsqueeze(0) diff = z - embeddings.unsqueeze(0) y_pred = (-(diff**2)).mean(1).div(2 * self.sigma**2).exp() return y_pred
[docs] def forward(self, input): (point_cloud,) = input if self.training: point_cloud.requires_grad_(True) z = self.embed(point_cloud) y_pred = self.bilinear(z) res = { "score": [y_pred], "embedding": [z], "input": [point_cloud], "centroids": [ self.m.detach().cpu().numpy() / self.N.detach().cpu().numpy() ], } self.z = z self.y_pred = y_pred return res
[docs] def update_buffers(self): with torch.no_grad(): # normalizing value per class, assumes y is one_hot encoded self.N = torch.max( self.gamma * self.N + (1 - self.gamma) * self.y_pred.sum(0), torch.ones_like(self.N), ) # compute sum of embeddings on class by class basis features_sum = torch.einsum("ijk,ik->jk", self.z, self.y_pred) self.m = self.gamma * self.m + (1 - self.gamma) * features_sum
[docs] class SegmentationLoss(nn.Module): def __init__(self, cfg, name="mcdropout_uresnet"): super(SegmentationLoss, self).__init__() self.loss_config = cfg.get(name, {}) self.loss_fn_name = self.loss_config.get("loss_fn", "edl_sumsq") self.loss_fn_args = self.loss_config.get("loss_fn_args", {}) if "edl" in self.loss_fn_name: self.loss_fn = EVDLoss(self.loss_fn_name, **self.loss_fn_args) elif self.loss_fn_name == "cross_entropy": self.loss_fn = torch.nn.functional.cross_entropy else: raise ValueError( "Loss function {} not recognized".format(self.loss_fn_name) ) self.one_hot = self.loss_config.get("one_hot", False) self.num_classes = self.loss_config.get("num_classes", 5)
[docs] def forward(self, outputs, label, iteration=0, weight=None): """ segmentation[0], label and weight are lists of size #gpus = batch_size. segmentation has as many elements as UResNet returns. label[0] has shape (N, dim + batch_id + 1) where N is #pts across minibatch_size events. """ # TODO Add weighting logits = outputs["segmentation"] if "edl" in self.loss_fn_name: segmentation = [ logits[0] + 1.0 ] # convert evidence to alpha concentration params. else: segmentation = logits device = segmentation[0].device assert len(segmentation) == len(label) # if weight is not None: # assert len(data) == len(weight) batch_ids = [d[:, 0] for d in label] total_loss = 0 total_acc = 0 count = 0 # Loop over GPUS for i in range(len(segmentation)): for b in batch_ids[i].unique(): batch_index = batch_ids[i] == b event_segmentation = segmentation[i][batch_index] event_label = label[i][:, -1][batch_index] event_label = torch.squeeze(event_label, dim=-1).long() loss_label = event_label if self.one_hot: loss_label = torch.eye(self.num_classes, device=device)[event_label] loss_seg = self.loss_fn(event_segmentation, loss_label, t=iteration) else: loss_seg = self.loss_fn(event_segmentation, loss_label) if weight is not None: event_weight = weight[i][batch_index] event_weight = torch.squeeze(event_weight, dim=-1) total_loss += torch.mean(loss_seg * event_weight) else: total_loss += torch.mean(loss_seg) # Accuracy predicted_labels = torch.argmax(event_segmentation, dim=-1) acc = (predicted_labels == event_label).sum().item() / float( predicted_labels.nelement() ) total_acc += acc count += 1 return {"accuracy": total_acc / count, "loss": total_loss / count}
[docs] class DUQSegmentationLoss(nn.Module): def __init__(self, cfg, name="duq_uresnet"): super(DUQSegmentationLoss, self).__init__() self.xentropy = nn.BCELoss(reduction="none") self.num_classes = 5 self.grad_w = cfg.get(name, {}).get("grad_w", 0.0) self.grad_penalty = cfg.get(name, {}).get("grad_penalty", True)
[docs] @staticmethod def calc_gradient_penalty(x, y_pred): """ Code From the DUQ main Github Repository: https://github.com/y0ast/deterministic-uncertainty-quantification Author: Joost van Amersfoort """ gradients = torch.autograd.grad( outputs=y_pred, inputs=x, grad_outputs=torch.ones_like(y_pred), create_graph=True, )[0] gradients = gradients.flatten(start_dim=1) # L2 norm grad_norm = gradients.norm(2, dim=1) # Two sided penalty gradient_penalty = ((grad_norm - 1) ** 2).mean() # One sided penalty - down # gradient_penalty = F.relu(grad_norm - 1).mean() return gradient_penalty
[docs] def forward(self, out, type_labels): # print(type_labels) probas = out["score"][0] device = probas.device labels = type_labels[0][:, -1].long() labels_one_hot = torch.eye(self.num_classes)[labels].to(device=device) loss1 = self.xentropy(probas, labels_one_hot) pred = torch.argmax(probas, dim=1) # Comptue gradient penalty loss2 = 0 if self.grad_penalty: loss2 = self.calc_gradient_penalty(out["input"][0], probas) loss1 = loss1.sum(dim=1).mean() loss = loss1 + self.grad_w * loss2 accuracy = float(torch.sum(pred == labels)) / float(labels.shape[0]) res = { "loss": loss, "loss_embedding": float(loss1), "loss_grad_penalty": float(loss2), "accuracy": accuracy, } print(res) acc_types = {} for c in labels.unique(): mask = labels == c acc_types["accuracy_{}".format(int(c))] = float( torch.sum(pred[mask] == labels[mask]) ) / float(torch.sum(mask)) return res