"""Image classification module.
This module includes:
- Single full image classification
- Individual cluster classification
- UQ implementations of the full image classification
"""
from collections import Counter, OrderedDict, defaultdict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Batch, Data
from spine.constants import PID_COL
from spine.data import TensorBatch
from spine.utils.gnn.cluster import form_clusters, get_cluster_label
# from .experimental.layer.pointmlp import PointMLPEncoder
from .experimental.bayes.encoder import MCDropoutEncoder
from .experimental.bayes.evidential import EVDLoss
from .experimental.layer.pointnet import PointNetEncoder
from .image import ImageClassifier
from .layer.cnn.act_norm import act_factory
from .layer.cnn.configuration import setup_cnn_configuration
from .layer.cnn.encoder import SparseResidualEncoder
from .layer.factories import loss_fn_factory
[docs]
class MultiParticleImageClassifier(ImageClassifier):
MODULES = ["particle_image_classifier", "network_base", "mink_encoder"]
def __init__(self, cfg, name="particle_image_classifier"):
super(MultiParticleImageClassifier, self).__init__(cfg, name)
model_cfg = cfg.get(name, {})
self.batch_col = model_cfg.get("batch_col", 0)
self.split_col = model_cfg.get("split_col", 6)
self.num_classes = model_cfg.get("num_classes", 5)
self.skip_invalid = model_cfg.get("skip_invalid", True)
self.target_col = model_cfg.get("target_col", 9)
self.invalid_id = model_cfg.get("invalid_id", -1)
self.split_input_mode = model_cfg.get("split_input_as_tg_batch", False)
[docs]
def split_input_as_tg_batch(self, point_cloud, clusts=None):
point_cloud_cpu = point_cloud.detach().cpu().numpy()
batches, bcounts = np.unique(
point_cloud_cpu[:, self.batch_col], return_counts=True
)
if clusts is None:
clusts = form_clusters(point_cloud_cpu, column=self.split_col)
if not len(clusts):
return Batch()
if self.skip_invalid:
target_ids = get_cluster_label(
point_cloud_cpu, clusts, column=self.target_col
)
clusts = [
c for i, c in enumerate(clusts) if target_ids[i] != self.invalid_id
]
if not len(clusts):
return Batch()
data_list = []
for i, c in enumerate(clusts):
x = point_cloud[c, 4].view(-1, 1)
pos = point_cloud[c, 1:4]
data = Data(x=x, pos=pos)
data_list.append(data)
split_data = Batch.from_data_list(data_list)
return split_data, clusts
[docs]
def split_input(self, point_cloud, clusts=None):
point_cloud_cpu = point_cloud.detach().cpu().numpy()
batches, bcounts = np.unique(
point_cloud_cpu[:, self.batch_col], return_counts=True
)
if clusts is None:
clusts = form_clusters(point_cloud_cpu, column=self.split_col)
if not len(clusts):
return point_cloud, [np.array([]) for _ in batches], []
if self.skip_invalid:
target_ids = get_cluster_label(
point_cloud_cpu, clusts, column=self.target_col
)
clusts = [
c for i, c in enumerate(clusts) if target_ids[i] != self.invalid_id
]
if not len(clusts):
return point_cloud, [np.array([]) for _ in batches], []
split_point_cloud = point_cloud.clone()
split_point_cloud[:, self.batch_col] = -1
for i, c in enumerate(clusts):
split_point_cloud[c, self.batch_col] = i
batch_ids = get_cluster_label(point_cloud_cpu, clusts, column=self.batch_col)
clusts_split, cbids = split_clusts(clusts, batch_ids, batches, bcounts)
return (
split_point_cloud[split_point_cloud[:, self.batch_col] > -1],
clusts_split,
cbids,
)
[docs]
def forward(self, input, clusts=None):
res = {}
(point_cloud,) = input
# It is possible that pid = 5 appears in the 9th column.
# In that case, it is observed that the training crashses with a
# integer overflow numel error.
mask = point_cloud[:, PID_COL] < self.num_classes
valid_points = point_cloud[mask]
if self.split_input_mode:
batch, clusts = self.split_input_as_tg_batch(valid_points, clusts)
out = self.encoder(batch)
out = self.final_layer(out)
res["clusts"] = [clusts]
res["logits"] = [out]
else:
out, clusts_split, cbids = self.split_input(valid_points, clusts)
res["clusts"] = [clusts_split]
out = self.encoder(out)
out = self.final_layer(out)
res["logits"] = [[out[b] for b in cbids]]
return res
[docs]
class DUQParticleClassifier(ImageClassifier):
"""
Uncertainty Estimation Using a Single Deep Deterministic Neural Network
https://arxiv.org/pdf/2003.02037.pdf
Joost van Amersfoort, Lewis Smith, Yee Whye Teh, Yarin Gal.
Pytorch Implementation for SparseConvNets with MinkowskiEngine backend.
"""
MODULES = ["network_base", "particle_image_classifier", "mink_encoder"]
def __init__(self, cfg, name="duq_particle_classifier"):
super(DUQParticleClassifier, self).__init__(cfg, name=name)
self.model_config = cfg.get(name, {})
self.final_layer = None
self.gamma = self.model_config.get("gamma", 0.999)
self.sigma = self.model_config.get("sigma", 0.3)
self.embedding_dim = self.model_config.get("embedding_dim", 64)
self.latent_size = self.model_config.get("latent_size", 256)
self.w = nn.Parameter(
torch.zeros(self.embedding_dim, self.num_classes, self.latent_size)
)
nn.init.kaiming_normal_(self.w, nonlinearity="relu")
self.register_buffer("N", torch.ones(self.num_classes) * 20)
self.register_buffer(
"m", torch.normal(torch.zeros(self.embedding_dim, self.num_classes), 0.05)
)
self.m = self.m * self.N.unsqueeze(0)
[docs]
def embed(self, x):
feats = self.encoder(x)
out = torch.einsum("ij,mnj->imn", feats, self.w)
return out
[docs]
def bilinear(self, z):
embeddings = self.m / self.N.unsqueeze(0)
diff = z - embeddings.unsqueeze(0)
y_pred = (-(diff**2)).mean(1).div(2 * self.sigma**2).exp()
return y_pred
[docs]
def forward(self, input):
(point_cloud,) = input
if self.training:
point_cloud.requires_grad_(True)
z = self.embed(point_cloud)
y_pred = self.bilinear(z)
res = {
"score": [y_pred],
"embedding": [z],
"input": [point_cloud],
"centroids": [
self.m.detach().cpu().numpy() / self.N.detach().cpu().numpy()
],
}
self.z = z
self.y_pred = y_pred
return res
[docs]
def update_buffers(self):
with torch.no_grad():
# normalizing value per class, assumes y is one_hot encoded
self.N = torch.max(
self.gamma * self.N + (1 - self.gamma) * self.y_pred.sum(0),
torch.ones_like(self.N),
)
# compute sum of embeddings on class by class basis
features_sum = torch.einsum("ijk,ik->jk", self.z, self.y_pred)
self.m = self.gamma * self.m + (1 - self.gamma) * features_sum
[docs]
class EvidentialParticleClassifier(ImageClassifier):
MODULES = ["network_base", "particle_image_classifier", "mink_encoder"]
def __init__(self, cfg, name="evidential_image_classifier"):
super(EvidentialParticleClassifier, self).__init__(cfg, name=name)
self.final_layer_name = cfg.get(name, {}).get("final_layer_name", "relu")
if self.final_layer_name == "relu":
self.final_layer = nn.Sequential(
nn.Linear(self.encoder.latent_size, self.num_classes), nn.ReLU()
)
elif self.final_layer_name == "softplus":
self.final_layer = nn.Sequential(
nn.Linear(self.encoder.latent_size, self.num_classes), nn.Softplus()
)
else:
raise Exception(
"Unknown output activation name %s provided" % self.final_layer_name
)
self.eps = cfg.get(name, {}).get("eps", 0.0)
[docs]
def forward(self, input):
(point_cloud,) = input
out = self.encoder(point_cloud)
evidence = self.final_layer(out)
# print("Evidence = ", evidence)
concentration = evidence + 1.0
S = torch.sum(concentration, dim=1, keepdim=True)
uncertainty = self.num_classes / (S + self.eps)
res = {
"evidence": [evidence],
"uncertainty": [uncertainty],
"concentration": [concentration],
"expected_probability": [concentration / S],
}
return res
[docs]
class BayesianParticleClassifier(torch.nn.Module):
MODULES = ["network_base", "mcdropout_encoder"]
def __init__(self, cfg, name="bayesian_particle_classifier"):
super(BayesianParticleClassifier, self).__init__()
setup_cnn_configuration(self, cfg, "network_base")
self.model_config = cfg.get(name, {})
self.num_classes = self.model_config.get("num_classes", 5)
self.encoder = MCDropoutEncoder(cfg)
self.mode = self.model_config.get("mode", "mc_dropout")
if self.mode == "evidential":
self.logit_layer = nn.Sequential(
nn.Linear(self.encoder.latent_size, self.num_classes), nn.Softplus()
)
else:
self.logit_layer = nn.Sequential(
nn.ReLU(), nn.Linear(self.encoder.latent_size, self.num_classes)
)
self.num_samples = self.model_config.get("num_samples", 20)
self.eps = self.model_config.get("eps", 0.0)
print("Dropout network will run inference on {} mode".format(self.mode))
[docs]
def evidential_forward(self, input):
(point_cloud,) = input
out = self.encoder(point_cloud)
out = self.logit_layer(out) + self.eps
concentration = out + 1.0
S = torch.sum(concentration, dim=1, keepdim=True)
uncertainty = self.num_classes / (S + 0.000001)
res = {}
res["evidence"] = [out]
res["uncertainty"] = [uncertainty]
res["concentration"] = [concentration]
res["expected_probability"] = [concentration / S]
return res
[docs]
def mc_forward(self, input, num_samples=None):
with torch.no_grad():
if num_samples is None:
num_samples = self.num_samples
print("Number of Samples = {}".format(num_samples))
(point_cloud,) = input
device = point_cloud.device
for m in self.modules():
if m.__class__.__name__ == "Dropout":
m.train()
num_batch = torch.unique(point_cloud[:, 0].int()).shape[0]
pvec = torch.zeros((num_batch, self.num_classes)).to(device)
logits = torch.zeros((num_batch, self.num_classes)).to(device)
discrete = torch.zeros((num_batch, self.num_classes)).to(device)
eye = torch.eye(self.num_classes).int().to(device)
for i in range(num_samples):
x = self.encoder(point_cloud)
out = self.logit_layer(x)
logits += out
pred = torch.argmax(out, dim=1)
pvec += F.softmax(out, dim=1)
discrete += eye[pred]
mc_dist = discrete / float(num_samples)
softmax_probs = pvec / float(num_samples)
logits = logits / float(num_samples)
# logits = torch.logit(softmax_probs)
res = {"softmax": [softmax_probs], "logits": [logits], "mc_dist": [mc_dist]}
return res
[docs]
def standard_forward(self, input, verbose=False):
print("Forwarding using weight averaging (standard dropout) ...")
(point_cloud,) = input
out = self.encoder(point_cloud)
out = self.logit_layer(out)
res = {"logits": [out]}
return res
[docs]
def forward(self, input):
if (not self.training) and (self.mode == "mc_dropout"):
return self.mc_forward(input)
elif self.mode == "evidential":
return self.evidential_forward(input)
else:
return self.standard_forward(input)
[docs]
class MultiParticleTypeLoss(nn.Module):
def __init__(self, cfg, name="particle_type_loss"):
super(MultiParticleTypeLoss, self).__init__()
loss_cfg = cfg.get(name, {})
self.num_classes = loss_cfg.get("num_classes", 5)
self.batch_col = loss_cfg.get("batch_col", 0)
self.target_col = loss_cfg.get("target_col", 9)
self.balance_classes = loss_cfg.get("balance_classes", False)
reduction = "mean" if not self.balance_classes else "sum"
self.xentropy = nn.CrossEntropyLoss(ignore_index=-1, reduction=reduction)
self.split_input_mode = loss_cfg.get("split_input_as_tg_batch", False)
[docs]
def forward_tg(self, out, valid_labels):
logits = out["logits"][0]
clusts = out["clusts"][0]
labels = get_cluster_label(valid_labels, clusts, self.target_col)
return [logits], [labels]
[docs]
def forward(self, out, type_labels):
valid_labels = type_labels[0][type_labels[0][:, 9] < self.num_classes]
if self.split_input_mode:
logits, labels = self.forward_tg(out, valid_labels)
else:
logits = out["logits"][0]
clusts = out["clusts"][0]
labels = [
get_cluster_label(
valid_labels[valid_labels[:, self.batch_col] == b],
clusts[b],
self.target_col,
)
for b in range(len(clusts))
if len(clusts[b])
]
if not len(labels):
res = {
"loss": torch.tensor(
0.0, requires_grad=True, device=valid_labels.device
),
"accuracy": 1.0,
}
for c in range(self.num_classes):
res[f"accuracy_class_{c}"] = 1.0
return res
labels = torch.tensor(
np.concatenate(labels), dtype=torch.long, device=valid_labels.device
)
logits = torch.cat(logits, axis=0)
if not self.balance_classes:
loss = self.xentropy(logits, labels)
else:
classes, counts = labels[labels > -1].unique(return_counts=True)
weights = torch.sum(counts) / counts / self.num_classes
loss = 0.0
for i, c in enumerate(classes):
class_mask = labels == c
loss += (
weights[i]
* self.xentropy(logits[class_mask], labels[class_mask])
/ torch.sum(counts)
)
pred = torch.argmax(logits, dim=1)
accuracy = float(torch.sum(pred[labels > -1] == labels[labels > -1])) / float(
labels[labels > -1].shape[0]
)
res = {"loss": loss, "accuracy": accuracy}
for c in range(self.num_classes):
mask = labels == c
res[f"accuracy_class_{c}"] = (
float(torch.sum(pred[mask] == labels[mask])) / float(torch.sum(mask))
if torch.sum(mask)
else 1.0
)
return res
[docs]
class MultiLabelCrossEntropy(nn.Module):
def __init__(self, cfg, name="duq_particle_classifier"):
super(MultiLabelCrossEntropy, self).__init__()
self.xentropy = nn.BCELoss(reduction="none")
self.num_classes = 5
model_cfg = cfg.get(name, {})
self.grad_w = model_cfg.get("grad_w", 0.0)
self.grad_penalty = model_cfg.get("grad_penalty", True)
[docs]
@staticmethod
def calc_gradient_penalty(x, y_pred):
"""
Code From the DUQ main Github Repository:
https://github.com/y0ast/deterministic-uncertainty-quantification
Author: Joost van Amersfoort
"""
gradients = torch.autograd.grad(
outputs=y_pred,
inputs=x,
grad_outputs=torch.ones_like(y_pred),
create_graph=True,
)[0]
gradients = gradients.flatten(start_dim=1)
# L2 norm
grad_norm = gradients.norm(2, dim=1)
# Two sided penalty
gradient_penalty = ((grad_norm - 1) ** 2).mean()
# One sided penalty - down
# gradient_penalty = F.relu(grad_norm - 1).mean()
return gradient_penalty
[docs]
def forward(self, out, type_labels):
probas = out["score"][0]
device = probas.device
labels_one_hot = torch.eye(self.num_classes)[type_labels[0][:, 0].long()].to(
device=device
)
loss1 = self.xentropy(probas, labels_one_hot)
pred = torch.argmax(probas, dim=1)
labels = type_labels[0][:, 0].long()
# Comptue gradient penalty
loss2 = 0
if self.grad_penalty:
loss2 = self.calc_gradient_penalty(out["input"][0], probas)
loss1 = loss1.sum(dim=1).mean()
loss = loss1 + self.grad_w * loss2
accuracy = float(torch.sum(pred == labels)) / float(labels.shape[0])
res = {
"loss": loss,
"loss_embedding": float(loss1),
"loss_grad_penalty": float(loss2),
"accuracy": accuracy,
}
print(res)
acc_types = {}
for c in labels.unique():
mask = labels == c
acc_types["accuracy_class_{}".format(int(c))] = float(
torch.sum(pred[mask] == labels[mask])
) / float(torch.sum(mask))
return res
[docs]
class EvidentialLearningLoss(nn.Module):
def __init__(self, cfg, name="evidential_learning_loss"):
super(EvidentialLearningLoss, self).__init__()
self.loss_config = cfg.get(name, {})
self.evd_loss_name = self.loss_config.get("evd_loss_name", "edl_sumsq")
self.num_classes = self.loss_config.get("num_classes", 5)
self.num_total_iter = self.loss_config.get("num_total_iter", 50000)
self.loss_fn = EVDLoss(self.evd_loss_name, "mean", T=self.num_total_iter)
[docs]
def forward(self, out, type_labels, iteration=0):
alpha = out["concentration"][0]
probs = out["expected_probability"][0]
device = alpha.device
labels = type_labels[0][:, 0].to(dtype=torch.long)
labels_onehot = torch.eye(self.num_classes, device=device)[labels]
loss_batch = self.loss_fn(alpha, labels_onehot, t=iteration)
loss = loss_batch.mean()
pred = torch.argmax(probs, dim=1)
accuracy = float(torch.sum(pred == labels)) / float(labels.shape[0])
res = {"loss": loss, "accuracy": accuracy}
acc_types = {}
for c in labels.unique():
mask = labels == c
acc_types["accuracy_class_{}".format(int(c))] = float(
torch.sum(pred[mask] == labels[mask])
) / float(torch.sum(mask))
return res