"""Centralize all methods associated with a machine-learning model."""
import glob
import os
from collections.abc import Mapping
from copy import deepcopy
from typing import Any
import numpy as np
from spine.data import EdgeIndexBatch, IndexBatch, TensorBatch
from spine.utils.conditional import TORCH_AVAILABLE, torch
from spine.utils.logger import logger
from spine.utils.stopwatch import StopwatchManager
from spine.utils.torch.training import lr_sched_factory, optim_factory
from .factories import model_factory
[docs]
class ModelManager:
"""Groups all relevant functions to construct a model and its loss."""
def __init__(
self,
name,
modules,
network_input,
loss_input=None,
weight_path=None,
weight_list=None,
train: Mapping[str, Any] | None = None,
to_numpy=False,
time_dependent_loss=False,
dtype="float32",
distributed=False,
rank=None,
detect_anomaly=False,
find_unused_parameters=False,
iter_per_epoch=None,
):
"""Process the model configuration.
Parameters
----------
name : str
Name of the model as specified under spine.model.factories
modules : dict
Dictionary of modules that make up the model
network_input : List[str]
List of keys of parsed objects to input into the model forward
loss_input : List[str], optional
List of keys of parsed objects to input into the loss forward
weight_path : str, optional
Path to global model weights to load
weight_list : str, optional
Path to a text file containing a list of weight file paths to load
to_numpy : int, default False
Cast model output to numpy ndarray
time_dependant_loss : bool, default False
Handles time-dependant loss, such as KL divergence annealing
train : dict, default None
Training regimen configuration
dtype : str, default 'float32'
Data type of the model parameters and input data
distributed : bool, default False
Whether the model is part of a distributed training process
rank : int, optional
Process rank in a torch distributed process
detect_anomaly : bool, default False
Whether to attempt to detect a torch anomaly
find_unused_parameters : bool, default False
Attempts to detect unused model parameters in the forward pass
iter_per_epoch : int, optional
Number of iterations per epoch (relevant for training)
"""
# Check that torch is available for model operations
if not TORCH_AVAILABLE:
raise ImportError(
"PyTorch is required to use the model manager. "
"Install with: pip install spine[model]"
)
# Save parameters
self.train: bool = train is not None
self.to_numpy = to_numpy
self.time_dependant = time_dependent_loss
self.dtype = getattr(torch, dtype)
self.distributed = distributed
self.rank = rank # Global rank (process ID in distributed group)
self.main_process = rank is None or rank == 0
# Determine device: use current_device() which setup_ddp() already configured
if self.rank is None:
self.device = "cpu"
self.device_id = None
else:
# In distributed mode, setup_ddp() already called torch.cuda.set_device(local_rank)
# This ensures we use the correct local GPU index, not global rank
self.device_id = torch.cuda.current_device()
self.device = f"cuda:{self.device_id}"
# Initialize the timers and the configuration dictionary
self.watch = StopwatchManager()
self.watch.initialize("forward")
if self.train:
self.watch.initialize(["backward", "save"])
# If anomaly detection is requested, set it
if detect_anomaly:
torch.autograd.set_detect_anomaly(True, check_nan=True)
# Deepcopy the model configuration, remove the weight loading/freezing
self.model_name = name
self.model_cfg = deepcopy(modules)
self.clean_config(modules)
# Initialize the model network and loss functions
net_cls, loss_cls = model_factory(name)
try:
self.net = net_cls(**modules)
self.net.to(device=self.device, dtype=self.dtype)
except Exception as err:
msg = f"Failed to instantiate {net_cls}"
raise type(err)(f"{err}\n{msg}")
try:
self.loss_fn = loss_cls(**modules)
self.loss_fn.to(device=self.device, dtype=self.dtype)
except Exception as err:
msg = f"Failed to instantiate {loss_cls}"
raise type(err)(f"{err}\n{msg}")
# If requested, initialize the training process
if train is not None:
self.initialize_train(**train, iter_per_epoch=iter_per_epoch)
else:
self.net.eval()
# If requested, freeze some/all the model weights
self.freeze_weights()
# Parse the list of weight files to consider for loading
self.weight_path = weight_path
if weight_path is not None:
# If a path is provided, check if it is an simple path or a wildcard pattern
if weight_list is not None:
raise ValueError("Cannot specify both `weight_path` and `weight_list`.")
if not os.path.isfile(weight_path):
if self.train or not glob.glob(weight_path):
raise ValueError(f"Weight file not found: {weight_path}")
self.weight_path = glob.glob(weight_path)
elif weight_list is not None:
with open(weight_list, "r", encoding="utf-8") as f:
self.weight_path = [line.strip() for line in f if line.strip()]
if not self.weight_path:
raise ValueError(f"No weight paths found in {weight_list}.")
# Load the weights only if a single weight file is provided. If multiple weight
# files are provided, the loading will be handled in a loop in the main driver.
if self.weight_path is None or isinstance(self.weight_path, str):
self.load_weights(self.weight_path)
# If the execution is distributed, wrap with DDP
if self.distributed:
self.net = torch.nn.parallel.DistributedDataParallel(
self.net,
device_ids=[self.device_id],
output_device=self.device_id,
find_unused_parameters=find_unused_parameters,
)
# Store the list of input keys to the forward/loss functions. These
# should be specified as a dictionary mapping the name of the argument
# in the forward/loss function to a data product name.
self.input_dict = network_input
self.loss_dict = loss_input
assert isinstance(network_input, dict), (
"Must specify `network_input` as a dictionary mapping model "
"input keys onto data loader product keys."
)
assert loss_input is None or isinstance(loss_input, dict), (
"Must specify `loss_input` as a dictionary mapping loss "
"input keys onto data loader product keys."
)
[docs]
def initialize_train(
self,
optimizer,
weight_prefix="snapshot",
restore_optimizer=False,
save_step=None,
save_epoch=None,
lr_scheduler=None,
iter_per_epoch=None,
):
"""Initialize the training regimen.
Parameters
----------
optimizer : dict
Configuration of the optimizer
weight_prefix : str, default 'snapshot'
Path + name of the weight file prefix
save_step : int, optional
Number of iterations before recording the model weights
save_epoch : float, optional
Fraction of epoch to train on before recording the model weights
restore_optimizer : bool, default False
Whether to load the opimizer state from the torch checkpoint
lr_scheduler : dict, optional
Configuration of the learning rate scheduler
iter_per_epoch : int, optional
Number of iterations per epoch (relevant for training)
"""
# Turn train on
self.train = True
self.net.train()
# Store parameters
self.weight_prefix = weight_prefix
self.restore_optimizer = restore_optimizer
# Store the saving parameters
if save_step is not None and save_epoch is not None:
raise ValueError("Cannot specify both `save_step` and `save_epoch`.")
self.save_step = save_step
if save_epoch is not None:
# Convert the save epoch to a save step
self.save_step = max(1, int(save_epoch * iter_per_epoch))
# Make a directory for the weight files, if need be
save_dir = os.path.dirname(weight_prefix)
if save_dir and not os.path.isdir(save_dir):
os.makedirs(save_dir, exist_ok=True)
# Initiliaze the optimizer
self.optimizer = optim_factory(optimizer, self.net.parameters())
# Initialize the learning rate scheduler
self.lr_scheduler = None
if lr_scheduler is not None:
self.lr_scheduler = lr_sched_factory(lr_scheduler, self.optimizer)
def __call__(self, data, iteration=None, epoch=None):
"""Calls the forward (and backward) function on a batch of data.
Parameters
----------
data : dict
Dictionary of input data product keys which each map to its
associated batched data product
iteration : int, optional
Iteration number (relevant for training)
epoch : float, optional
Epoch fractional count (relevant for training)
Returns
-------
dict
Dictionary of model and loss outputs
"""
# Reset active stopwatches
self.watch.reset_if_active()
# Reset the gradient accumulation, free memory
if self.train:
self.optimizer.zero_grad(set_to_none=True)
# Run the model forward
self.watch.start("forward")
result = self.forward(data, iteration)
self.watch.stop("forward")
# If traning run the backward pass and update the weigths
if self.train:
assert (
"loss" in result
), "Every model must return a `loss` value to be trained."
self.watch.start("backward")
self.backward(result["loss"])
self.watch.stop("backward")
# If training and at an appropriate iteration, save model state
if self.train:
self.watch.start("save")
assert (
iteration is not None
), "Must provide iteration information to save the model state."
if self.save_step is not None and self.main_process:
if ((iteration + 1) % self.save_step) == 0:
self.save_state(iteration, epoch)
self.watch.stop("save")
# If requested, cast the result dictionary to numpy
if self.to_numpy:
self.cast_to_numpy(result)
return result
[docs]
def clean_config(self, config):
"""Remove model loading/freezing keys from all level of a dictionary.
This is used to remove the weight loading/freezing from the input
configuration before it is fed to the model/loss classes.
Parameters
----------
config : dict
Dictionary to remove the keys from
"""
keys = ["model_name", "weight_path", "freeze_weights"]
if isinstance(config, dict):
for k in keys:
if k in config:
del config[k]
for val in config.values():
self.clean_config(val)
[docs]
def freeze_weights(self):
"""Freeze the weights of certain model components.
Breadth-first search for `freeze_weights` parameters in the model
configuration. If `freeze_weights` is `True` under a module block,
`requires_grad` is set to `False` for its parameters. The batch
normalization and dropout layers are set to evaluation mode.
"""
# Loop over all the module blocks in the model configuration
module_items = list(self.model_cfg.items())
while len(module_items) > 0:
# Get the module name and its configuration block
module, config = module_items.pop()
# If the module is to be frozen, apply
if config.get("freeze_weights", False):
# Fetch the module name to be found in the state dictionary
model_name = config.get("model_name", module)
# Set BN and DO layers to evaluation mode
getattr(self.net, module).eval()
# Freeze all the weights of this module
count = 0
for name, param in self.net.named_parameters():
if module in name:
key = name.replace(f".{module}.", f".{model_name}.")
if key in self.net.state_dict().keys():
param.requires_grad = False
count += 1
# Throw if no weights were found to freeze
assert count, f"Could not find any weights to freeze for {module}"
logger.info("Froze %d weights in module %s", count, module)
# Keep the BFS going by adding the nested blocks
for key in config:
if isinstance(config[key], dict):
module_items.append((key, config[key]))
[docs]
def load_weights(self, full_weight_path):
"""Load the weights of certain model components.
Breadth-first search for `weight_path` parameters in the model
configuration. If 'weight_path' is found under a module block,
the weights are loaded for its parameters.
If a `weight_path` is not found for a given module, load the overall
weights from `weight_path` under `trainval` for that module instead.
Parameters
----------
full_weight_path : str
Path to the weights for the full model
"""
# If a general model path is provided, add it to the loading list first
weight_paths = []
if full_weight_path:
weight_paths = [(self.model_name, full_weight_path, "")]
# Find the list of sub-module weights to subsequently load
module_items = list(self.model_cfg.items())
while len(module_items) > 0:
module, config = module_items.pop()
if config.get("weight_path", "") != "":
model_name = config.get("model_name", module)
weight_paths.append((module, config["weight_path"], model_name))
for key in config:
if isinstance(config[key], dict):
module_items.append((key, config[key]))
# If no pre-trained weights are requested, nothing to do here
self.start_iteration = 0
if not weight_paths:
return
# Loop over provided model paths
for module, weight_path, model_name in weight_paths:
# Module-level weight paths must resolve to a single checkpoint.
if not os.path.isfile(weight_path):
raise ValueError(
"Weight file not found for module " f"{module}: {weight_path}"
)
# Load weight file into existing model
logger.info(
"Restoring weights for module %s from %s...", module, weight_path
)
with open(weight_path, "rb") as f:
# Read checkpoint
try:
checkpoint = torch.load(
f, map_location=self.device, weights_only=True
)
except TypeError as err:
if "weights_only" not in str(err):
raise
f.seek(0)
checkpoint = torch.load(f, map_location=self.device)
state_dict = checkpoint["state_dict"]
# Check that all the needed weights are provided
missing_keys = []
if module == self.model_name:
for name in self.net.state_dict():
if not name in state_dict.keys():
missing_keys.append((name, name))
else:
# Update the key names according to the name used to store
state_dict = {}
for name in self.net.state_dict():
if f"{module}." in name:
suffix = "." if len(model_name) > 0 else ""
key = name.replace(f"{module}.", f"{model_name}{suffix}")
if key in checkpoint["state_dict"].keys():
state_dict[name] = checkpoint["state_dict"][key]
else:
missing_keys.append((name, key))
# If some necessary keys were not found, throw
if missing_keys:
logger.critical("These necessary parameters could not be found:")
for name, key in missing_keys:
logger.critical("Parameter %s is missing for %s.", key, name)
raise ValueError(
"To be loaded, a set of weights "
"must provide all necessary parameters."
)
# Load checkpoint. Check that all weights are used
bad_keys = self.net.load_state_dict(state_dict, strict=False)
if len(bad_keys.unexpected_keys) > 0:
logger.warning(
"This weight file contains parameters that could "
"not be loaded, indicating that the weight file "
"contains more than needed. This might be ok."
)
logger.warning("Unexpected keys: %s", bad_keys.unexpected_keys)
# Load the optimizer state from the main weight file only
if self.train and module == self.model_name and self.restore_optimizer:
self.optimizer.load_state_dict(checkpoint["optimizer"])
# Get the latest iteration from the main weight file only
if module == self.model_name:
self.start_iteration = checkpoint["global_step"] + 1
logger.info("Done.")
[docs]
def prepare_data(self, data):
"""Fetches the necessary data products to form the input to the forward
function and the input to the loss function.
Parameters
----------
data : dict
Dictionary of input data product keys, each of which maps to its
associated batched data product
Returns
-------
input_dict : dict
Input to the forward pass of the model
loss_dict : dict
Labels to be used in the loss computation
"""
# Fetch the requested data products
input_dict, loss_dict = {}, {}
with torch.set_grad_enabled(self.train):
# Load the data products for the model forward
input_dict = {}
for param, name in self.input_dict.items():
assert name in data, (
f"Must provide `{name}` in the dataloader schema to "
"input into the model forward."
)
value = data[name]
if isinstance(value, TensorBatch):
value = data[name].to_tensor(device=self.device, dtype=self.dtype)
input_dict[param] = value
# Load the data products for the loss function
loss_dict = {}
if self.loss_dict is not None:
for param, name in self.loss_dict.items():
assert name in data, (
f"Must provide `{name}` in the dataloader schema "
"to input into the loss function."
)
value = data[name]
if isinstance(value, TensorBatch):
value = data[name].to_tensor(
device=self.device, dtype=self.dtype
)
loss_dict[param] = value
return input_dict, loss_dict
[docs]
def forward(self, data, iteration=None):
"""Pass one minibatch of data through the network and the loss.
Load one minibatch of data. pass it through the network forward
function and the loss computation. Store the output.
Parameters
----------
data : dict
Dictionary of input data product keys which each map to its
associated batched data product
iteration : int, optional
Iteration number (relevant for time-dependant losses)
Returns
-------
dict
Dictionary of model and loss outputs
"""
# Prepare the input to the forward and loss functions
input_dict, loss_dict = self.prepare_data(data)
# If in train mode, record the gradients for backward step
with torch.set_grad_enabled(self.train):
# Apply the model forward
result = self.net(**input_dict)
# Compute the loss if one is specified, append results
if self.loss_dict:
if not self.time_dependant:
result.update(self.loss_fn(**loss_dict, **result))
else:
result.update(
self.loss_fn(iteration=iteration, **loss_dict, **result)
)
return result
[docs]
def backward(self, loss):
"""Run the backward step on the model.
Parameters
----------
loss : torch.tensor
Scalar loss value to step the model weights
"""
# Run the model backward
loss.backward()
# Step the optimizer
self.optimizer.step()
# Step the learning rate scheduler
if self.lr_scheduler is not None:
self.lr_scheduler.step()
# If the model has a buffer that needs to be updated, do it after
# the trainable parameter update
if hasattr(self.net, "update_buffers"):
logger.info("Updating buffers")
self.net.update_buffers()
[docs]
def cast_to_numpy(self, result):
"""Casts the model output data products to numpy object in place.
Parameters
----------
result : dict
Dictionary of model and loss outputs
"""
# Loop over the key, value pairs in the result dictionary
for key, value in result.items():
# Cast to numpy or python scalars
if np.isscalar(value):
# Scalar
result[key] = value
elif isinstance(value, torch.Tensor) and value.numel() == 1:
# Scalar tensor
result[key] = value.item()
elif isinstance(value, (TensorBatch, IndexBatch, EdgeIndexBatch)):
# Batch of data
result[key] = value.to_numpy()
elif (
isinstance(value, list)
and len(value)
and isinstance(value[0], TensorBatch)
):
# List of tensor batches
result[key] = [v.to_numpy() for v in value]
else:
dtype = type(value)
raise ValueError(f"Cannot cast output {key} of type {dtype} to numpy.")
[docs]
def save_state(self, iteration, epoch):
"""Save the model state.
Save three things from the model:
- global_step (iteration)
- global_epoch (epoch progress)
- state_dict (model parameter values)
- optimizer (optimizer parameter values)
Parameters
----------
iteration : int
Iteration step index
"""
# Make sure that the weight prefix is valid
assert self.weight_prefix, "Must provide a weight prefix to store them."
filename = f"{self.weight_prefix}-{iteration:d}.ckpt"
model = self.net if not self.distributed else self.net.module
torch.save(
{
"global_step": iteration,
"global_epoch": epoch,
"state_dict": model.state_dict(),
"optimizer": self.optimizer.state_dict(),
},
filename,
)