Source code for spine.io.collate

"""Contains implementations of data collation classes.

Collate classes are a middleware between parsers and datasets. They are given
to :class:`torch.utils.data.DataLoader` as the `collate_fn` argumement.
"""

from __future__ import annotations

from collections.abc import Mapping, Sequence
from typing import Any

import numpy as np

from spine.data import EdgeIndexBatch, IndexBatch, TensorBatch
from spine.geo import GeoManager

from .overlay import Overlayer
from .parse.data import ParserEdgeIndex, ParserIndex, ParserIndexList, ParserTensor

SampleDict = dict[str, Any]
BatchType = Sequence[SampleDict]

__all__ = ["CollateAll"]


[docs] class CollateAll: """General collate function for all data types coming from the parsers. Provide it with a list of dictionaries. Each value can be one of: - A `ParserTensor` with coordinates, features and metadata, merged into rows of the form `[batch_id, *coords, *features]` - A feature-only `ParserTensor`, merged into `[batch_id, *features]` - A `ParserIndex`, `ParserIndexList` or `ParserEdgeIndex`, merged into an offset-adjusted index batch - Scalar values, lists and objects, gathered into a list """ name = "all" def __init__( self, data_types: Mapping[str, str], split: bool = False, target_id: int = 0, source: Mapping[str, str] | None = None, overlay: Mapping[str, Any] | None = None, overlay_methods: Mapping[str, str] | None = None, ) -> None: """Initialize the collation parameters. Parameters ---------- data_types : mapping Mapping of data types returned by the parsers split : bool, default False Whether to split the input by module ID (each module gets its own batch ID, multiplies the number of batches by `num_modules`) target_id : int, default 0 If split is `True`, specifies where to relocate the points source : mapping, optional Mapping which maps keys to their corresponding sources. This can be used to split tensors without having to check the geometry overlay : mapping, optional Image overlay configuration overlay_methods : mapping Mapping of overlay methods """ # Store the data types of each parser output self.data_types = data_types # Initialize the geometry, if required self.split = split self.source = None if split: self.target_id = target_id self.geo = GeoManager.get_instance() self.num_modules = self.geo.tpc.num_modules self.source = source # Initialize the overlayer, if required self.overlayer = None if overlay is not None: if overlay_methods is None: raise ValueError( "`overlay_methods` must be provided if `overlay` is not None." ) self.overlayer = Overlayer( **overlay, data_types=data_types, methods=overlay_methods ) def __call__(self, batch: BatchType) -> dict[str, Any]: """Takes a list of parsed information, one per event in a batch, and collates them into a single object per entry in the batch. Parameters ---------- batch : List[Dict] List of dictionaries of parsed information, one per event. Each dictionary matches one data key to one event-worth of parsed data. Returns ------- Dict Dictionary that matches one data key to one batch-worth of data """ # Overlay data (modify batch), if needed if self.overlayer is not None: batch = self.overlayer(batch) # Loop over the data keys, merge all events in a batch data = {} for key, data_type in self.data_types.items(): # Dispatch ref_data = batch[0][key] if data_type == "tensor": if isinstance(ref_data, ParserTensor): if ref_data.coords is not None and not ref_data.feats_only: # Case where a coordinates tensor and a feature tensor # are provided, along with the metadata information data[key] = self.stack_coord_tensors(batch, key) else: # Case where there is a feature tensor provided per entry data[key] = self.stack_feat_tensors(batch, key) elif isinstance( ref_data, (ParserIndex, ParserIndexList, ParserEdgeIndex) ): # Case where a coordinates tensor and a feature tensor # are not provided and the payload is index-like data[key] = self.stack_index_tensors(batch, key) else: raise TypeError( f"Unsupported parser payload type for `{key}`: " f"{type(ref_data).__name__}." ) else: # In all other cases, just make a list data[key] = [sample[key] for sample in batch] return data
[docs] def stack_coord_tensors(self, batch: BatchType, key: str) -> TensorBatch: """Stack coordinate tensors together across an overlay. Parameters ---------- batch : List[Dict] List of dictionaries of parsed information, one per event. Each dictionary matches one data key to one event-worth of parsed data. key : str Data product key Returns ------- TensorBatch Batched coordinate tensor """ # Dispatch batch_size = len(batch) if not self.split: # If not split, simply stack everything coords = np.vstack([sample[key].coords for sample in batch]) features = np.vstack([sample[key].features for sample in batch]) counts = [len(sample[key].coords) for sample in batch] batch_ids = np.repeat(np.arange(batch_size, dtype=coords.dtype), counts) else: # If split, must shift the voxel coordinates and create # one batch ID per [batch, volume] pair coords_v, features_v, batch_ids_v = [], [], [] counts = np.empty(batch_size * self.num_modules, dtype=np.int64) for s, sample in enumerate(batch): # Identify which point belongs to which module coords = sample[key].coords features = sample[key].features meta = sample[key].meta coords_wrapped, module_indexes = self.geo.split( coords.reshape(-1, 3), self.target_id, meta=meta ) coords = coords_wrapped.reshape(-1, coords.shape[1]) # If there are more than one point per row and they # are in separate volumes, the choice is arbitrary if coords.shape[1] > 3: num_points = coords.shape[1] // 3 free = np.ones(len(coords), dtype=bool) for m, module_index in enumerate(module_indexes): mask = np.zeros(len(coords_wrapped), dtype=bool) mask[module_index] = True mask = mask.reshape(-1, num_points).any(axis=1) module_indexes[m] = np.where(free & mask)[0] free[module_indexes[m]] = False # Assign a different batch ID to each volume for m, module_index in enumerate(module_indexes): coords_v.append(coords[module_index]) features_v.append(features[module_index]) idx = self.num_modules * s + m batch_ids_v.append( np.full(len(module_index), idx, dtype=coords.dtype) ) counts[idx] = len(module_index) coords = np.vstack(coords_v) features = np.vstack(features_v) batch_ids = np.concatenate(batch_ids_v) # Stack the coordinates with the features tensor = np.hstack([batch_ids[:, None], coords, features]) coord_cols = np.arange(1, 1 + coords.shape[1]) return TensorBatch( tensor.astype(features.dtype), counts, has_batch_col=True, coord_cols=coord_cols, )
[docs] def stack_index_tensors( self, batch: BatchType, key: str ) -> IndexBatch | EdgeIndexBatch: """Stack index tensors together across an overlay. Parameters ---------- batch : List[Dict] List of dictionaries of parsed information, one per event. Each dictionary matches one data key to one event-worth of parsed data. key : str Data product key Returns ------- Union[IndexBatch, EdgeIndexBatch] Batched index tensor """ # Start by computing the necessary node ID offsets to apply total_counts = [sample[key].span for sample in batch] spans = np.asarray(total_counts, dtype=np.int64) offsets = np.zeros(len(total_counts), dtype=int) offsets[1:] = np.cumsum(total_counts)[:-1] if isinstance(batch[0][key], ParserIndexList): index_list = [] counts = [] single_counts = [] for i, sample in enumerate(batch): sample_index_list = [ np.asarray(index, dtype=np.int64) + offsets[i] for index in sample[key].features ] index_list.extend(sample_index_list) counts.append(len(sample_index_list)) if sample[key].single_counts is not None: single_counts.extend(sample[key].single_counts.tolist()) else: single_counts.extend(len(index) for index in sample_index_list) return IndexBatch(index_list, spans, counts, single_counts) # Stack the indexes, do not add a batch column index_list = [] for i, sample in enumerate(batch): index_list.append(sample[key].features + offsets[i]) axis = 0 if index_list[0].ndim == 1 else 1 index = np.concatenate(index_list, axis=axis) counts = [sample[key].features.shape[-1] for sample in batch] if len(index.shape) == 1: return IndexBatch(index, spans, counts) else: return EdgeIndexBatch(index, counts, spans, directed=True)
[docs] def stack_feat_tensors(self, batch: BatchType, key: str) -> TensorBatch: """Stack feature tensors together across an overlay. Parameters ---------- batch : List[Dict] List of dictionaries of parsed information, one per event. Each dictionary matches one data key to one event-worth of parsed data. key : str Data product key Returns ------- TensorBatch Batched feature tensor """ # Fetch the source object, if it exists sources = None if self.split and self.source is not None and key in self.source: source_key = self.source[key] sources = [batch[i][source_key].features for i in range(len(batch))] # Dispatch if not self.split or sources is None: tensor = np.concatenate([sample[key].features for sample in batch]) counts = [len(sample[key].features) for sample in batch] else: batch_size = len(batch) features_v = [] counts = np.empty(batch_size * self.num_modules, dtype=np.int64) for s, sample in enumerate(batch): features = sample[key].features for m in range(self.num_modules): module_index = np.where(sources[s][:, 0] == m)[0] features_v.append(features[module_index]) idx = self.num_modules * s + m counts[idx] = len(module_index) tensor = np.vstack(features_v) return TensorBatch(tensor, counts)