Source code for spine.io.overlay

"""Module with methods to overlay multiple events."""

from __future__ import annotations

from collections.abc import Mapping, Sequence
from typing import Any
from warnings import warn

import numpy as np

from .parse.clean_data import clean_sparse_data
from .parse.data import (
    ParserEdgeIndex,
    ParserIndex,
    ParserIndexList,
    ParserObjectList,
    ParserTensor,
)

SampleDict = dict[str, Any]
BatchType = Sequence[SampleDict]

__all__ = ["Overlayer"]


[docs] class Overlayer: """Generic class to produce data overlays. This class supports three overlay modes: - `constant` uses a fixed multiplicity. - `uniform` samples multiplicities `M_i` from a uniform distribution and adjusts them so that, for a batch size `B`, `sum_i M_i = B`. - `poisson` samples multiplicities from a Poisson distribution with mean set by `multiplicity` and adjusts them the same way. """ # List of recognized overlay modes _modes = ("constant", "uniform", "poisson") def __init__( self, data_types: Mapping[str, str], methods: Mapping[str, str | None], multiplicity: int, mode: str = "constant", ) -> None: """Store the overlay parameters. Parameters ---------- data_types : mapping Types of data returned by the upstream parsers methods : mapping Maps data products onto overlay methods multiplicity : int Number of images to stack in the overlay mode : str, default 'constant' Overlay mode (one of 'constant', 'uniform' or 'poisson') """ # Check that the overlay mode is recognized if mode not in self._modes: raise ValueError( f"Overlay mode not recognized: {mode}. Must be one of {self._modes}." ) self.mode = mode # Check that multiplicity is sensible if multiplicity <= 0: raise ValueError( "Overlay multiplicity should be a non-zero positive integer." ) self.multiplicity = multiplicity # Store the data types and methods self.data_types = data_types self.methods = methods # Initialize row selection references for feature-only tensors self._row_selections = {} def __call__(self, batch: BatchType) -> list[SampleDict]: """Given a batch of data, provides an overlay batching and modifies the data in place to avoid indexing conflicts. Parameters ---------- batch : List[Dict] List of dictionaries of parsed information, one per event. Each dictionary matches one data key to one event-worth of parsed data. Returns ------- List[Dict] Overlayed list of dictionaries of parsed information, one per overlay. """ # Fetch the batch size, build an overlap map batch_size = len(batch) overlay_ids = self.get_assignments(batch_size) # Loop over the unique overlay indexes overlay_batch = [] _, splits = np.unique(overlay_ids, return_index=True) indexes = np.split(np.arange(batch_size), splits[1:]) for index in indexes: # Initialize row selection references for feature-only tensors self._row_selections = {} # If there is only a single index in the overlay, nothing to do if len(index) < 2: overlay_batch.append(batch[index[0]]) continue # Loop over the keys to overlay overlay = {} for key in self.get_overlay_order(batch, index): # Load up the data type and overlay method for this key data_type = self.data_types[key] # Dispatch and fill the overlay if data_type == "scalar": # Check whether scalars can be harmonized overlay[key] = self.merge_scalars(batch, key, index) elif data_type == "object": # Check that objects are compatible when overlaying overlay[key] = self.merge_objects(batch, key, index) elif data_type == "object_list": # Offset object list index attributes if needed overlay[key] = self.cat_objects(batch, key, index) elif data_type == "tensor": # Stack tensors, offset index columns if needed overlay[key] = self.stack_tensors(batch, key, index) # Add overlay to the batch overlay_batch.append(overlay) return overlay_batch
[docs] def get_overlay_order( self, batch: BatchType, index: np.ndarray | Sequence[int] ) -> list[str]: """Order reference tensors before tensors that depend on them. Feature-only tensors such as source IDs may be row-aligned to another tensor that drops duplicate coordinates during overlay. Processing references first lets them define the row selection reused by aligned tensors. Parameters ---------- batch : List[Dict] List of dictionaries of parsed information, one per event. Each dictionary matches one data key to one event-worth of parsed data. index : np.ndarray or Sequence[int] List of indexes to merge into an overlay Returns ------- List[str] List of keys in the order they should be processed for overlay. """ ordered = [] visited = set() visiting = set() def visit(key: str) -> None: if key in visited: return if key in visiting: raise ValueError(f"Cyclic overlay reference involving `{key}`.") ref_data = batch[index[0]][key] if isinstance(ref_data, ParserTensor) and ref_data.overlay_reference: reference = ref_data.overlay_reference if reference not in self.data_types: raise ValueError( f"Overlay reference `{reference}` for `{key}` is not " "available in the overlaid products." ) visiting.add(key) visit(reference) visiting.remove(key) visited.add(key) ordered.append(key) for key in self.data_types: visit(key) return ordered
[docs] def get_assignments(self, batch_size: int) -> np.ndarray: """Given a data product count, produce batch assignments. Parameters ---------- batch_size : int Number of entries in the batch Returns ------- np.ndarray Overlay ID assignments """ # Dispatch if self.mode == "constant": # Uniform multiplicity of overlays if batch_size % self.multiplicity != 0: warn( f"The overlay multiplicity ({self.multiplicity}) is not a " f"divider of the batch size ({batch_size}). The overlay " "multiplicity will not be uniform." ) overlay_ids = np.arange(batch_size, dtype=int) // self.multiplicity elif self.mode in ("poisson", "uniform"): # Sample from a Poisson distribution until it adds up to the batch size overlay_ids = np.empty(batch_size, dtype=int) idx, total = 0, 0 while total < batch_size: # Sample distribution if self.mode == "poisson": sample = np.random.poisson(self.multiplicity) else: sample = np.random.randint(1, self.multiplicity + 1) # Assign overlay indices if sample > 0: overlay_ids[total : total + sample] = idx idx += 1 total += sample else: raise ValueError(f"Overlay mode not recognized: {self.mode}.") # Return return overlay_ids
[docs] def merge_scalars( self, batch: BatchType, key: str, index: np.ndarray | Sequence[int] ) -> Any: """Merge scalars into one per overlay. Parameters ---------- batch : List[Dict] List of dictionaries of parsed information, one per event. Each dictionary matches one data key to one event-worth of parsed data. key : str Scalar data product key index : np.ndarray List of indexes to merge into an overlay Returns ------- object Single scalar for the batch """ scalars = np.array([batch[idx][key] for idx in index]) if self.methods[key] in ["first", "match"]: # Make sure that all scalars match within the overlay, if needed if self.methods[key] == "match": if not np.all(scalars[1:] == scalars[0]): raise ValueError( f"The scalar values to overlay do not match for {key}." ) return scalars[0] elif self.methods[key] == "sum": # Sum the values within each overlay return np.sum(scalars) elif self.methods[key] == "cat": # Concatenate the scalars in a single array (type change) return scalars else: if self.methods[key] is None: raise ValueError(f"Scalar overlay method not specified for {key}.") raise ValueError( f"Scalar overlay method not recognized: {self.methods[key]}. " "Must be one of 'first', 'match' or 'sum'." )
[docs] def merge_objects( self, batch: BatchType, key: str, index: np.ndarray | Sequence[int] ) -> Any: """Merge objects into one per overlay. Parameters ---------- batch : List[Dict] List of dictionaries of parsed information, one per event. Each dictionary matches one data key to one event-worth of parsed data. key : str Object data product key index : np.ndarray List of indexes to merge into an overlay Returns ------- object Single object for the batch """ objects = [batch[idx][key] for idx in index] if self.methods[key] in ["first", "match"]: # Make sure that all objects match within the overlay, if needed if self.methods[key] == "match": if not np.all([obj == objects[0] for obj in objects]): raise ValueError(f"The objects to overlay do not match for {key}.") return objects[0] elif self.methods[key] == "cat": # Concatenate the objects in a single list (type change) return ParserObjectList(objects, default=objects[0]) else: if self.methods[key] is None: raise ValueError(f"Object overlay method not specified for {key}.") raise ValueError( f"Object overlay method not recognized: {self.methods[key]}. " "Must be one of 'first' or 'match'." )
[docs] def cat_objects( self, batch: BatchType, key: str, index: np.ndarray | Sequence[int] ) -> ParserObjectList: """Concatenate object lists into one, offset index attributes if needed. Parameters ---------- batch : List[Dict] List of dictionaries of parsed information, one per event. Each dictionary matches one data key to one event-worth of parsed data. key : str Object list data product key index : np.ndarray List of indexes to merge into an overlay Returns ------- ObjList Concatenated obejct list """ # If the objects in the lists contain indexes, must offset them ref_list = batch[index[0]][key] shifts = None if len(ref_list.default.index_attrs) > 0: shifts = ref_list.index_shifts for idx in index[1:]: # Shift indexes in the objects obj_list = batch[idx][key] for obj in obj_list: obj.shift_indexes(shifts) # Increment shifts if not isinstance(shifts, dict): shifts += obj_list.index_shifts else: for attr in shifts: shifts[attr] += obj_list.index_shifts[attr] # Concatenate and return obj_list = [] for idx in index: obj_list.extend(batch[idx][key]) return ParserObjectList(obj_list, ref_list.default, shifts)
[docs] def stack_tensors( self, batch: BatchType, key: str, index: np.ndarray | Sequence[int] ) -> ParserTensor | ParserIndex | ParserIndexList | ParserEdgeIndex: """Stack parser payloads together across an overlay. Parameters ---------- batch : List[Dict] List of dictionaries of parsed information, one per event. Each dictionary matches one data key to one event-worth of parsed data. key : str Tensor data product key index : np.ndarray List of indexes to merge into an overlay Returns ------- ParserTensor or ParserIndex or ParserIndexList or ParserEdgeIndex Overlayed parser payload of the same logical type as the input. """ # Define a reference tensor ref_data = batch[index[0]][key] if isinstance(ref_data, ParserTensor): if ref_data.feats_only: return self.stack_feature_tensor_data(batch, key, index, ref_data) return self.stack_tensor_data(batch, key, index, ref_data) if isinstance(ref_data, ParserIndex): return self.stack_flat_index_data(batch, key, index, ref_data) if isinstance(ref_data, ParserIndexList): return self.stack_index_list_data(batch, key, index, ref_data) if isinstance(ref_data, ParserEdgeIndex): return self.stack_edge_index_data(batch, key, index, ref_data) raise TypeError( f"Unsupported parser payload type for `{key}`: {type(ref_data).__name__}." )
[docs] def stack_tensor_data( self, batch: BatchType, key: str, index: np.ndarray | Sequence[int], ref_data: ParserTensor, ) -> ParserTensor: """Overlay one tensor-like parser payload. Parameters ---------- batch : List[Dict] List of dictionaries of parsed information, one per event. Each dictionary matches one data key to one event-worth of parsed data. key : str Tensor data product key index : np.ndarray List of indexes to merge into an overlay ref_data : ParserTensor Reference tensor used to check metadata and index columns, and to preserve overlay metadata in the output. Returns ------- ParserTensor Overlayed parser tensor """ # Stack coordinates, if present coords = None if ref_data.coords is not None: # Check that the meta data matches between all images (it must) if not np.all([batch[idx][key].meta == ref_data.meta for idx in index]): raise ValueError("The metadata must match across all overlayed tensor.") coords = np.vstack([batch[idx][key].coords for idx in index]) # If required, offset indexes in the feature tensor index_shifts = None if ref_data.feat_index_cols is not None: # Apply offsets to the relevant columns only (mixed features) if ref_data.index_shifts is None: raise ValueError( "Index shifts must be provided if index columns are present." ) index_shifts = ref_data.index_shifts.copy() for idx in index[1:]: for i, col in enumerate(ref_data.feat_index_cols): mask = batch[idx][key].features[:, col] > -1 batch[idx][key].features[mask, col] += index_shifts[i] index_shifts += batch[idx][key].index_shifts # Stack features features = np.vstack([batch[idx][key].features for idx in index]) # If requested, remove rows corresponding to duplicate coordinates if ref_data.remove_duplicates: # Check that we have coordinates to make the check if coords is None: raise ValueError("Must provide coordinates to filter duplicates.") # Filter out duplicates, aggregating features when requested. selection_size = len(features) coords, features, selection = clean_sparse_data( coords, features, sum_cols=ref_data.feat_sum_cols, avg_cols=ref_data.feat_avg_cols, prec_col=ref_data.feat_prec_col, precedence=ref_data.precedence, return_index=True, ) self._row_selections[key] = (selection, selection_size) return self.build_parser_tensor(ref_data, features, coords, index_shifts)
[docs] def stack_feature_tensor_data( self, batch: BatchType, key: str, index: np.ndarray | Sequence[int], ref_data: ParserTensor, ) -> ParserTensor: """Overlay one feature-only parser payload. Parameters ---------- batch : List[Dict] List of dictionaries of parsed information, one per event. Each dictionary matches one data key to one event-worth of parsed data. key : str Tensor data product key index : np.ndarray List of indexes to merge into an overlay ref_data : ParserTensor Reference tensor used to check metadata and index columns, and to preserve overlay metadata in the output. Returns ------- ParserTensor Overlayed parser tensor with feature-only coordinates """ # Stack the features features = np.vstack([batch[idx][key].features for idx in index]) # Nothing to do if no duplicate removal is requested if not ref_data.remove_duplicates: return self.build_parser_tensor(ref_data, features, feats_only=True) # If it is requested, we need a reference tensor if not ref_data.overlay_reference: raise ValueError( f"Feature-only tensor `{key}` requires an `overlay_reference` " "to remove duplicates during overlay." ) # Feature-only tensors reuse the duplicate policy of their reference. row_selection, row_selection_size = self._row_selections.get( ref_data.overlay_reference, (None, None) ) if row_selection is None: # If the reference tensor has not been cleaned up, nothing to do # for the feature-only tensor either. return self.build_parser_tensor(ref_data, features, feats_only=True) if len(features) != row_selection_size: # If the sizes disagree, that is not allowed raise ValueError( f"Feature-only tensor `{key}` has {len(features)} rows before " f"overlay cleanup, but its reference `{ref_data.overlay_reference}` " f"has {row_selection_size} rows." ) return self.build_parser_tensor( ref_data, features[row_selection], feats_only=True )
[docs] @staticmethod def build_parser_tensor( ref_data: ParserTensor, features: np.ndarray, coords: np.ndarray | None = None, index_shifts: np.ndarray | None = None, feats_only: bool | None = None, ) -> ParserTensor: """Build a parser tensor while preserving overlay metadata. Parameters ---------- ref_data : ParserTensor Reference tensor used to check metadata and index columns, and to preserve overlay metadata in the output. features : np.ndarray Stacked features for the overlay coords : np.ndarray, optional Stacked coordinates for the overlay, if present in the reference tensor index_shifts : np.ndarray, optional Stacked index shifts for the overlay, if present in the reference tensor feats_only : bool, optional Whether the output tensor should be feature-only. If not provided, will be inferred from the reference tensor. """ return ParserTensor( coords=coords, features=features, meta=ref_data.meta, index_shifts=index_shifts, index_cols=ref_data.index_cols, sum_cols=ref_data.sum_cols, avg_cols=ref_data.avg_cols, prec_col=ref_data.prec_col, precedence=ref_data.precedence, feats_only=ref_data.feats_only if feats_only is None else feats_only, overlay_reference=ref_data.overlay_reference, )
[docs] def stack_flat_index_data( self, batch: BatchType, key: str, index: np.ndarray | Sequence[int], ref_data: ParserIndex, ) -> ParserIndex: """Overlay one flat index payload. Parameters ---------- batch : List[Dict] List of dictionaries of parsed information, one per event. Each dictionary matches one data key to one event-worth of parsed data. key : str Index data product key index : np.ndarray List of indexes to merge into an overlay ref_data : ParserIndex Reference index used to check metadata and preserve overlay metadata in the output. Returns ------- ParserIndex Overlayed index data. """ span = ref_data.span shifted_indexes = [batch[index[0]][key].features] for idx in index[1:]: shifted_index = batch[idx][key].features.copy() mask = shifted_index > -1 shifted_index[mask] += span shifted_indexes.append(shifted_index) span += batch[idx][key].span features = np.concatenate(shifted_indexes, axis=-1) return ParserIndex(features=features, span=span)
[docs] def stack_index_list_data( self, batch: BatchType, key: str, index: np.ndarray | Sequence[int], ref_data: ParserIndexList, ) -> ParserIndexList: """Overlay one jagged index-list payload. Parameters ---------- batch : List[Dict] List of dictionaries of parsed information, one per event. Each dictionary matches one data key to one event-worth of parsed data. key : str Index list data product key index : np.ndarray List of indexes to merge into an overlay ref_data : ParserIndexList Reference index list used to check metadata and preserve overlay metadata in the output. Returns ------- ParserIndexList Overlayed index list data. """ span = ref_data.span features = [entry.copy() for entry in batch[index[0]][key].features] single_counts = [] if ref_data.single_counts is not None: single_counts.extend(ref_data.single_counts.tolist()) else: single_counts.extend(len(entry) for entry in features) for idx in index[1:]: shifted_entries = [] for entry in batch[idx][key].features: shifted_entry = entry.copy() mask = shifted_entry > -1 shifted_entry[mask] += span shifted_entries.append(shifted_entry) features.extend(shifted_entries) if batch[idx][key].single_counts is not None: single_counts.extend(batch[idx][key].single_counts.tolist()) else: single_counts.extend(len(entry) for entry in shifted_entries) span += batch[idx][key].span return ParserIndexList( features=features, span=span, single_counts=np.asarray(single_counts, dtype=np.int64), )
[docs] def stack_edge_index_data( self, batch: BatchType, key: str, index: np.ndarray | Sequence[int], ref_data: ParserEdgeIndex, ) -> ParserEdgeIndex: """Overlay one edge-index payload. Parameters ---------- batch : List[Dict] List of dictionaries of parsed information, one per event. Each dictionary matches one data key to one event-worth of parsed data. key : str Edge index data product key index : np.ndarray List of indexes to merge into an overlay ref_data : ParserEdgeIndex Reference edge index used to check metadata and preserve overlay metadata in the output. Returns ------- ParserEdgeIndex Overlayed edge index data. """ span = ref_data.span shifted_indexes = [batch[index[0]][key].features] for idx in index[1:]: shifted_index = batch[idx][key].features.copy() mask = shifted_index > -1 shifted_index[mask] += span shifted_indexes.append(shifted_index) span += batch[idx][key].span features = np.concatenate(shifted_indexes, axis=-1) return ParserEdgeIndex(features=features, span=span)