Source code for spine.math.base

"""Numba JIT compiled implementation of basic functions.

Most of these functions are implemented here because vanilla numba does not
support optional arguments, such as `axis` for most functions or
`return_counts` for the `unique` function.
"""

import numba as nb
import numpy as np

__all__ = [
    "seed",
    "unique",
    "sum",
    "mean",
    "mode",
    "argmax",
    "argmin",
    "amax",
    "amin",
    "all",
    "softmax",
    "log_loss",
]


[docs] @nb.njit(cache=True) def seed(seed_value: int) -> None: """Sets the numpy random seed for all Numba jitted functions. Note that setting the seed using `np.random.seed` outside a Numba jitted function does *not* set the seed of Numba functions. Parameters ---------- seed_value : int Random number generator seed """ np.random.seed(seed_value)
[docs] @nb.njit(cache=True) def unique(x: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """Numba implementation of `np.unique(x, return_counts=True)`. Parameters ---------- x : np.ndarray (N,) array of values Returns ------- np.ndarray (U,) array of unique values np.ndarray (U,) array of counts of each unique value in the original array """ # Nothing to do if the input is empty uniques = np.empty(len(x), dtype=x.dtype) counts = np.empty(len(x), dtype=np.int64) if len(x) == 0: return uniques, counts # Build the list of unique values and counts x = np.sort(x.flatten()) uniques[0] = x[0] idx = 1 for i in range(len(x) - 1): if x[i] != x[i + 1]: uniques[idx] = x[i + 1] counts[idx - 1] = i + 1 idx += 1 counts[idx - 1] = len(x) # Narrow vectors down uniques = uniques[:idx] counts = counts[:idx] # Adjust counts counts[1:] = counts[1:] - counts[:-1] return uniques, counts
[docs] @nb.njit(cache=True) def sum(x: np.ndarray, axis: int) -> np.ndarray: """Numba implementation of `np.sum(x, axis)`. Parameters ---------- x : np.ndarray (N, M) array of values axis : int Array axis ID Returns ------- np.ndarray (N,) or (M,) array of `sum` values """ assert axis == 0 or axis == 1 summ = np.empty(x.shape[1 - axis], dtype=x.dtype) if axis == 0: for i in range(x.shape[1]): summ[i] = np.sum(x[:, i]) else: for i, xi in enumerate(x): summ[i] = np.sum(xi) return summ
[docs] @nb.njit(cache=True) def mean(x: np.ndarray, axis: int) -> np.ndarray: """Numba implementation of `np.mean(x, axis)`. Parameters ---------- x : np.ndarray (N, M) array of values axis : int Array axis ID Returns ------- np.ndarray (N,) or (M,) array of `mean` values """ assert axis == 0 or axis == 1 mean_values = np.empty(x.shape[1 - axis], dtype=x.dtype) if axis == 0: for i in range(x.shape[1]): mean_values[i] = np.mean(x[:, i]) else: for i, xi in enumerate(x): mean_values[i] = np.mean(xi) return mean_values
[docs] @nb.njit(cache=True) def mode(x: np.ndarray) -> int: """Numba implementation of `scipy.stats.mode(x)`. Parameters ---------- x : np.ndarray (N,) array of values Returns ------- int Most-probable value in the array """ values, counts = unique(x) return values[np.argmax(counts)]
[docs] @nb.njit(cache=True) def argmin(x: np.ndarray, axis: int) -> np.ndarray: """Numba implementation of `np.argmin(x, axis)`. Parameters ---------- x : np.ndarray (N, M) array of values axis : int Array axis ID Returns ------- np.ndarray (N,) or (M,) array of `argmin` values """ assert axis == 0 or axis == 1 argmin_values = np.empty(x.shape[1 - axis], dtype=np.int32) if axis == 0: for i in range(x.shape[1]): argmin_values[i] = np.argmin(x[:, i]) else: for i, xi in enumerate(x): argmin_values[i] = np.argmin(xi) return argmin_values
[docs] @nb.njit(cache=True) def argmax(x: np.ndarray, axis: int) -> np.ndarray: """Numba implementation of `np.argmax(x, axis)`. Parameters ---------- x : np.ndarray (N, M) array of values axis : int Array axis ID Returns ------- np.ndarray (N,) or (M,) array of `argmax` values """ assert axis == 0 or axis == 1 argmax_values = np.empty(x.shape[1 - axis], dtype=np.int32) if axis == 0: for i in range(x.shape[1]): argmax_values[i] = np.argmax(x[:, i]) else: for i, xi in enumerate(x): argmax_values[i] = np.argmax(xi) return argmax_values
[docs] @nb.njit(cache=True) def amin(x: np.ndarray, axis: int) -> np.ndarray: """Numba implementation of `np.amin(x, axis)`. Parameters ---------- x : np.ndarray (N, M) array of values axis : int Array axis ID Returns ------- np.ndarray (N,) or (M,) array of `min` values """ assert axis == 0 or axis == 1 xmin = np.empty(x.shape[1 - axis], dtype=x.dtype) if axis == 0: for i in range(x.shape[1]): xmin[i] = np.min(x[:, i]) else: for i, xi in enumerate(x): xmin[i] = np.min(xi) return xmin
[docs] @nb.njit(cache=True) def amax(x: np.ndarray, axis: int) -> np.ndarray: """Numba implementation of `np.amax(x, axis)`. Parameters ---------- x : np.ndarray (N, M) array of values axis : int Array axis ID Returns ------- np.ndarray (N,) or (M,) array of `max` values """ assert axis == 0 or axis == 1 xmax = np.empty(x.shape[1 - axis], dtype=x.dtype) if axis == 0: for i in range(x.shape[1]): xmax[i] = np.max(x[:, i]) else: for i, xi in enumerate(x): xmax[i] = np.max(xi) return xmax
[docs] @nb.njit(cache=True) def all(x: np.ndarray, axis: int) -> np.ndarray: """Numba implementation of `np.all(x, axis)`. Parameters ---------- x : np.ndarray (N, M) Array of values axis : int Array axis ID Returns ------- np.ndarray (N,) or (M,) array of `all` outputs """ assert axis == 0 or axis == 1 all_values = np.empty(x.shape[1 - axis], dtype=np.bool_) if axis == 0: for i in range(x.shape[1]): all_values[i] = np.all(x[:, i]) else: for i, xi in enumerate(x): all_values[i] = np.all(xi) return all_values
[docs] @nb.njit(cache=True) def softmax(x: np.ndarray, axis: int) -> np.ndarray: """ Numba implementation of `scipy.special.softmax(x, axis)`. Parameters ---------- x : np.ndarray (N, M) array of values axis : int Array axis ID Returns ------- np.ndarray (N, M) array of softmax scores """ assert axis == 0 or axis == 1 if axis == 0: xmax = amax(x, axis=0) logsumexp = np.log(np.sum(np.exp(x - xmax), axis=0)) + xmax return np.exp(x - logsumexp) else: xmax = amax(x, axis=1).reshape(-1, 1) logsumexp = np.log(np.sum(np.exp(x - xmax), axis=1)).reshape(-1, 1) + xmax return np.exp(x - logsumexp)
[docs] @nb.njit(cache=True) def log_loss(label: np.ndarray, pred: np.ndarray) -> float: """Numba implementation of cross-entropy loss. Parameters ---------- label : np.ndarray (N,) array of boolean labels (0 or 1) pred : np.ndarray (N,) array of float scores (between 0 and 1) Returns ------- float Cross-entropy loss """ if len(label) > 0: return -( np.sum(np.log(pred[label])) + np.sum(np.log(1.0 - pred[~label])) ) / len(label) else: return 0.0