Source code for spine.utils.jit

"""JIT compilation decorators using Numba for performance optimization."""

import inspect
from functools import wraps

import numba as nb
import numpy as np

from .conditional import TORCH_AVAILABLE, torch


[docs] def numbafy(cast_args=[], list_args=[], keep_torch=False, ref_arg=None): """Function which wraps a `numba` function with some checks on the input to make the relevant conversions to numpy where necessary. Parameters ---------- cast_args : list(str), optional List of arguments to be cast to numpy list_args : list(str), optional List of arguments which need to be cast to a numba typed list keep_torch : bool, default False Make the output a torch object, if the reference argument is one ref_arg : str, optional Reference argument used to assign a type and device to the torch output Returns ------- callable Wrapped function which ensures input type compatibility with numba """ def outer(fn): @wraps(fn) def inner(*args, **kwargs): # Convert the positional arguments in args into # key:value pairs in kwargs keys = list(inspect.signature(fn).parameters.keys()) for i, val in enumerate(args): kwargs[keys[i]] = val # Extract the default values for the remaining parameters for key, val in inspect.signature(fn).parameters.items(): if key not in kwargs and val.default != inspect.Parameter.empty: kwargs[key] = val.default # If a torch output is request, register the input dtype and # device location if keep_torch: assert ( ref_arg in kwargs ), "Must provide the `ref_arg` to cast back to torch." dtype, device = None, None if TORCH_AVAILABLE and isinstance(kwargs[ref_arg], torch.Tensor): dtype = kwargs[ref_arg].dtype device = kwargs[ref_arg].device # If the cast data is not a numpy array, cast it for arg in cast_args: assert arg in kwargs, ( f"Argument `{arg}` appears in `cast_args` but does " "not appear in the function arguments." ) if not isinstance(kwargs[arg], np.ndarray): if TORCH_AVAILABLE and isinstance(kwargs[arg], torch.Tensor): kwargs[arg] = kwargs[arg].detach().cpu().numpy() else: raise TypeError( f"Can only convert torch.Tensor to numpy when PyTorch is available. " f"Got a request to cast {type(kwargs[arg])} instead." ) # If there is a reflected list in the input, type it for arg in list_args: assert arg in kwargs, ( f"Argument `{arg}` appears in `list_args` but does " "not appear in the function arguments." ) kwargs[arg] = nb.typed.List(kwargs[arg]) # Get the output ret = fn(**kwargs) if keep_torch and dtype and TORCH_AVAILABLE: if isinstance(ret, np.ndarray): ret = torch.tensor(ret, dtype=dtype, device=device) elif isinstance(ret, list): ret = [torch.tensor(r, dtype=dtype, device=device) for r in ret] else: raise TypeError("Return type not recognized, cannot cast to torch.") return ret return inner return outer