"""Stage-aware HDF5 cache reader."""
from __future__ import annotations
from collections.abc import Mapping, Sequence
from warnings import warn
import h5py
import numpy as np
import yaml
from yaml.parser import ParserError
from spine.utils.logger import logger
from .hdf5 import HDF5Reader
__all__ = ["StageHDF5Reader"]
[docs]
class StageHDF5Reader(HDF5Reader):
"""Read products stored under one or more stage groups in a cache file.
The reader exposes the same event-level interface as :class:`HDF5Reader`,
but resolves requested product keys under ``/stages/<stage>`` instead of
the flat top-level namespace.
"""
name = "stage_hdf5"
[docs]
def __init__(
self,
stage: str | None = None,
file_keys: str | list[str] | None = None,
file_list: str | None = None,
limit_num_files: int | None = None,
max_print_files: int = 10,
n_entry: int | None = None,
n_skip: int | None = None,
entry_list: list[int] | None = None,
skip_entry_list: list[int] | None = None,
build_classes: bool = True,
skip_unknown_attrs: bool = False,
allow_missing: bool = False,
keep_open: bool = True,
swmr: bool = False,
ignore_incomplete: bool = False,
stage_map: Mapping[str, str] | None = None,
keys: Sequence[str] | None = None,
) -> None:
"""Initialize the stage-cache reader.
Parameters
----------
stage : str, optional
Default stage from which to load products. If omitted, keys are
searched across all stages and must resolve uniquely.
stage_map : mapping, optional
Explicit map from product keys to stage names. This overrides the
default stage on a per-product basis.
keys : sequence[str], optional
Product keys that should be exposed by the reader. If omitted, all
products from the selected stage(s) are exposed.
file_keys, file_list, limit_num_files, max_print_files, n_entry, n_skip, \
entry_list, skip_entry_list, build_classes, skip_unknown_attrs, \
allow_missing, keep_open, swmr, ignore_incomplete
See :class:`spine.io.read.HDF5Reader`. These options control file
discovery, entry selection, object reconstruction, file-handle
lifetime, and incomplete-stage handling.
"""
self.stage = stage
self.stage_map = dict(stage_map or {})
self.requested_keys = tuple(keys) if keys is not None else None
self.process_file_paths(file_keys, file_list, limit_num_files, max_print_files)
self.keep_open = keep_open
self.swmr = swmr
self.ignore_incomplete = ignore_incomplete
self._handle_pid = None
self._file_handles = {}
self._resolved_products: dict[int, dict[str, str]] = {}
self._source_info: dict[int, dict[str, object]] = {}
file_index = []
self.num_entries = 0
self.file_offsets = np.empty(len(self.file_paths), dtype=np.int64)
for i, path in enumerate(self.file_paths):
with h5py.File(path, "r") as in_file:
self._source_info[i] = self.read_source_info(in_file)
product_stage_map = self.resolve_product_stages(in_file, path)
self._resolved_products[i] = product_stage_map
stage_lengths = self.get_stage_lengths(in_file, path, product_stage_map)
num_entries = self.validate_stage_lengths(path, stage_lengths)
file_index.append(i * np.ones(num_entries, dtype=np.int64))
self.file_offsets[i] = self.num_entries
self.num_entries += num_entries
logger.info("Total number of entries in the file(s): %d\n", self.num_entries)
self.file_index = (
np.concatenate(file_index) if file_index else np.empty(0, dtype=np.int64)
)
self.run_info = None
self.run_map = None
self.process_entry_list(
n_entry,
n_skip,
entry_list,
skip_entry_list,
None,
None,
allow_missing,
)
self.build_classes = build_classes
self.skip_unknown_attrs = skip_unknown_attrs
self.cfg = self.process_cfg()
self.version = self.process_version()
[docs]
@staticmethod
def get_stages_group(in_file: h5py.File, path: str) -> h5py.Group:
"""Return the top-level ``stages`` group.
Parameters
----------
in_file : h5py.File
Open cache file handle.
path : str
File path used to build informative error messages.
"""
assert "stages" in in_file, f"Stage-cache file '{path}' is missing 'stages'."
stages = in_file["stages"]
assert isinstance(
stages, h5py.Group
), f"'stages' in '{path}' must be a group, got {type(stages)}."
return stages
[docs]
@classmethod
def get_stage_group(cls, in_file: h5py.File, path: str, stage: str) -> h5py.Group:
"""Return one named stage group.
Parameters
----------
in_file : h5py.File
Open cache file handle.
path : str
File path used to build informative error messages.
stage : str
Name of the stage group to load under ``/stages``.
"""
stages = cls.get_stages_group(in_file, path)
assert (
stage in stages
), f"Stage-cache file '{path}' does not contain stage '{stage}'."
stage_group = stages[stage]
assert isinstance(
stage_group, h5py.Group
), f"Stage '{stage}' in '{path}' must be a group, got {type(stage_group)}."
return stage_group
[docs]
@staticmethod
def read_source_info(in_file: h5py.File) -> dict[str, object]:
"""Return top-level source provenance stored in the cache file.
Parameters
----------
in_file : h5py.File
Open cache file handle.
Returns
-------
dict[str, object]
File-level provenance dictionary. If the cache predates the source
group convention, this returns an empty dictionary.
"""
if "source" not in in_file:
return {}
source_group = in_file["source"]
assert isinstance(
source_group, h5py.Group
), f"Expected 'source' to be a group, got {type(source_group)}."
file_name = source_group.attrs["file_name"]
if isinstance(file_name, bytes):
file_name = file_name.decode()
return {
"source_file_name": file_name,
"source_file_size": int(source_group.attrs["file_size"]),
"source_file_mtime_ns": int(source_group.attrs["file_mtime_ns"]),
}
[docs]
def list_stage_keys(self, stage_group: h5py.Group) -> tuple[str, ...]:
"""List product keys stored in one stage group.
This excludes the administrative ``info`` and ``events`` members.
"""
return tuple(key for key in stage_group.keys() if key not in {"info", "events"})
[docs]
def resolve_product_stages(self, in_file: h5py.File, path: str) -> dict[str, str]:
"""Resolve each requested product key to one stage.
Resolution order is:
1. explicit ``stage_map`` entry for the key
2. dataset-level default ``stage``
3. automatic discovery across all available stages
Automatic discovery requires a unique match. If the same product name
appears in multiple stages, the caller must disambiguate it.
"""
if (
self.stage is not None
and self.requested_keys is None
and not self.stage_map
):
stage_group = self.get_stage_group(in_file, path, self.stage)
self.check_stage_complete(stage_group, path, self.stage)
return {key: self.stage for key in self.list_stage_keys(stage_group)}
stages = self.get_stages_group(in_file, path)
required_keys = (
tuple(self.requested_keys)
if self.requested_keys is not None
else tuple(
key
for stage_name in stages
for key in self.list_stage_keys(
self.get_stage_group(in_file, path, stage_name)
)
)
)
resolved: dict[str, str] = {}
for key in required_keys:
if key in self.stage_map:
stage_name = self.stage_map[key]
stage_group = self.get_stage_group(in_file, path, stage_name)
self.check_stage_complete(stage_group, path, stage_name)
if key not in stage_group:
raise KeyError(
f"Requested product '{key}' does not exist in stage "
f"'{stage_name}' of '{path}'."
)
resolved[key] = stage_name
continue
if self.stage is not None:
stage_group = self.get_stage_group(in_file, path, self.stage)
self.check_stage_complete(stage_group, path, self.stage)
if key not in stage_group:
raise KeyError(
f"Requested product '{key}' does not exist in stage "
f"'{self.stage}' of '{path}'."
)
resolved[key] = self.stage
continue
candidates = []
for stage_name in stages:
stage_group = self.get_stage_group(in_file, path, stage_name)
if key in stage_group:
self.check_stage_complete(stage_group, path, stage_name)
candidates.append(stage_name)
if not candidates:
raise KeyError(
f"Could not find requested product '{key}' in any stage of '{path}'."
)
if len(candidates) > 1:
raise ValueError(
f"Requested product '{key}' appears in multiple stages of '{path}': "
f"{candidates}. Specify its stage explicitly."
)
resolved[key] = candidates[0]
return resolved
[docs]
def check_stage_complete(
self, stage_group: h5py.Group, path: str, stage: str
) -> None:
"""Reject incomplete stages unless explicitly allowed.
Parameters
----------
stage_group : h5py.Group
Resolved stage group.
path : str
Cache file path.
stage : str
Stage name used in the error message.
"""
if (
"info" in stage_group
and "complete" in stage_group["info"].attrs
and not stage_group["info"].attrs["complete"]
and not self.ignore_incomplete
):
raise RuntimeError(
f"Stage '{stage}' in '{path}' is marked incomplete. "
"Pass ignore_incomplete=True to override."
)
[docs]
def get_stage_lengths(
self, in_file: h5py.File, path: str, product_stage_map: Mapping[str, str]
) -> dict[str, int]:
"""Return the event count of each referenced stage.
Parameters
----------
in_file : h5py.File
Open cache file handle.
path : str
Cache file path.
product_stage_map : mapping
Mapping from requested raw product key to resolved stage name.
"""
stage_lengths: dict[str, int] = {}
for stage_name in set(product_stage_map.values()):
stage_group = self.get_stage_group(in_file, path, stage_name)
events = stage_group["events"]
assert isinstance(
events, h5py.Dataset
), f"Stage '{stage_name}' in '{path}' is missing an 'events' dataset."
stage_lengths[stage_name] = len(events)
return stage_lengths
[docs]
@staticmethod
def validate_stage_lengths(path: str, stage_lengths: Mapping[str, int]) -> int:
"""Ensure all referenced stages in one file have the same length.
Returns
-------
int
Shared number of entries across all referenced stages.
"""
lengths = list(stage_lengths.values())
if not lengths:
return 0
if any(length != lengths[0] for length in lengths[1:]):
raise ValueError(
f"Referenced stages in '{path}' do not expose the same number of entries: "
f"{dict(stage_lengths)}."
)
return lengths[0]
[docs]
def process_cfg(self) -> dict[str, object] | None:
"""Return the stored configuration for the referenced stage(s), if any.
Returns
-------
dict or object or None
Parsed YAML configuration stored under stage metadata. A single
stage yields its parsed object directly; multiple stages return a
mapping from stage name to parsed object.
"""
with h5py.File(self.file_paths[0], "r") as in_file:
stage_names = sorted(set(self._resolved_products[0].values()))
cfg_map: dict[str, object | None] = {}
for stage_name in stage_names:
stage_group = self.get_stage_group(
in_file, self.file_paths[0], stage_name
)
if "info" not in stage_group or "cfg" not in stage_group["info"].attrs:
cfg_map[stage_name] = None
continue
cfg_str = stage_group["info"].attrs["cfg"]
try:
assert isinstance(cfg_str, str), "'cfg' attribute is not a string."
cfg_map[stage_name] = yaml.safe_load(cfg_str)
except ParserError:
warn(
"Parsing stage configuration failed, returning None for "
f"stage '{stage_name}'."
)
cfg_map[stage_name] = None
if len(cfg_map) == 1:
return next(iter(cfg_map.values()))
return cfg_map
[docs]
def get(self, idx: int) -> dict[str, object]:
"""Return one merged cache entry.
Parameters
----------
idx : int
Dataset entry index in the staged cache.
Returns
-------
dict[str, object]
Raw merged event dictionary containing standard metadata plus all
requested stage products for the selected entry.
"""
if idx < 0 or idx >= len(self):
raise IndexError(
f"Index {idx} out of bounds for dataset of size {len(self)}."
)
file_idx = self.get_file_index(idx)
entry_idx = self.get_file_entry_index(idx)
data: dict[str, object] = {
"file_index": file_idx,
"file_entry_index": entry_idx,
}
data.update(self._source_info.get(file_idx, {}))
product_stage_map = self._resolved_products[file_idx]
in_file, should_close = self._open_file(file_idx)
try:
for stage_name in sorted(set(product_stage_map.values())):
stage_group = self.get_stage_group(
in_file, self.file_paths[file_idx], stage_name
)
events = stage_group["events"]
assert isinstance(
events, h5py.Dataset
), f"Stage '{stage_name}' is missing an 'events' dataset."
event = events[entry_idx]
names = getattr(getattr(event, "dtype", None), "names", None)
if names is None:
raise ValueError(
f"Stage '{stage_name}' event entry does not have named fields."
)
for key in names:
if product_stage_map.get(key) != stage_name:
continue
self.load_key(stage_group, event, data, key)
finally:
if should_close:
in_file.close()
data["index"] = idx
return data