Source code for spine.io.dataset.hdf5

"""Dataset wrapper around :class:`spine.io.read.HDF5Reader`."""

from __future__ import annotations

from collections.abc import Mapping, Sequence
from typing import Any, ClassVar

from spine.utils.conditional import TORCH_AVAILABLE
from spine.utils.factory import instantiate, module_dict
from spine.utils.logger import logger

from ..parse import hdf5 as parse_hdf5
from ..read import HDF5Reader, StageHDF5Reader
from .base import BaseDataset, DataDict

__all__ = ["HDF5Dataset"]

PARSER_DICT = module_dict(parse_hdf5)


[docs] class HDF5Dataset(BaseDataset): """Torch dataset wrapper around flat or staged HDF5 readers. The dataset can operate in two modes: - flat HDF5 mode, backed by :class:`spine.io.read.HDF5Reader` - staged cache mode, backed by :class:`spine.io.read.StageHDF5Reader` In both cases the dataset exposes a uniform parser-driven interface to the DataLoader layer. Reader-produced metadata such as entry indexes and source provenance are forwarded automatically alongside any parsed products. """ name: ClassVar[str] = "hdf5" parsers: dict[str, Any] reader: HDF5Reader | StageHDF5Reader
[docs] def __init__( self, dtype: str | None = None, staged: bool = False, stage: str | None = None, schema: Mapping[str, Mapping[str, Any]] | None = None, keys: Sequence[str] | None = None, skip_keys: Sequence[str] | None = None, data_types: Mapping[str, str] | None = None, overlay_methods: Mapping[str, str] | None = None, augment: Mapping[str, Any] | None = None, **kwargs: Any, ) -> None: """Instantiate the HDF5-backed dataset. Parameters ---------- dtype : str, optional Floating-point dtype forwarded to parser factories staged : bool, default False If `True`, use :class:`StageHDF5Reader` as the backend instead of the flat :class:`HDF5Reader` stage : str, optional Default stage name to read when `staged=True`. Individual schema entries may override this with their own ``stage`` field. schema : mapping, optional Parser schema used to reconstruct higher-level products keys : sequence[str], optional Explicit list of raw HDF5 products to keep skip_keys : sequence[str], optional Explicit list of raw HDF5 products to drop data_types : mapping, optional Explicit collate type overrides for raw-product mode overlay_methods : mapping, optional Explicit overlay-method overrides for raw-product mode augment : mapping, optional Augmentation applied to each loaded sample **kwargs : Any Reader-specific keyword arguments forwarded to the selected HDF5 backend reader """ # Initialize parent class super().__init__() # Validate the configuration and prepare reader arguments before # instantiating the backend. if not TORCH_AVAILABLE: raise ImportError("PyTorch is required to use HDF5Dataset.") if keys is not None and skip_keys is not None: raise ValueError("Provide either `keys` or `skip_keys`, not both.") if not staged and stage is not None: raise ValueError("`stage` can only be provided when `staged=True`.") self.keys = set(keys) if keys is not None else None self.skip_keys = set(skip_keys) if skip_keys is not None else set() self.parsers = {} self._data_types = dict(data_types) if data_types is not None else None self._overlay_methods = ( dict(overlay_methods) if overlay_methods is not None else None ) reader_stage_map: dict[str, str] = {} # If a parser schema is provided, instantiate the parsers and collect # the raw HDF5 products they require. In staged mode, also validate # schema-level stage assignments and build the reader key-to-stage map. if schema is not None: if dtype is None: raise ValueError("An explicit `dtype` is required when using `schema`.") inferred_keys = [] for data_product, parser_cfg in schema.items(): parser_cfg = dict(parser_cfg) parser_stage = parser_cfg.pop("stage", stage) parser = instantiate( PARSER_DICT, parser_cfg, alt_name="parser", dtype=dtype ) self.parsers[data_product] = parser for key in parser.tree_keys: if key not in inferred_keys: inferred_keys.append(key) if staged and parser_stage is not None: existing_stage = reader_stage_map.get(key) if ( existing_stage is not None and existing_stage != parser_stage ): raise ValueError( f"Conflicting staged HDF5 schema for raw product '{key}': " f"'{existing_stage}' vs '{parser_stage}'." ) reader_stage_map[key] = parser_stage if self.keys is None: self.keys = set(inferred_keys) else: self.keys.update(inferred_keys) # Initialize the appropriate reader backend if staged: self.reader = StageHDF5Reader( stage=stage, stage_map=reader_stage_map, keys=tuple(self.keys) if self.keys is not None else None, **kwargs, ) else: self.reader = HDF5Reader(**kwargs) # Initialize the augmenter self.build_augmenter(augment)
def __len__(self) -> int: """Return the number of entries exposed by the backend reader.""" return len(self.reader) def __getitem__(self, idx: int) -> DataDict: """Return one cached dataset entry. Parameters ---------- idx : int Dataset entry index. Returns ------- dict Either the raw reader output (optionally filtered to ``keys``) or a parsed dictionary containing standard metadata plus the products described by ``schema``. """ result = self.reader[idx] if self.keys is not None: keep = self.keys.union(self._index_keys).union(self._source_keys) result = {key: val for key, val in result.items() if key in keep} for key in self.skip_keys: result.pop(key, None) if not self.parsers: return self.apply_augmenter(result) parsed = self.metadata_dict(result) for name, parser in self.parsers.items(): try: parsed[name] = parser(result) except Exception as err: logger.error("Failed to produce %s using %s", name, parser) raise err return self.apply_augmenter(parsed) @property def data_types(self) -> dict[str, str]: """Return the collate type for each exposed HDF5 product. Returns ------- dict[str, str] Mapping from dataset output key to collate type. """ data_types = self.index_data_types() if self.parsers: for name, parser in self.parsers.items(): data_types[name] = parser.returns elif self._data_types is not None: data_types.update(self._data_types) else: sample = self[0] if len(self) else {} for key in sample: if key not in data_types: data_types[key] = "list" return data_types @property def overlay_methods(self) -> dict[str, str]: """Return the overlay method for each exposed HDF5 product. Returns ------- dict[str, str] Mapping from dataset output key to overlay strategy. """ overlay_methods = self.index_overlay_methods() if self.parsers: for name, parser in self.parsers.items(): overlay_methods[name] = parser.overlay if self._overlay_methods is not None: overlay_methods.update(self._overlay_methods) return overlay_methods @property def data_keys(self) -> tuple[str, ...]: """Return the names of all data products exposed by the dataset. Returns ------- tuple[str, ...] Ordered tuple of metadata and parser-product keys. """ if self.parsers: return (*self._index_keys, *self._source_keys, *self.parsers.keys()) return tuple(self.data_types.keys())