"""Main functions that call the Driver class.
This is the first module called when launching a binary script under the `bin`
directory. It takes care of setting up the environment and the `Driver`
object(s) used to execute/train ML models, post-processors, analysis
scripts, writers and profilers.
"""
import os
from typing import Optional, Tuple
from .driver import Driver
from .utils.conditional import TORCH_AVAILABLE, torch
from .utils.logger import configure_rank_logging, logger
from .utils.torch.devices import set_visible_devices
[docs]
def run(cfg: dict) -> None:
"""Execute a model in one or more processes.
Parameters
----------
cfg : dict
Full driver/trainer configuration
"""
# Process the configuration to set up the driver world
if "base" not in cfg:
raise ValueError("Configuration must contain a 'base' section.")
distributed, world_size, torch_sharing = process_world(cfg["base"])
# Check if rank is provided externally (multi-node/SLURM setup)
rank = int(os.environ["RANK"]) if "RANK" in os.environ else None
# Launch the training/inference process
if not distributed:
# Run a single process on a single GPU (or CPU if no GPUs available)
run_single(cfg)
elif rank is not None:
# Multi-node: rank provided externally by SLURM/torchrun, run directly
if "train" not in cfg["base"]:
raise ValueError("Distributed execution is only supported for training.")
train_single(rank, cfg, distributed, world_size, torch_sharing)
else:
# Single-node multi-GPU: launch processes using multiprocessing.spawn
if "train" not in cfg["base"]:
raise ValueError("Distributed execution is only supported for training.")
torch.multiprocessing.spawn(
train_single,
args=(cfg, distributed, world_size, torch_sharing),
nprocs=world_size,
)
[docs]
def run_single(cfg: dict) -> None:
"""Execute a model on a single process.
Parameters
----------
cfg : dict
Full driver/trainer configuration
"""
# Dispatch
if "train" in cfg["base"]:
train_single(cfg=cfg, rank=None)
else:
inference_single(cfg)
[docs]
def train_single(
rank: Optional[int],
cfg: dict,
distributed: bool = False,
world_size: Optional[int] = None,
torch_sharing: Optional[str] = None,
) -> None:
"""Train a model in a single process.
Parameters
----------
rank : int, optional
Process rank
cfg : dict
Full driver/trainer configuration
distributed : bool, default False
If `True`, distribute the training process
world_size : int, optional
Number of devices to use in the distributed training process
torch_sharing : str or None, optional
File sharing strategy for torch distributed training
"""
configure_rank_logging(rank)
# Training always requires torch
if not TORCH_AVAILABLE:
raise ImportError(
"PyTorch is required for training. "
"Install with: pip install spine[model]"
)
# Set the torch sharing strategy, if needed
if distributed and torch_sharing is not None:
torch.multiprocessing.set_sharing_strategy(torch_sharing)
# If distributed, setup the process group
if distributed:
assert rank is not None and world_size is not None
setup_ddp(rank, world_size)
# Prepare the trainer
driver = Driver(cfg, rank)
# Run the training process
driver.run()
# If distributed, destroy the process group
if distributed:
torch.distributed.destroy_process_group()
[docs]
def inference_single(cfg: dict) -> None:
"""Execute a model in inference mode in a single process.
Parameters
----------
cfg : dict
Full driver configuration
"""
configure_rank_logging()
# Prepare the driver
driver = Driver(cfg)
# Find the set of weights to run the inference on
preloaded, weights = False, []
if driver.model is not None:
weights = driver.model.weight_path
if weights is None or isinstance(weights, str):
preloaded = True
weights = [weights]
else:
weights = sorted(weights)
weight_list = " - " + "\n - ".join(weights)
logger.info(
"Looping over %d set of weights:\n%s", len(weights), weight_list
)
if not weights:
weights = [None]
# Loop over the weights, run the inference loop
for weight in weights:
if driver.model is not None and weight is not None and not preloaded:
driver.model.load_weights(weight)
driver.initialize_log()
driver.run()
[docs]
def process_world(base: dict) -> Tuple[bool, int, Optional[str]]:
"""Check on the number of available GPUs and what has been requested.
Parameters
----------
base : dict
Base driver configuration dictionary
Returns
-------
distributed : bool
If `True`, distribute the training process
world_size : int
Number of devices to use in the distributed training process
torch_sharing : str or None
File sharing strategy for torch distributed training
"""
# Set the verbosity of the logger
verbosity = base.get("verbosity", "info")
logger.setLevel(verbosity.upper())
# Parse information about the world size, set visible CUDA devices
world_size = set_visible_devices(
world_size=base.get("world_size", None), gpus=base.get("gpus", None)
)
# If there is more than one GPU in use, must distribute
distributed = base.get("distributed", world_size > 1)
if world_size > 1 and not distributed:
raise ValueError(
"Multiple GPUs detected but distributed execution is disabled. "
"Set 'distributed: true' in the configuration to enable it."
)
# If distributed, check what the file sharing strategy is
torch_sharing = base.get("torch_sharing_strategy", None)
if torch_sharing is not None and torch_sharing not in (
"file_system",
"file_descriptor",
):
raise ValueError(
"torch_sharing_strategy must be one of: "
"'file_system', 'file_descriptor', or None"
)
return distributed, world_size, torch_sharing
[docs]
def setup_ddp(rank: int, world_size: int, backend: str = "nccl") -> None:
"""Sets up the DistributedDataParallel environment.
Parameters
----------
rank : int
Global rank of this process (0 to world_size-1)
world_size : int
Total number of processes across all nodes
backend : str, default "nccl"
Distributed backend to use
Notes
-----
For multi-node training, set these environment variables:
- MASTER_ADDR: IP address of the master node
- MASTER_PORT: Free port on the master node
- RANK: Global rank (0 to world_size-1)
- WORLD_SIZE: Total number of processes
- LOCAL_RANK (optional): Local rank on this node
"""
# Set master address and port from environment, or use defaults for single-node
if "MASTER_ADDR" not in os.environ:
os.environ["MASTER_ADDR"] = "localhost"
if "MASTER_PORT" not in os.environ:
os.environ["MASTER_PORT"] = "12355"
# Get local rank for setting the correct GPU device
# In multi-node: LOCAL_RANK is the GPU index on this machine
# In single-node: rank is the GPU index
local_rank = int(os.environ.get("LOCAL_RANK", rank))
# Initialize the process group for this GPU
torch.distributed.init_process_group(
backend=backend, rank=rank, world_size=world_size
)
torch.cuda.set_device(local_rank)