spine.model.manager

Centralize all methods associated with a machine-learning model.

Classes

ModelManager(name, modules, network_input[, ...])

Groups all relevant functions to construct a model and its loss.

class spine.model.manager.ModelManager(name, modules, network_input, loss_input=None, weight_path=None, weight_list=None, train: Mapping[str, Any] | None = None, to_numpy=False, time_dependent_loss=False, dtype='float32', distributed=False, rank=None, detect_anomaly=False, find_unused_parameters=False, iter_per_epoch=None)[source]

Groups all relevant functions to construct a model and its loss.

Methods

__call__(data[, iteration, epoch])

Calls the forward (and backward) function on a batch of data.

backward(loss)

Run the backward step on the model.

cast_to_numpy(result)

Casts the model output data products to numpy object in place.

clean_config(config)

Remove model loading/freezing keys from all level of a dictionary.

forward(data[, iteration])

Pass one minibatch of data through the network and the loss.

freeze_weights()

Freeze the weights of certain model components.

initialize_train(optimizer[, weight_prefix, ...])

Initialize the training regimen.

load_weights(full_weight_path)

Load the weights of certain model components.

prepare_data(data)

Fetches the necessary data products to form the input to the forward function and the input to the loss function.

save_state(iteration, epoch)

Save the model state.

initialize_train(optimizer, weight_prefix='snapshot', restore_optimizer=False, save_step=None, save_epoch=None, lr_scheduler=None, iter_per_epoch=None)[source]

Initialize the training regimen.

Parameters:
  • optimizer (dict) – Configuration of the optimizer

  • weight_prefix (str, default 'snapshot') – Path + name of the weight file prefix

  • save_step (int, optional) – Number of iterations before recording the model weights

  • save_epoch (float, optional) – Fraction of epoch to train on before recording the model weights

  • restore_optimizer (bool, default False) – Whether to load the opimizer state from the torch checkpoint

  • lr_scheduler (dict, optional) – Configuration of the learning rate scheduler

  • iter_per_epoch (int, optional) – Number of iterations per epoch (relevant for training)

clean_config(config)[source]

Remove model loading/freezing keys from all level of a dictionary.

This is used to remove the weight loading/freezing from the input configuration before it is fed to the model/loss classes.

Parameters:

config (dict) – Dictionary to remove the keys from

freeze_weights()[source]

Freeze the weights of certain model components.

Breadth-first search for freeze_weights parameters in the model configuration. If freeze_weights is True under a module block, requires_grad is set to False for its parameters. The batch normalization and dropout layers are set to evaluation mode.

load_weights(full_weight_path)[source]

Load the weights of certain model components.

Breadth-first search for weight_path parameters in the model configuration. If ‘weight_path’ is found under a module block, the weights are loaded for its parameters.

If a weight_path is not found for a given module, load the overall weights from weight_path under trainval for that module instead.

Parameters:

full_weight_path (str) – Path to the weights for the full model

prepare_data(data)[source]

Fetches the necessary data products to form the input to the forward function and the input to the loss function.

Parameters:

data (dict) – Dictionary of input data product keys, each of which maps to its associated batched data product

Returns:

  • input_dict (dict) – Input to the forward pass of the model

  • loss_dict (dict) – Labels to be used in the loss computation

forward(data, iteration=None)[source]

Pass one minibatch of data through the network and the loss.

Load one minibatch of data. pass it through the network forward function and the loss computation. Store the output.

Parameters:
  • data (dict) – Dictionary of input data product keys which each map to its associated batched data product

  • iteration (int, optional) – Iteration number (relevant for time-dependant losses)

Returns:

Dictionary of model and loss outputs

Return type:

dict

backward(loss)[source]

Run the backward step on the model.

Parameters:

loss (torch.tensor) – Scalar loss value to step the model weights

cast_to_numpy(result)[source]

Casts the model output data products to numpy object in place.

Parameters:

result (dict) – Dictionary of model and loss outputs

save_state(iteration, epoch)[source]

Save the model state.

Save three things from the model: - global_step (iteration) - global_epoch (epoch progress) - state_dict (model parameter values) - optimizer (optimizer parameter values)

Parameters:

iteration (int) – Iteration step index