"""Functions that instantiate IO tools from configuration blocks."""
from __future__ import annotations
from collections.abc import Mapping
from typing import Any
from warnings import warn
from spine.utils.conditional import TORCH_AVAILABLE
from spine.utils.factory import instantiate, module_dict
from . import read, write
READER_DICT = module_dict(read)
WRITER_DICT = module_dict(write)
__all__ = [
"reader_factory",
"writer_factory",
"loader_factory",
"dataset_factory",
"sampler_factory",
"collate_factory",
]
[docs]
def reader_factory(reader_cfg: Mapping[str, Any] | str) -> Any:
"""Instantiate a reader from a configuration block.
The configured ``name`` must match a reader class exported from
:mod:`spine.io.read`.
Parameters
----------
reader_cfg : Mapping[str, Any] or str
Reader configuration mapping or the short reader name.
Returns
-------
object
Instantiated reader object.
"""
# Initialize reader
return instantiate(READER_DICT, reader_cfg)
[docs]
def writer_factory(
writer_cfg: Mapping[str, Any] | str,
prefix: str | list[str] | None = None,
split: bool = False,
) -> Any:
"""Instantiate a writer from a configuration block.
The configured ``name`` must match a writer class exported from
:mod:`spine.io.write`.
Parameters
----------
writer_cfg : Mapping[str, Any] or str
Writer configuration mapping or the short writer name.
prefix : str or list[str], optional
Input file prefix or per-file list of prefixes used to derive output
names when the writer supports prefix-based naming.
split : bool, default False
Request one output file per input file. Writers that do not support
unsplit output may reject ``split=False`` explicitly.
Returns
-------
object
Instantiated writer object.
"""
# Initialize writer
extra_kwargs = {}
if prefix is not None:
extra_kwargs["prefix"] = prefix
if split:
extra_kwargs["split"] = split
return instantiate(WRITER_DICT, writer_cfg, **extra_kwargs)
[docs]
def loader_factory(
dataset: Mapping[str, Any] | str,
dtype: str,
batch_size: int | None = None,
minibatch_size: int | None = None,
shuffle: bool = True,
sampler: Mapping[str, Any] | str | None = None,
num_workers: int = 0,
collate_fn: Mapping[str, Any] | str | None = None,
entry_list: list[int] | None = None,
distributed: bool = False,
world_size: int = 0,
rank: int | None = None,
**kwargs: Any,
) -> Any:
"""Instantiate a PyTorch ``DataLoader`` from configuration.
Parameters
----------
dataset : mapping or str
Dataset configuration mapping or short dataset name.
dtype : str
Floating-point dtype passed to the dataset factory.
batch_size : int, optional
Global batch size. Mutually exclusive with ``minibatch_size``.
minibatch_size : int, optional
Per-process batch size. Mutually exclusive with ``batch_size``.
shuffle : bool, default True
Whether to shuffle batches in the underlying loader.
sampler : mapping or str, optional
Sampler configuration mapping or short sampler name.
num_workers : int, default 0
Number of loader worker processes.
collate_fn : mapping or str, optional
Collate function configuration mapping or short collate name.
entry_list : list[int], optional
Explicit subset of dataset entries to expose.
distributed : bool, default False
If ``True``, wrap the sampler for distributed loading.
world_size : int, default 0
Number of distributed processes/devices.
rank : int, optional
Distributed process rank. Required when ``distributed=True``.
**kwargs : dict
Extra keyword arguments forwarded to ``torch.utils.data.DataLoader``.
Returns
-------
torch.utils.data.DataLoader
Instantiated data loader.
"""
if not TORCH_AVAILABLE:
raise ImportError("PyTorch is required to use loader_factory.")
from torch.utils.data import DataLoader
# Process the batch size, make sure it is sensible
if batch_size is not None and minibatch_size is not None:
raise ValueError("Provide either `batch_size` or `minibatch_size`, not both.")
if batch_size is not None:
if world_size != 0 and (batch_size % world_size) != 0:
raise ValueError("The batch_size must be a multiple of the number of GPUs.")
minibatch_size = batch_size // max(world_size, 1)
elif minibatch_size is not None:
batch_size = minibatch_size * max(world_size, 1)
else:
raise ValueError("Provide either `batch_size` or `minibatch_size`, not both.")
# Initialize the dataset
torch_dataset = dataset_factory(dataset, entry_list, dtype)
# Initialize the sampler
if sampler is None and getattr(torch_dataset, "joint", False):
raise ValueError("JointDataset requires an explicit joint sampler.")
if sampler is not None:
sampler = sampler_factory(
sampler, torch_dataset, batch_size, distributed, world_size, rank
)
# Initialize the collate function
if collate_fn is not None:
collate_fn = collate_factory(
collate_fn, torch_dataset.data_types, torch_dataset.overlay_methods
)
# Initialize the loader
return DataLoader(
torch_dataset,
batch_size=minibatch_size,
shuffle=shuffle,
sampler=sampler,
num_workers=num_workers,
collate_fn=collate_fn,
**kwargs,
)
[docs]
def dataset_factory(
dataset_cfg: Mapping[str, Any] | str,
entry_list: list[int] | None = None,
dtype: str | None = None,
) -> Any:
"""Instantiate a dataset from configuration.
Parameters
----------
dataset_cfg : Mapping[str, Any] or str
Dataset configuration mapping or short dataset name.
entry_list : list[int], optional
Explicit subset of dataset entries to expose. When provided here, it
overrides any ``entry_list`` already present in ``dataset_cfg``.
dtype : str, optional
Floating-point dtype forwarded to the dataset constructor.
Returns
-------
object
Instantiated dataset object.
"""
from . import dataset
# Get the dataset class dictionary
dataset_dict = module_dict(dataset)
# Append the entry_list if it is provided independently
if entry_list is not None:
dataset_name = (
dataset_cfg if isinstance(dataset_cfg, str) else dataset_cfg.get("name")
)
if dataset_name in ("joint", "JointDataset"):
raise ValueError(
"`entry_list` must be configured inside `base`, `primary`, "
"or `secondary` for JointDataset."
)
warn(
"You are manually overwriting the existing `entry_list` "
"argument provided in the configuration file."
)
dataset_cfg = (
{"name": dataset_cfg} if isinstance(dataset_cfg, str) else dict(dataset_cfg)
)
dataset_cfg["entry_list"] = entry_list
# Initialize dataset
extra_kwargs: dict[str, Any] = {"dtype": dtype}
return instantiate(dataset_dict, dataset_cfg, **extra_kwargs)
[docs]
def sampler_factory(
sampler_cfg: Mapping[str, Any] | str,
dataset: Any,
minibatch_size: int,
distributed: bool = False,
num_replicas: int = 1,
rank: int | None = None,
) -> Any:
"""Instantiate a sampler from configuration.
Parameters
----------
sampler_cfg : mapping or str
Sampler configuration mapping or short sampler name.
dataset : object
Dataset instance used to initialize the sampler.
minibatch_size : int
Per-process batch size passed to the sampler.
distributed : bool, default False
If ``True``, wrap the sampler in ``DistributedProxySampler``.
num_replicas : int, default 1
Number of distributed processes/devices.
rank : int, optional
Distributed process rank. Required when ``distributed=True``.
Returns
-------
object
Instantiated sampler object, optionally wrapped for distributed
loading.
"""
if not TORCH_AVAILABLE:
raise ImportError("PyTorch is required to use sampler_factory.")
if distributed and rank is None:
raise ValueError("A distributed sampler requires an explicit integer `rank`.")
from . import sample
# Get the sampler class dictionary
sampler_dict = module_dict(sample)
# Initialize sampler
sampler_obj = instantiate(
sampler_dict, sampler_cfg, dataset=dataset, batch_size=minibatch_size
)
# Joint datasets consume tuple indexes; standard datasets consume scalars.
is_joint_dataset = getattr(dataset, "joint", False)
is_joint_sampler = getattr(sampler_obj, "joint", False)
if is_joint_dataset != is_joint_sampler:
expected = "joint" if is_joint_dataset else "standard"
got = "joint" if is_joint_sampler else "standard"
raise ValueError(
f"Cannot use a {got} sampler with a {expected} dataset. "
"Use a joint sampler with JointDataset and a standard sampler "
"with standard datasets."
)
# If we are working a distributed environment, wrap the sampler
if distributed:
sampler_obj = sample.DistributedProxySampler(sampler_obj, num_replicas, rank)
return sampler_obj
[docs]
def collate_factory(
collate_cfg: Mapping[str, Any] | str,
data_types: Mapping[str, str],
overlay_methods: Mapping[str, str],
) -> Any:
"""Instantiate a collate function from configuration.
Parameters
----------
collate_cfg : Mapping[str, Any] or str
Collate configuration mapping or short collate function name.
data_types : Mapping[str, str]
Mapping from parser output keys to their declared data type.
overlay_methods : Mapping[str, str]
Mapping from parser output keys to the overlay method used when
combining data from multiple sources.
Returns
-------
collections.abc.Callable
Instantiated collate callable.
"""
from . import collate
# Get the collate function class dictionary
collate_dict = module_dict(collate)
return instantiate(
collate_dict,
collate_cfg,
"collate_fn",
data_types=data_types,
overlay_methods=overlay_methods,
)