Source code for spine.io.write.stage_hdf5

"""Stage-aware HDF5 cache writer."""

from __future__ import annotations

import os
from collections import defaultdict
from dataclasses import dataclass
from typing import Any

import h5py
import numpy as np
import yaml

from spine.version import __version__

from .hdf5 import HDF5Writer

__all__ = ["StageHDF5Writer"]


[docs] class StageHDF5Writer(HDF5Writer): """Write additive stage caches to one HDF5 file per source file. This writer is intended for sequential cache materialization workflows where each processing stage writes a self-contained set of products under ``/stages/<stage>`` while preserving previously completed stages. Cache files are split by source-file provenance automatically. Unlike :class:`HDF5Writer`, this class does not use one flat product namespace for the entire file. Each stage owns its own ``events`` dataset and product datasets, which allows failed later stages to be rewritten without modifying earlier completed stages. """ name = "stage_hdf5" _file_source_keys = {"source_file_name", "source_file_size", "source_file_mtime_ns"}
[docs] @dataclass class StageState: """In-memory description of one stage schema. The regular :class:`HDF5Writer` stores one flat schema for the whole file. Stage caches need one schema per stage, so this small dataclass carries the state required to keep appending consistently to a given stage group. """ keys: set[str] type_dict: dict[str, HDF5Writer.DataFormat] object_dtypes: list[list[tuple[str, type]]] event_dtype: np.dtype | list[tuple[str, Any]] | None = None entries_since_flush: int = 0
[docs] def __init__( self, file_name: str | None = None, directory: str | None = None, prefix: str | list[str] | None = None, suffix: str = "stage", stage: str | None = None, keys: list[str] | None = None, skip_keys: list[str] | None = None, split: bool = True, lite: bool = False, keep_open: bool = True, flush_frequency: int | None = None, overwrite: bool = False, ) -> None: """Initialize the stage-cache writer. Parameters ---------- file_name : str, optional Output cache file name. When ``directory`` is not provided, this path also provides the parent directory for source-derived cache files. If omitted, the base output path is built from ``prefix`` and ``suffix`` using the same naming rules as :class:`HDF5Writer`. directory : str, optional Output directory used for all source-derived cache files. When provided, it overrides the directory encoded in ``file_name``. prefix : str or list[str], optional Input file prefix used to derive the base staged-cache file name when ``file_name`` is not specified. suffix : str, default "stage" Suffix appended to source file basenames when deriving split cache file names. stage : str, optional Stage name to use for the standard driver-facing writer contract. When provided, :meth:`__call__` writes to this stage and :meth:`finalize` marks it complete. If omitted, use :meth:`write_stage` and :meth:`finalize_stage` directly. keys : list[str], optional List of data-product keys to persist in each stage. If omitted, store every product present in the batch apart from administrative source-file metadata. skip_keys : list[str], optional List of data-product keys to exclude from each stage. split : bool, default True Stage caches are always written one file per source file. This argument is accepted for compatibility with generic writer configuration, but it must remain `True`. lite : bool, default False If `True`, store lite object representations when applicable keep_open : bool, default True If `True`, keep one append handle open per process flush_frequency : int, optional Flush the file after this many appended entries per stage. If `None`, only flush on explicit requests or close/finalize. overwrite : bool, default False If `True`, replace the entire cache file if it already exists. """ self._handle_pid: int | None = None self._handles: dict[str, h5py.File] = {} if not split: raise ValueError( "StageHDF5Writer requires `split=True` because staged caches " "are written one file per source file." ) name_split = split if prefix is not None else False self.file_name = self.get_file_names( file_name=file_name, prefix=prefix, suffix=suffix, split=name_split, directory=directory, )[0] self._route_by_source = isinstance(prefix, list) and len(prefix) > 1 self.directory = directory self.suffix = suffix self.stage = stage self.lite = lite self.keep_open = keep_open self.flush_frequency = flush_frequency self.source_info: dict[str, Any] | None = None self.keys = set(keys) if keys is not None else None self.skip_keys = skip_keys self.dummy_ds = None self.append = True self.split = True self.ready = False self.object_dtypes = [] self.type_dict = None self.event_dtype = None self._cfg: dict[str, Any] | None = None self._initialized_files: set[str] = set() self._stage_states: dict[str, StageHDF5Writer.StageState] = {} self._completed_stages: dict[str, set[str]] = defaultdict(set) self._known_files: set[str] = set() if overwrite and os.path.exists(self.file_name): os.remove(self.file_name)
def __call__(self, data: dict[str, Any], cfg: dict[str, Any] | None = None) -> None: """Append one batch to the configured stage. Parameters ---------- data : dict Dictionary of data products cfg : dict, optional Dictionary containing the complete SPINE configuration """ if self.stage is None: raise RuntimeError( "StageHDF5Writer requires a configured `stage` to be used " "through the standard writer call path." ) self.write_stage(self.stage, data, cfg=cfg)
[docs] def finalize(self) -> None: """Mark the configured stage as complete across touched cache files.""" if self.stage is None: raise RuntimeError( "StageHDF5Writer requires a configured `stage` to finalize " "through the standard writer interface." ) self.finalize_stage(self.stage)
[docs] def close(self) -> None: """Close any persistent cache-file handles. This only affects handles cached in the current process and may be called repeatedly. """ for handle in self._handles.values(): try: handle.close() except Exception: pass self._handles = {} self._handle_pid = None
def _check_handle_pid(self) -> None: """Ensure persistent writer handles remain process-local. Stage caches are not safe to append to through a writer instance that has crossed a process boundary. This method enforces the same single-process handle ownership contract as :class:`HDF5Writer`. """ current_pid = os.getpid() if self._handle_pid is None: self._handle_pid = current_pid return if self._handle_pid != current_pid: raise RuntimeError( "StageHDF5Writer file handles are process-local and cannot be " "reused across process boundaries." ) def _open_handle(self, file_path: str) -> tuple[h5py.File, bool]: """Return an appendable cache-file handle for one output path. Returns ------- tuple[h5py.File, bool] Open HDF5 handle and a flag indicating whether the caller is responsible for closing it immediately. """ self._ensure_file(file_path) if not self.keep_open: return h5py.File(file_path, "a"), True self._check_handle_pid() handle = self._handles.get(file_path) if handle is None or not handle.id.valid: handle = h5py.File(file_path, "a") self._handles[file_path] = handle return handle, False def _ensure_file(self, file_path: str) -> None: """Initialize one output cache file structure on first use. The top-level administrative groups are created lazily because staged cache files are derived from source provenance and may not all be touched by every write call. """ if file_path in self._initialized_files: return mode = "a" if os.path.exists(file_path) else "w" if self.keep_open: self._check_handle_pid() out_file = h5py.File(file_path, mode) self._handles[file_path] = out_file else: out_file = h5py.File(file_path, mode) try: if "info" not in out_file: out_file.create_group("info") out_file["info"].attrs["version"] = __version__ out_file["info"].attrs["format"] = self.name if "stages" not in out_file: out_file.create_group("stages") finally: if not self.keep_open: out_file.close() self._initialized_files.add(file_path) self._known_files.add(file_path)
[docs] def get_batch_source_info(self, data: dict[str, Any]) -> dict[str, Any]: """Extract cache-file source provenance from one normalized batch. Parameters ---------- data : dict Normalized batch dictionary prepared for writing. Returns ------- dict[str, Any] File-level source identity stored under the cache file's top-level ``/source`` group. """ required = ("source_file_name", "source_file_size", "source_file_mtime_ns") missing = [key for key in required if key not in data] if missing: raise KeyError( "StageHDF5Writer requires reader-provided source provenance. " f"Missing keys: {missing}." ) values = {} for key in required: value = data[key] if np.isscalar(value): values[key] = value.item() if hasattr(value, "item") else value continue array = np.asarray(value) if array.ndim == 0: values[key] = array.item() continue if len(array) == 0: raise ValueError(f"Source provenance key '{key}' is empty.") first = array[0].item() if hasattr(array[0], "item") else array[0] if any( (el.item() if hasattr(el, "item") else el) != first for el in array[1:] ): raise ValueError( "StageHDF5Writer expects one source file per cache file. " f"Batch key '{key}' contains multiple values." ) values[key] = first return { "file_name": values["source_file_name"], "file_size": int(values["source_file_size"]), "file_mtime_ns": int(values["source_file_mtime_ns"]), }
[docs] def ensure_source_group( self, out_file: h5py.File, data: dict[str, Any], file_path: str ) -> None: """Create or validate the top-level source provenance group. This enforces the one-cache-file-per-source-file contract. If a later stage attempts to write into an existing cache file with mismatched source provenance, the writer raises immediately. """ source_info = self.get_batch_source_info(data) self.source_info = source_info if "source" not in out_file: source_group = out_file.create_group("source") for key, value in source_info.items(): source_group.attrs[key] = value return source_group = out_file["source"] assert isinstance( source_group, h5py.Group ), f"Expected 'source' to be a group, got {type(source_group)}." for key, value in source_info.items(): cached_value = source_group.attrs.get(key) if cached_value != value: raise RuntimeError( f"Cache source mismatch for '{file_path}': '{key}' differs " f"({cached_value!r} != {value!r})." )
def _prepare_batch(self, data: dict[str, Any]) -> tuple[dict[str, Any], int]: """Normalize one batch for stage writing. This mirrors the flat HDF5 writer behavior by accepting either scalar single-entry payloads or already batched payloads and returning a uniform list-like representation. """ data = self.with_source_provenance(data) if np.isscalar(data["index"]): for key in data: data[key] = [data[key]] return data, 1 return data, len(data["index"]) def _create_stage_state( self, stage: str, data: dict[str, Any] ) -> StageHDF5Writer.StageState: """Infer the schema of one stage from the first written batch. Parameters ---------- stage : str Stage name whose schema is being initialized. data : dict Normalized batch dictionary used as the schema template. """ keys = self.get_stored_keys(data) if "source_file_entry_index" in data: keys.add("source_file_entry_index") keys.difference_update(self._file_source_keys) type_dict, object_dtypes = self.get_data_types(data, keys) state = self.StageState( keys=keys, type_dict=type_dict, object_dtypes=object_dtypes ) self._stage_states[stage] = state return state
[docs] def get_output_path( self, source_info: dict[str, Any], multiple_sources: bool = False ) -> str: """Resolve the cache-file path for one source file. Parameters ---------- source_info : dict File-level source identity returned by :meth:`get_batch_source_info`. multiple_sources : bool, default False If `True`, derive one output path from the source file basename. Otherwise reuse ``self.file_name`` directly unless this writer is already in source-routed mode. """ if not (self._route_by_source or multiple_sources): if self.directory is None: return self.file_name return os.path.join(self.directory, os.path.basename(self.file_name)) dir_name = ( self.directory if self.directory is not None else os.path.dirname(self.file_name) ) base_name = os.path.splitext(str(source_info["file_name"]))[0] return os.path.join(dir_name, f"{base_name}_{self.suffix}.h5")
[docs] def split_batch_by_source( self, data: dict[str, Any] ) -> list[tuple[str, dict[str, Any], dict[str, Any]]]: """Split one normalized batch into one subset per source file. Returns ------- list[tuple[str, dict, dict]] One tuple per source file containing the resolved output file path, the batch subset that belongs to that source file, and the file-level source provenance dictionary. """ required = ("source_file_name", "source_file_size", "source_file_mtime_ns") for key in required: if key not in data: raise KeyError( "StageHDF5Writer requires reader-provided source provenance. " f"Missing key: {key}." ) batch_size = len(data["index"]) groups: dict[tuple[Any, Any, Any], list[int]] = defaultdict(list) for batch_id in range(batch_size): groups[ ( data["source_file_name"][batch_id], data["source_file_size"][batch_id], data["source_file_mtime_ns"][batch_id], ) ].append(batch_id) multiple_sources = len(groups) > 1 self._route_by_source = self._route_by_source or multiple_sources result = [] for (file_name, file_size, file_mtime_ns), batch_ids in groups.items(): source_info = { "file_name": ( file_name.item() if hasattr(file_name, "item") else file_name ), "file_size": int( file_size.item() if hasattr(file_size, "item") else file_size ), "file_mtime_ns": int( file_mtime_ns.item() if hasattr(file_mtime_ns, "item") else file_mtime_ns ), } subset = {} for key, value in data.items(): if np.isscalar(value): subset[key] = value continue subset[key] = [value[i] for i in batch_ids] result.append( ( self.get_output_path(source_info, multiple_sources), subset, source_info, ) ) return result
def _ensure_stage_group( self, out_file: h5py.File, file_path: str, stage: str, state: StageState, cfg: dict[str, Any] | None = None, attrs: dict[str, Any] | None = None, overwrite_stage: bool = False, ) -> h5py.Group: """Create or fetch one stage group. Parameters ---------- out_file : h5py.File Open cache-file handle. file_path : str Output cache-file path used for error messages and bookkeeping. stage : str Stage name to create or reopen. state : StageState Inferred schema state for the stage. cfg : dict, optional Stage configuration to serialize into metadata. attrs : dict, optional Additional stage metadata attributes. overwrite_stage : bool, default False If `True`, delete any existing stage group and rebuild it. """ stages = out_file["stages"] assert isinstance(stages, h5py.Group), "'stages' must be an HDF5 group." if stage in stages and overwrite_stage: del stages[stage] self._completed_stages[file_path].discard(stage) if stage not in stages: stage_group = stages.create_group(stage) info = stage_group.create_group("info") info.attrs["complete"] = False if cfg is not None: info.attrs["cfg"] = yaml.dump(cfg) if attrs is not None: for key, value in attrs.items(): info.attrs[key] = value self.type_dict = state.type_dict self.event_dtype = state.event_dtype self.initialize_datasets(stage_group, state.type_dict) state.event_dtype = self.event_dtype return stage_group stage_group = stages[stage] assert isinstance( stage_group, h5py.Group ), f"Stage '{stage}' is expected to be a group, got {type(stage_group)}." if stage not in self._stage_states: raise RuntimeError( f"Stage '{stage}' already exists in '{self.file_name}'. Reopening and " "appending an existing stage across writer sessions is not supported " "in this first pass. Pass overwrite_stage=True to rebuild it." ) if "info" in stage_group and attrs is not None: for key, value in attrs.items(): stage_group["info"].attrs[key] = value if "info" in stage_group and cfg is not None: stage_group["info"].attrs["cfg"] = yaml.dump(cfg) stage_group["info"].attrs["complete"] = False return stage_group
[docs] def write_stage( self, stage: str, data: dict[str, Any], cfg: dict[str, Any] | None = None, attrs: dict[str, Any] | None = None, overwrite_stage: bool = False, ) -> None: """Append one batch of products to a named stage. Parameters ---------- stage : str Stage group name under ``/stages`` data : dict Dictionary of batched data products cfg : dict, optional Configuration to store alongside this stage attrs : dict, optional Additional stage metadata to persist under ``stage/info.attrs`` overwrite_stage : bool, default False If `True`, delete any existing stage group with the same name and rebuild it from the provided data. Notes ----- The input batch may span multiple source files. In that case the batch is partitioned by source provenance and written into one cache file per source file automatically. """ normalized, batch_size = self._prepare_batch(data) state = self._stage_states.get(stage) if state is None or overwrite_stage: state = self._create_stage_state(stage, normalized) original_keys = self.keys original_type_dict = self.type_dict original_object_dtypes = self.object_dtypes original_event_dtype = self.event_dtype try: for file_path, subset, _ in self.split_batch_by_source(normalized): out_file, should_close = self._open_handle(file_path) try: self.ensure_source_group(out_file, subset, file_path) stage_group = self._ensure_stage_group( out_file, file_path, stage, state, cfg=cfg, attrs=attrs, overwrite_stage=overwrite_stage, ) self.keys = state.keys self.type_dict = state.type_dict self.object_dtypes = state.object_dtypes self.event_dtype = state.event_dtype for batch_id in range(len(subset["index"])): self.append_entry(stage_group, subset, batch_id) state.event_dtype = self.event_dtype if self.flush_frequency is not None: state.entries_since_flush += len(subset["index"]) if state.entries_since_flush >= self.flush_frequency: out_file.flush() state.entries_since_flush = 0 finally: if should_close: out_file.close() finally: self.keys = original_keys self.type_dict = original_type_dict self.object_dtypes = original_object_dtypes self.event_dtype = original_event_dtype
[docs] def finalize_stage(self, stage: str) -> None: """Mark one stage as complete in every touched cache file. Parameters ---------- stage : str Stage name to finalize across all cache files written by this writer instance. """ for file_path in sorted(self._known_files): out_file, should_close = self._open_handle(file_path) try: stages = out_file["stages"] assert isinstance(stages, h5py.Group), "'stages' must be an HDF5 group." if stage not in stages: continue stage_group = stages[stage] stage_group["info"].attrs["complete"] = True out_file.flush() self._completed_stages[file_path].add(stage) finally: if should_close: out_file.close()
[docs] def list_stages(self) -> tuple[str, ...]: """Return the union of stage-group names across touched cache files. Returns ------- tuple[str, ...] Sorted tuple of unique stage names seen in all output cache files touched by this writer instance. """ stage_names: set[str] = set() for file_path in sorted(self._known_files): out_file, should_close = self._open_handle(file_path) try: stages = out_file["stages"] assert isinstance(stages, h5py.Group), "'stages' must be an HDF5 group." stage_names.update(stages.keys()) finally: if should_close: out_file.close() return tuple(sorted(stage_names))