Source code for spine.io.parse.hdf5.tensor

"""Lightweight parsers for cached HDF5 tensor products."""

from __future__ import annotations

from typing import Any

import numpy as np

from spine.constants import COORD_COLS_LO, VALUE_COL
from spine.data import Meta

from ..base import ParserBase
from ..data import ParserTensor

__all__ = [
    "HDF5TensorParser",
    "HDF5ClusterTensorParser",
    "HDF5FeatureTensorParser",
]


[docs] class HDF5TensorParser(ParserBase): """Build a sparse-tensor :class:`ParserTensor` from a cached HDF5 tensor.""" name = "tensor" returns = "tensor" def __init__( self, dtype: str, has_batch_col: bool = True, coord_start_col: int = COORD_COLS_LO, feature_start_col: int = VALUE_COL, meta_event: str | None = None, feature_cols: list[int] | tuple[int, ...] | np.ndarray | None = None, **kwargs: Any, ) -> None: """Initialize the cached tensor parser. Parameters ---------- dtype : str Floating-point dtype used by parser outputs. has_batch_col : bool, default True If `True`, the cached tensor is assumed to store a leading batch-id column before the coordinates. coord_start_col : int, default 1 Column index at which the coordinate block starts. feature_start_col : int, default 4 Column index at which the feature block starts. meta_event : str, optional HDF5 product name that stores the metadata object to inject into the returned :class:`ParserTensor`. feature_cols : sequence[int], optional Optional feature-column indices to keep after splitting coordinates and features. **kwargs : dict, optional Parser configuration forwarded to :class:`ParserBase`. """ super().__init__(dtype, meta_event=meta_event, **kwargs) self.has_batch_col = has_batch_col self.coord_start_col = coord_start_col self.feature_start_col = feature_start_col self.feature_cols = None if feature_cols is not None: self.feature_cols = np.asarray(feature_cols, dtype=np.int64) def __call__(self, trees: dict[str, Any]) -> ParserTensor: """Parse one cached entry into a sparse-tensor parser payload.""" return self.process(**self.get_input_data(trees))
[docs] def process( self, tensor_event: np.ndarray, meta_event: Meta | None = None ) -> ParserTensor: """Split one cached tensor into coordinates, features, and metadata.""" tensor = np.asarray(tensor_event, dtype=self.ftype) if tensor.ndim != 2: raise ValueError( "Cached sparse tensors must be 2D. " f"Received an array with shape {tensor.shape}." ) coords = tensor[:, self.coord_start_col : self.feature_start_col].astype( self.itype ) features = tensor[:, self.feature_start_col :] if self.feature_cols is not None: features = features[:, self.feature_cols] if self.has_batch_col and self.coord_start_col < 1: raise ValueError( "`coord_start_col` must be at least 1 when `has_batch_col=True`." ) return ParserTensor(coords=coords, features=features, meta=meta_event)
[docs] class HDF5ClusterTensorParser(HDF5TensorParser): """Build a cluster-label :class:`ParserTensor` from cached HDF5 tensors.""" name = "cluster_tensor" def __init__( self, dtype: str, index_cols: list[int] | tuple[int, ...] | np.ndarray | None = None, sum_cols: list[int] | tuple[int, ...] | np.ndarray | None = None, avg_cols: list[int] | tuple[int, ...] | np.ndarray | None = None, prec_col: int | None = None, precedence: list[int] | tuple[int, ...] | np.ndarray | None = None, remove_duplicates: bool = True, **kwargs: Any, ) -> None: """Initialize the cached cluster-tensor parser. Parameters ---------- dtype : str Floating-point dtype used by parser outputs. index_cols : sequence[int], optional Feature columns that carry indices and should be shifted when collating batches. sum_cols : sequence[int], optional Feature columns that should be summed when duplicate coordinates are merged. avg_cols : sequence[int], optional Feature columns that should be averaged when duplicate coordinates are merged. prec_col : int, optional Feature column used to resolve duplicate-coordinate precedence. precedence : sequence[int], optional Ordering applied to ``prec_col`` when duplicate coordinates are merged. remove_duplicates : bool, default True If `True`, mark the returned parser tensor for duplicate removal. **kwargs : dict, optional Tensor-parser configuration forwarded to :class:`HDF5TensorParser`. """ super().__init__(dtype, **kwargs) self.index_cols = None if index_cols is None else np.asarray(index_cols) self.sum_cols = None if sum_cols is None else np.asarray(sum_cols) self.avg_cols = None if avg_cols is None else np.asarray(avg_cols) self.prec_col = prec_col self.precedence = None if precedence is None else np.asarray(precedence) self.remove_duplicates = remove_duplicates
[docs] def process( self, tensor_event: np.ndarray, meta_event: Meta | None = None ) -> ParserTensor: """Split one cached cluster tensor and restore cluster parser semantics.""" tensor = super().process(tensor_event=tensor_event, meta_event=meta_event) tensor.index_cols = self.index_cols tensor.sum_cols = self.sum_cols tensor.avg_cols = self.avg_cols tensor.prec_col = self.prec_col tensor.precedence = self.precedence tensor.remove_duplicates = self.remove_duplicates return tensor
[docs] class HDF5FeatureTensorParser(ParserBase): """Build a feature-only :class:`ParserTensor` from a cached HDF5 array.""" name = "feature_tensor" returns = "tensor" def __init__( self, dtype: str, feature_cols: list[int] | tuple[int, ...] | np.ndarray | None = None, remove_duplicates: bool = False, overlay_reference: str | None = None, **kwargs: Any, ) -> None: """Initialize the cached feature-tensor parser. Parameters ---------- dtype : str Floating-point dtype used by parser outputs. feature_cols : sequence[int], optional Optional list of feature-column indices to keep from the cached tensor. When provided, this acts as a feature ablation step before the parser tensor is returned. remove_duplicates : bool, default False If `True`, require an ``overlay_reference`` when overlaying this feature-only tensor. overlay_reference : str, optional Product key whose duplicate-cleaning row selection should be used for this tensor when overlaying. **kwargs : dict, optional Parser configuration forwarded to :class:`ParserBase`. """ super().__init__(dtype, **kwargs) self.feature_cols = None if feature_cols is not None: self.feature_cols = np.asarray(feature_cols, dtype=np.int64) self.remove_duplicates = remove_duplicates self.overlay_reference = overlay_reference def __call__(self, trees: dict[str, Any]) -> ParserTensor: """Parse one cached entry into a feature-only parser tensor. Parameters ---------- trees : dict Mapping from configured HDF5 product names to cached entry values. Returns ------- ParserTensor Feature-only parser tensor built from the cached array. """ return self.process(**self.get_input_data(trees))
[docs] def process(self, tensor_event: np.ndarray) -> ParserTensor: """Cast one cached per-entry array into a feature-only parser tensor. Parameters ---------- tensor_event : np.ndarray Cached feature array for one event entry. Returns ------- ParserTensor Feature-only parser tensor with ``features`` cast to the parser float dtype. """ features = np.asarray(tensor_event, dtype=self.ftype) if self.feature_cols is not None: if features.ndim != 2: raise ValueError( "Feature ablation requires a 2D cached feature tensor. " f"Received an array with shape {features.shape}." ) features = features[:, self.feature_cols] return ParserTensor( features=features, remove_duplicates=self.remove_duplicates, feats_only=True, overlay_reference=self.overlay_reference, )