Source code for spine.main

"""Main functions that call the Driver class.

This is the first module called when launching a binary script under the `bin`
directory. It takes care of setting up the environment and the `Driver`
object(s) used to execute/train ML models, post-processors, analysis
scripts, writers and profilers.
"""

import os
from typing import Optional, Tuple

from .driver import Driver
from .utils.conditional import TORCH_AVAILABLE, torch
from .utils.logger import configure_rank_logging, logger
from .utils.torch.devices import set_visible_devices


[docs] def run(cfg: dict) -> None: """Execute a model in one or more processes. Parameters ---------- cfg : dict Full driver/trainer configuration """ # Process the configuration to set up the driver world if "base" not in cfg: raise ValueError("Configuration must contain a 'base' section.") distributed, world_size, torch_sharing = process_world(cfg["base"]) # Check if rank is provided externally (multi-node/SLURM setup) rank = int(os.environ["RANK"]) if "RANK" in os.environ else None # Launch the training/inference process if not distributed: # Run a single process on a single GPU (or CPU if no GPUs available) run_single(cfg) elif rank is not None: # Multi-node: rank provided externally by SLURM/torchrun, run directly if "train" not in cfg["base"]: raise ValueError("Distributed execution is only supported for training.") train_single(rank, cfg, distributed, world_size, torch_sharing) else: # Single-node multi-GPU: launch processes using multiprocessing.spawn if "train" not in cfg["base"]: raise ValueError("Distributed execution is only supported for training.") torch.multiprocessing.spawn( train_single, args=(cfg, distributed, world_size, torch_sharing), nprocs=world_size, )
[docs] def run_single(cfg: dict) -> None: """Execute a model on a single process. Parameters ---------- cfg : dict Full driver/trainer configuration """ # Dispatch if "train" in cfg["base"]: train_single(cfg=cfg, rank=None) else: inference_single(cfg)
[docs] def train_single( rank: Optional[int], cfg: dict, distributed: bool = False, world_size: Optional[int] = None, torch_sharing: Optional[str] = None, ) -> None: """Train a model in a single process. Parameters ---------- rank : int, optional Process rank cfg : dict Full driver/trainer configuration distributed : bool, default False If `True`, distribute the training process world_size : int, optional Number of devices to use in the distributed training process torch_sharing : str or None, optional File sharing strategy for torch distributed training """ configure_rank_logging(rank) # Training always requires torch if not TORCH_AVAILABLE: raise ImportError( "PyTorch is required for training. " "Install with: pip install spine[model]" ) # Set the torch sharing strategy, if needed if distributed and torch_sharing is not None: torch.multiprocessing.set_sharing_strategy(torch_sharing) # If distributed, setup the process group if distributed: assert rank is not None and world_size is not None setup_ddp(rank, world_size) # Prepare the trainer driver = Driver(cfg, rank) # Run the training process driver.run() # If distributed, destroy the process group if distributed: torch.distributed.destroy_process_group()
[docs] def inference_single(cfg: dict) -> None: """Execute a model in inference mode in a single process. Parameters ---------- cfg : dict Full driver configuration """ configure_rank_logging() # Prepare the driver driver = Driver(cfg) # Find the set of weights to run the inference on preloaded, weights = False, [] if driver.model is not None: weights = driver.model.weight_path if weights is None or isinstance(weights, str): preloaded = True weights = [weights] else: weights = sorted(weights) weight_list = " - " + "\n - ".join(weights) logger.info( "Looping over %d set of weights:\n%s", len(weights), weight_list ) if not weights: weights = [None] # Loop over the weights, run the inference loop for weight in weights: if driver.model is not None and weight is not None and not preloaded: driver.model.load_weights(weight) driver.initialize_log() driver.run()
[docs] def process_world(base: dict) -> Tuple[bool, int, Optional[str]]: """Check on the number of available GPUs and what has been requested. Parameters ---------- base : dict Base driver configuration dictionary Returns ------- distributed : bool If `True`, distribute the training process world_size : int Number of devices to use in the distributed training process torch_sharing : str or None File sharing strategy for torch distributed training """ # Set the verbosity of the logger verbosity = base.get("verbosity", "info") logger.setLevel(verbosity.upper()) # Parse information about the world size, set visible CUDA devices world_size = set_visible_devices( world_size=base.get("world_size", None), gpus=base.get("gpus", None) ) # If there is more than one GPU in use, must distribute distributed = base.get("distributed", world_size > 1) if world_size > 1 and not distributed: raise ValueError( "Multiple GPUs detected but distributed execution is disabled. " "Set 'distributed: true' in the configuration to enable it." ) # If distributed, check what the file sharing strategy is torch_sharing = base.get("torch_sharing_strategy", None) if torch_sharing is not None and torch_sharing not in ( "file_system", "file_descriptor", ): raise ValueError( "torch_sharing_strategy must be one of: " "'file_system', 'file_descriptor', or None" ) return distributed, world_size, torch_sharing
[docs] def setup_ddp(rank: int, world_size: int, backend: str = "nccl") -> None: """Sets up the DistributedDataParallel environment. Parameters ---------- rank : int Global rank of this process (0 to world_size-1) world_size : int Total number of processes across all nodes backend : str, default "nccl" Distributed backend to use Notes ----- For multi-node training, set these environment variables: - MASTER_ADDR: IP address of the master node - MASTER_PORT: Free port on the master node - RANK: Global rank (0 to world_size-1) - WORLD_SIZE: Total number of processes - LOCAL_RANK (optional): Local rank on this node """ # Set master address and port from environment, or use defaults for single-node if "MASTER_ADDR" not in os.environ: os.environ["MASTER_ADDR"] = "localhost" if "MASTER_PORT" not in os.environ: os.environ["MASTER_PORT"] = "12355" # Get local rank for setting the correct GPU device # In multi-node: LOCAL_RANK is the GPU index on this machine # In single-node: rank is the GPU index local_rank = int(os.environ.get("LOCAL_RANK", rank)) # Initialize the process group for this GPU torch.distributed.init_process_group( backend=backend, rank=rank, world_size=world_size ) torch.cuda.set_device(local_rank)