"""Module with methods to overlay multiple events."""
from __future__ import annotations
from collections.abc import Mapping, Sequence
from typing import Any
from warnings import warn
import numpy as np
from .parse.clean_data import clean_sparse_data
from .parse.data import (
ParserEdgeIndex,
ParserIndex,
ParserIndexList,
ParserObjectList,
ParserTensor,
)
SampleDict = dict[str, Any]
BatchType = Sequence[SampleDict]
__all__ = ["Overlayer"]
[docs]
class Overlayer:
"""Generic class to produce data overlays.
This class supports three overlay modes:
- `constant` uses a fixed multiplicity.
- `uniform` samples multiplicities `M_i` from a uniform distribution and
adjusts them so that, for a batch size `B`, `sum_i M_i = B`.
- `poisson` samples multiplicities from a Poisson distribution with mean
set by `multiplicity` and adjusts them the same way.
"""
# List of recognized overlay modes
_modes = ("constant", "uniform", "poisson")
def __init__(
self,
data_types: Mapping[str, str],
methods: Mapping[str, str | None],
multiplicity: int,
mode: str = "constant",
) -> None:
"""Store the overlay parameters.
Parameters
----------
data_types : mapping
Types of data returned by the upstream parsers
methods : mapping
Maps data products onto overlay methods
multiplicity : int
Number of images to stack in the overlay
mode : str, default 'constant'
Overlay mode (one of 'constant', 'uniform' or 'poisson')
"""
# Check that the overlay mode is recognized
if mode not in self._modes:
raise ValueError(
f"Overlay mode not recognized: {mode}. Must be one of {self._modes}."
)
self.mode = mode
# Check that multiplicity is sensible
if multiplicity <= 0:
raise ValueError(
"Overlay multiplicity should be a non-zero positive integer."
)
self.multiplicity = multiplicity
# Store the data types and methods
self.data_types = data_types
self.methods = methods
# Initialize row selection references for feature-only tensors
self._row_selections = {}
def __call__(self, batch: BatchType) -> list[SampleDict]:
"""Given a batch of data, provides an overlay batching and modifies
the data in place to avoid indexing conflicts.
Parameters
----------
batch : List[Dict]
List of dictionaries of parsed information, one per event. Each
dictionary matches one data key to one event-worth of parsed data.
Returns
-------
List[Dict]
Overlayed list of dictionaries of parsed information, one per overlay.
"""
# Fetch the batch size, build an overlap map
batch_size = len(batch)
overlay_ids = self.get_assignments(batch_size)
# Loop over the unique overlay indexes
overlay_batch = []
_, splits = np.unique(overlay_ids, return_index=True)
indexes = np.split(np.arange(batch_size), splits[1:])
for index in indexes:
# Initialize row selection references for feature-only tensors
self._row_selections = {}
# If there is only a single index in the overlay, nothing to do
if len(index) < 2:
overlay_batch.append(batch[index[0]])
continue
# Loop over the keys to overlay
overlay = {}
for key in self.get_overlay_order(batch, index):
# Load up the data type and overlay method for this key
data_type = self.data_types[key]
# Dispatch and fill the overlay
if data_type == "scalar":
# Check whether scalars can be harmonized
overlay[key] = self.merge_scalars(batch, key, index)
elif data_type == "object":
# Check that objects are compatible when overlaying
overlay[key] = self.merge_objects(batch, key, index)
elif data_type == "object_list":
# Offset object list index attributes if needed
overlay[key] = self.cat_objects(batch, key, index)
elif data_type == "tensor":
# Stack tensors, offset index columns if needed
overlay[key] = self.stack_tensors(batch, key, index)
# Add overlay to the batch
overlay_batch.append(overlay)
return overlay_batch
[docs]
def get_overlay_order(
self, batch: BatchType, index: np.ndarray | Sequence[int]
) -> list[str]:
"""Order reference tensors before tensors that depend on them.
Feature-only tensors such as source IDs may be row-aligned to another
tensor that drops duplicate coordinates during overlay. Processing
references first lets them define the row selection reused by aligned
tensors.
Parameters
----------
batch : List[Dict]
List of dictionaries of parsed information, one per event. Each
dictionary matches one data key to one event-worth of parsed data.
index : np.ndarray or Sequence[int]
List of indexes to merge into an overlay
Returns
-------
List[str]
List of keys in the order they should be processed for overlay.
"""
ordered = []
visited = set()
visiting = set()
def visit(key: str) -> None:
if key in visited:
return
if key in visiting:
raise ValueError(f"Cyclic overlay reference involving `{key}`.")
ref_data = batch[index[0]][key]
if isinstance(ref_data, ParserTensor) and ref_data.overlay_reference:
reference = ref_data.overlay_reference
if reference not in self.data_types:
raise ValueError(
f"Overlay reference `{reference}` for `{key}` is not "
"available in the overlaid products."
)
visiting.add(key)
visit(reference)
visiting.remove(key)
visited.add(key)
ordered.append(key)
for key in self.data_types:
visit(key)
return ordered
[docs]
def get_assignments(self, batch_size: int) -> np.ndarray:
"""Given a data product count, produce batch assignments.
Parameters
----------
batch_size : int
Number of entries in the batch
Returns
-------
np.ndarray
Overlay ID assignments
"""
# Dispatch
if self.mode == "constant":
# Uniform multiplicity of overlays
if batch_size % self.multiplicity != 0:
warn(
f"The overlay multiplicity ({self.multiplicity}) is not a "
f"divider of the batch size ({batch_size}). The overlay "
"multiplicity will not be uniform."
)
overlay_ids = np.arange(batch_size, dtype=int) // self.multiplicity
elif self.mode in ("poisson", "uniform"):
# Sample from a Poisson distribution until it adds up to the batch size
overlay_ids = np.empty(batch_size, dtype=int)
idx, total = 0, 0
while total < batch_size:
# Sample distribution
if self.mode == "poisson":
sample = np.random.poisson(self.multiplicity)
else:
sample = np.random.randint(1, self.multiplicity + 1)
# Assign overlay indices
if sample > 0:
overlay_ids[total : total + sample] = idx
idx += 1
total += sample
else:
raise ValueError(f"Overlay mode not recognized: {self.mode}.")
# Return
return overlay_ids
[docs]
def merge_scalars(
self, batch: BatchType, key: str, index: np.ndarray | Sequence[int]
) -> Any:
"""Merge scalars into one per overlay.
Parameters
----------
batch : List[Dict]
List of dictionaries of parsed information, one per event. Each
dictionary matches one data key to one event-worth of parsed data.
key : str
Scalar data product key
index : np.ndarray
List of indexes to merge into an overlay
Returns
-------
object
Single scalar for the batch
"""
scalars = np.array([batch[idx][key] for idx in index])
if self.methods[key] in ["first", "match"]:
# Make sure that all scalars match within the overlay, if needed
if self.methods[key] == "match":
if not np.all(scalars[1:] == scalars[0]):
raise ValueError(
f"The scalar values to overlay do not match for {key}."
)
return scalars[0]
elif self.methods[key] == "sum":
# Sum the values within each overlay
return np.sum(scalars)
elif self.methods[key] == "cat":
# Concatenate the scalars in a single array (type change)
return scalars
else:
if self.methods[key] is None:
raise ValueError(f"Scalar overlay method not specified for {key}.")
raise ValueError(
f"Scalar overlay method not recognized: {self.methods[key]}. "
"Must be one of 'first', 'match' or 'sum'."
)
[docs]
def merge_objects(
self, batch: BatchType, key: str, index: np.ndarray | Sequence[int]
) -> Any:
"""Merge objects into one per overlay.
Parameters
----------
batch : List[Dict]
List of dictionaries of parsed information, one per event. Each
dictionary matches one data key to one event-worth of parsed data.
key : str
Object data product key
index : np.ndarray
List of indexes to merge into an overlay
Returns
-------
object
Single object for the batch
"""
objects = [batch[idx][key] for idx in index]
if self.methods[key] in ["first", "match"]:
# Make sure that all objects match within the overlay, if needed
if self.methods[key] == "match":
if not np.all([obj == objects[0] for obj in objects]):
raise ValueError(f"The objects to overlay do not match for {key}.")
return objects[0]
elif self.methods[key] == "cat":
# Concatenate the objects in a single list (type change)
return ParserObjectList(objects, default=objects[0])
else:
if self.methods[key] is None:
raise ValueError(f"Object overlay method not specified for {key}.")
raise ValueError(
f"Object overlay method not recognized: {self.methods[key]}. "
"Must be one of 'first' or 'match'."
)
[docs]
def cat_objects(
self, batch: BatchType, key: str, index: np.ndarray | Sequence[int]
) -> ParserObjectList:
"""Concatenate object lists into one, offset index attributes if needed.
Parameters
----------
batch : List[Dict]
List of dictionaries of parsed information, one per event. Each
dictionary matches one data key to one event-worth of parsed data.
key : str
Object list data product key
index : np.ndarray
List of indexes to merge into an overlay
Returns
-------
ObjList
Concatenated obejct list
"""
# If the objects in the lists contain indexes, must offset them
ref_list = batch[index[0]][key]
shifts = None
if len(ref_list.default.index_attrs) > 0:
shifts = ref_list.index_shifts
for idx in index[1:]:
# Shift indexes in the objects
obj_list = batch[idx][key]
for obj in obj_list:
obj.shift_indexes(shifts)
# Increment shifts
if not isinstance(shifts, dict):
shifts += obj_list.index_shifts
else:
for attr in shifts:
shifts[attr] += obj_list.index_shifts[attr]
# Concatenate and return
obj_list = []
for idx in index:
obj_list.extend(batch[idx][key])
return ParserObjectList(obj_list, ref_list.default, shifts)
[docs]
def stack_tensors(
self, batch: BatchType, key: str, index: np.ndarray | Sequence[int]
) -> ParserTensor | ParserIndex | ParserIndexList | ParserEdgeIndex:
"""Stack parser payloads together across an overlay.
Parameters
----------
batch : List[Dict]
List of dictionaries of parsed information, one per event. Each
dictionary matches one data key to one event-worth of parsed data.
key : str
Tensor data product key
index : np.ndarray
List of indexes to merge into an overlay
Returns
-------
ParserTensor or ParserIndex or ParserIndexList or ParserEdgeIndex
Overlayed parser payload of the same logical type as the input.
"""
# Define a reference tensor
ref_data = batch[index[0]][key]
if isinstance(ref_data, ParserTensor):
if ref_data.feats_only:
return self.stack_feature_tensor_data(batch, key, index, ref_data)
return self.stack_tensor_data(batch, key, index, ref_data)
if isinstance(ref_data, ParserIndex):
return self.stack_flat_index_data(batch, key, index, ref_data)
if isinstance(ref_data, ParserIndexList):
return self.stack_index_list_data(batch, key, index, ref_data)
if isinstance(ref_data, ParserEdgeIndex):
return self.stack_edge_index_data(batch, key, index, ref_data)
raise TypeError(
f"Unsupported parser payload type for `{key}`: {type(ref_data).__name__}."
)
[docs]
def stack_tensor_data(
self,
batch: BatchType,
key: str,
index: np.ndarray | Sequence[int],
ref_data: ParserTensor,
) -> ParserTensor:
"""Overlay one tensor-like parser payload.
Parameters
----------
batch : List[Dict]
List of dictionaries of parsed information, one per event. Each
dictionary matches one data key to one event-worth of parsed data.
key : str
Tensor data product key
index : np.ndarray
List of indexes to merge into an overlay
ref_data : ParserTensor
Reference tensor used to check metadata and index columns, and to
preserve overlay metadata in the output.
Returns
-------
ParserTensor
Overlayed parser tensor
"""
# Stack coordinates, if present
coords = None
if ref_data.coords is not None:
# Check that the meta data matches between all images (it must)
if not np.all([batch[idx][key].meta == ref_data.meta for idx in index]):
raise ValueError("The metadata must match across all overlayed tensor.")
coords = np.vstack([batch[idx][key].coords for idx in index])
# If required, offset indexes in the feature tensor
index_shifts = None
if ref_data.feat_index_cols is not None:
# Apply offsets to the relevant columns only (mixed features)
if ref_data.index_shifts is None:
raise ValueError(
"Index shifts must be provided if index columns are present."
)
index_shifts = ref_data.index_shifts.copy()
for idx in index[1:]:
for i, col in enumerate(ref_data.feat_index_cols):
mask = batch[idx][key].features[:, col] > -1
batch[idx][key].features[mask, col] += index_shifts[i]
index_shifts += batch[idx][key].index_shifts
# Stack features
features = np.vstack([batch[idx][key].features for idx in index])
# If requested, remove rows corresponding to duplicate coordinates
if ref_data.remove_duplicates:
# Check that we have coordinates to make the check
if coords is None:
raise ValueError("Must provide coordinates to filter duplicates.")
# Filter out duplicates, aggregating features when requested.
selection_size = len(features)
coords, features, selection = clean_sparse_data(
coords,
features,
sum_cols=ref_data.feat_sum_cols,
avg_cols=ref_data.feat_avg_cols,
prec_col=ref_data.feat_prec_col,
precedence=ref_data.precedence,
return_index=True,
)
self._row_selections[key] = (selection, selection_size)
return self.build_parser_tensor(ref_data, features, coords, index_shifts)
[docs]
def stack_feature_tensor_data(
self,
batch: BatchType,
key: str,
index: np.ndarray | Sequence[int],
ref_data: ParserTensor,
) -> ParserTensor:
"""Overlay one feature-only parser payload.
Parameters
----------
batch : List[Dict]
List of dictionaries of parsed information, one per event. Each
dictionary matches one data key to one event-worth of parsed data.
key : str
Tensor data product key
index : np.ndarray
List of indexes to merge into an overlay
ref_data : ParserTensor
Reference tensor used to check metadata and index columns, and to
preserve overlay metadata in the output.
Returns
-------
ParserTensor
Overlayed parser tensor with feature-only coordinates
"""
# Stack the features
features = np.vstack([batch[idx][key].features for idx in index])
# Nothing to do if no duplicate removal is requested
if not ref_data.remove_duplicates:
return self.build_parser_tensor(ref_data, features, feats_only=True)
# If it is requested, we need a reference tensor
if not ref_data.overlay_reference:
raise ValueError(
f"Feature-only tensor `{key}` requires an `overlay_reference` "
"to remove duplicates during overlay."
)
# Feature-only tensors reuse the duplicate policy of their reference.
row_selection, row_selection_size = self._row_selections.get(
ref_data.overlay_reference, (None, None)
)
if row_selection is None:
# If the reference tensor has not been cleaned up, nothing to do
# for the feature-only tensor either.
return self.build_parser_tensor(ref_data, features, feats_only=True)
if len(features) != row_selection_size:
# If the sizes disagree, that is not allowed
raise ValueError(
f"Feature-only tensor `{key}` has {len(features)} rows before "
f"overlay cleanup, but its reference `{ref_data.overlay_reference}` "
f"has {row_selection_size} rows."
)
return self.build_parser_tensor(
ref_data, features[row_selection], feats_only=True
)
[docs]
@staticmethod
def build_parser_tensor(
ref_data: ParserTensor,
features: np.ndarray,
coords: np.ndarray | None = None,
index_shifts: np.ndarray | None = None,
feats_only: bool | None = None,
) -> ParserTensor:
"""Build a parser tensor while preserving overlay metadata.
Parameters
----------
ref_data : ParserTensor
Reference tensor used to check metadata and index columns, and to
preserve overlay metadata in the output.
features : np.ndarray
Stacked features for the overlay
coords : np.ndarray, optional
Stacked coordinates for the overlay, if present in the reference tensor
index_shifts : np.ndarray, optional
Stacked index shifts for the overlay, if present in the reference tensor
feats_only : bool, optional
Whether the output tensor should be feature-only. If not provided, will
be inferred from the reference tensor.
"""
return ParserTensor(
coords=coords,
features=features,
meta=ref_data.meta,
index_shifts=index_shifts,
index_cols=ref_data.index_cols,
sum_cols=ref_data.sum_cols,
avg_cols=ref_data.avg_cols,
prec_col=ref_data.prec_col,
precedence=ref_data.precedence,
feats_only=ref_data.feats_only if feats_only is None else feats_only,
overlay_reference=ref_data.overlay_reference,
)
[docs]
def stack_flat_index_data(
self,
batch: BatchType,
key: str,
index: np.ndarray | Sequence[int],
ref_data: ParserIndex,
) -> ParserIndex:
"""Overlay one flat index payload.
Parameters
----------
batch : List[Dict]
List of dictionaries of parsed information, one per event. Each
dictionary matches one data key to one event-worth of parsed data.
key : str
Index data product key
index : np.ndarray
List of indexes to merge into an overlay
ref_data : ParserIndex
Reference index used to check metadata and preserve overlay metadata in
the output.
Returns
-------
ParserIndex
Overlayed index data.
"""
span = ref_data.span
shifted_indexes = [batch[index[0]][key].features]
for idx in index[1:]:
shifted_index = batch[idx][key].features.copy()
mask = shifted_index > -1
shifted_index[mask] += span
shifted_indexes.append(shifted_index)
span += batch[idx][key].span
features = np.concatenate(shifted_indexes, axis=-1)
return ParserIndex(features=features, span=span)
[docs]
def stack_index_list_data(
self,
batch: BatchType,
key: str,
index: np.ndarray | Sequence[int],
ref_data: ParserIndexList,
) -> ParserIndexList:
"""Overlay one jagged index-list payload.
Parameters
----------
batch : List[Dict]
List of dictionaries of parsed information, one per event. Each
dictionary matches one data key to one event-worth of parsed data.
key : str
Index list data product key
index : np.ndarray
List of indexes to merge into an overlay
ref_data : ParserIndexList
Reference index list used to check metadata and preserve overlay metadata
in the output.
Returns
-------
ParserIndexList
Overlayed index list data.
"""
span = ref_data.span
features = [entry.copy() for entry in batch[index[0]][key].features]
single_counts = []
if ref_data.single_counts is not None:
single_counts.extend(ref_data.single_counts.tolist())
else:
single_counts.extend(len(entry) for entry in features)
for idx in index[1:]:
shifted_entries = []
for entry in batch[idx][key].features:
shifted_entry = entry.copy()
mask = shifted_entry > -1
shifted_entry[mask] += span
shifted_entries.append(shifted_entry)
features.extend(shifted_entries)
if batch[idx][key].single_counts is not None:
single_counts.extend(batch[idx][key].single_counts.tolist())
else:
single_counts.extend(len(entry) for entry in shifted_entries)
span += batch[idx][key].span
return ParserIndexList(
features=features,
span=span,
single_counts=np.asarray(single_counts, dtype=np.int64),
)
[docs]
def stack_edge_index_data(
self,
batch: BatchType,
key: str,
index: np.ndarray | Sequence[int],
ref_data: ParserEdgeIndex,
) -> ParserEdgeIndex:
"""Overlay one edge-index payload.
Parameters
----------
batch : List[Dict]
List of dictionaries of parsed information, one per event. Each
dictionary matches one data key to one event-worth of parsed data.
key : str
Edge index data product key
index : np.ndarray
List of indexes to merge into an overlay
ref_data : ParserEdgeIndex
Reference edge index used to check metadata and preserve overlay metadata
in the output.
Returns
-------
ParserEdgeIndex
Overlayed edge index data.
"""
span = ref_data.span
shifted_indexes = [batch[index[0]][key].features]
for idx in index[1:]:
shifted_index = batch[idx][key].features.copy()
mask = shifted_index > -1
shifted_index[mask] += span
shifted_indexes.append(shifted_index)
span += batch[idx][key].span
features = np.concatenate(shifted_indexes, axis=-1)
return ParserEdgeIndex(features=features, span=span)