Source code for spine.utils.inference

"""Module with helper functions to run inference on a model configuration."""

from copy import deepcopy

import yaml


[docs] def get_inference_cfg( cfg, file_keys=None, weight_path=None, batch_size=None, num_workers=None, cpu=False ): """Turns a training configuration into an inference configuration. This script does the following: - Turn `train` to `False` - Set sequential sampling - Load the specified validation file_keys, if requested - Load the specified set of weight_path, if requested - Reset the batch_size to a different value, if requested - Sets num_workers to a different value, if requested - Make the model run in CPU mode, if requested Parameters ---------- cfg : Union[str, dict] Configuration file or Path to the configuration file file_keys : str, optional Path to the dataset to use for inference weight_path : str, optional Path to the weigths to use for inference batch_size : int, optional Number of data samples per batch num_workers : int, optional Number of workers that load data cpu : bool, default False Whether or not to execute the inference on CPU Returns ------- dict Dictionary of parameters to initialize handlers """ # Fetch the training configuration if isinstance(cfg, dict): cfg = deepcopy(cfg) else: cfg = open(cfg, "r", encoding="utf-8") cfg = yaml.safe_load(cfg) # Turn train to False if "train" in cfg["base"]: del cfg["base"]["train"] # Turn on unwrapper cfg["base"]["unwrap"] = True # Convert mode output to numpy cfg["model"]["to_numpy"] = True # Get rid of random sampler if "sampler" in cfg["io"]["loader"]: del cfg["io"]["loader"]["sampler"] # Change the batch_size, if requested if batch_size is not None: cfg["io"]["loader"]["batch_size"] = batch_size # Change dataset, if requested if file_keys is not None: cfg["io"]["loader"]["dataset"]["file_keys"] = file_keys # Set the number of workers, if requested if num_workers is not None: cfg["io"]["loader"]["num_workers"] = num_workers # Change weights, if requested if weight_path is not None: cfg["model"]["weight_path"] = weight_path # Put the network in CPU mode, if requested if cpu: cfg["base"]["world_size"] = 0 return cfg