Source code for spine.utils.weighting

"""Module which contains methods to compute class weights.

All methods compute the weights based on the relative abundance of each class.
"""

from typing import Union

import numpy as np
import torch


[docs] def get_class_weights(labels, num_classes, mode="const", per_class=True): """Computes class-wise weights based on their relative abundance. Parameters ---------- labels : Union[np.ndarray, torch.Tensor] (N) Array of class values to base the weighting scheme on num_classes : int Total number of classes (needed if there are some missing) mode : str, default const Weigthing mode (one of 'const', 'log', 'sqrt' per_class : bool, default True If `True`, returns one valuer per class. Otherwise, this function returns one value per input in the `classes` array """ # Select the right functions depending on the input is_numpy = not isinstance(labels, torch.Tensor) if is_numpy: ones, unique, log, sqrt = np.ones, np.empty, np.unique, np.log, np.sqrt else: ones = lambda x: torch.ones(x, dtype=labels.dtype, device=labels.device) empty = lambda x: torch.empty(x, dtype=torch.float, device=labels.device) unique, log, sqrt = torch.unique, torch.log, torch.sqrt # Compute the abundance of each class in the input vector counts = ones(num_classes) uni, cnts = unique(labels, return_counts=True) counts[uni] = cnts # Compute the weights weights = len(labels) / num_classes / counts if mode == "const": pass elif mode == "log": weights = log(weights) elif mode == "sqrt": weights = sqrt(weights) else: raise ValueError("Weighting scheme not recognized:", mode) # Return if per_class: return weights else: weight_array = empty(len(labels)) for c in range(num_classes): weight_array[labels == c] = weights[c] return weight_array