Source code for spine.data.batch.index

"""Module with a dataclass targeted at a batch index or list of indexes."""

from __future__ import annotations

from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
from warnings import warn

import numpy as np

from .base import ArrayLike, BatchBase

__all__ = ["IndexBatch"]


[docs] @dataclass(eq=False) class IndexBatch(BatchBase): """Batched index with the necessary methods to slice it. Attributes ---------- spans : Union[np.ndarray, torch.Tensor] (B) Per-entry parent spans used to build the batch offsets. This is the same quantity as the parser-side ``span`` and may be required when serializing unwrapped indexes for later rebatching. offsets : Union[np.ndarray, torch.Tensor] (B) Offsets between successive indexes in the batch, computed from the cumulative sum of ``spans``. single_counts : Union[np.ndarray, torch.Tensor] (I) Number of index elements per index in the index list. This is the same as counts if the underlying data is a single index """ data: ArrayLike | Sequence[ArrayLike] counts: ArrayLike edges: ArrayLike batch_size: int spans: ArrayLike offsets: ArrayLike single_counts: ArrayLike
[docs] def __init__( self, data: ArrayLike | Sequence[ArrayLike], spans: Sequence[int] | ArrayLike, counts: Sequence[int] | ArrayLike | None = None, single_counts: Sequence[int] | ArrayLike | None = None, batch_ids: Sequence[int] | ArrayLike | None = None, batch_size: int | None = None, default: ArrayLike | None = None, ) -> None: """Initialize the attributes of the class. Parameters ---------- data : Union[np.ndarray, torch.Tensor, List[Union[np.ndarray, torch.Tensor]]] Simple batched index or list of indexes spans : Union[List[int], np.ndarray, torch.Tensor] (B) Per-entry parent spans used to derive ``offsets``. counts : Union[List[int], np.ndarray, torch.Tensor], optional (B) Number of indexes in the batch single_counts : Union[List[int], np.ndarray, torch.Tensor], optional (I) Number of index elements per index in the index list. This is the same as counts if the underlying data is a single index batch_ids : Union[List[int], np.ndarray, torch.Tensor], optional (I) Batch index of each of the clusters. If not specified, the assumption is that each count corresponds to a specific entry batch_size : int, optional Number of entries in the batch. Must be specified along batch_ids default : Union[np.ndarray, torch.Tensor], optional Empty-index prototype used when initializing an empty index list """ # Check weather the input is a single index or a list is_list = ( isinstance(data, (list, tuple)) or getattr(data, "dtype", None) == object ) # Initialize the base class if not is_list: init_data = data elif len(data): init_data = data[0] else: if default is None: warn( "The input index data is an empty list without a default " "index. Will use numpy as an underlying representation." ) default = np.empty(0, dtype=np.int64) init_data = default super().__init__(init_data, is_list=is_list) # Get the counts if they are not provided for free if counts is None: if batch_ids is None or batch_size is None: raise ValueError("Must provide `batch_size` alongside `batch_ids`.") counts = self.get_counts(batch_ids, batch_size) else: batch_size = len(counts) # Get the number of index elements per entry in the batch if single_counts is None: if self.is_list: raise ValueError( "When initializing an index list, provide `single_counts`." ) single_counts = counts else: if len(single_counts) != len(data): raise ValueError( "There must be one single count per index in the list." ) # Cast counts = self._as_long(counts) single_counts = self._as_long(single_counts) spans = self._as_long(spans) # Do a couple of basic sanity checks if self._sum(counts) != len(data): raise ValueError("The `counts` provided must add up to the index length.") if len(counts) != len(spans): raise ValueError("The number of `spans` must match the number of `counts`.") # Compute the offsets from the per-entry spans offsets = self._zeros(len(spans), None if self.is_numpy else spans.device) offsets[1:] = self._cumsum(spans)[:-1] # Get the boundaries between successive index using the counts edges = self.get_edges(counts) # Store the attributes self.data = data self.counts = counts self.single_counts = single_counts self.edges = edges self.spans = spans self.offsets = offsets self.batch_size = batch_size
def __getitem__(self, batch_id: int) -> ArrayLike | list[ArrayLike]: """Returns a subset of the index corresponding to one entry. Parameters ---------- batch_id : int Entry index """ # Make sure the batch_id is sensible if batch_id >= self.batch_size: raise IndexError( f"Index {batch_id} out of bound for a batch size " f"of ({self.batch_size})" ) # Return lower, upper = self.edges[batch_id], self.edges[batch_id + 1] if not self.is_list: return self._index_data[lower:upper] - self.offsets[batch_id] else: return [ index - self.offsets[batch_id] for index in self._index_list[lower:upper] ] @property def _index_data(self) -> ArrayLike: """Underlying single index with index-list cases excluded.""" if self.is_list: raise TypeError("IndexBatch data is an index list.") return self.data @property def _index_list(self) -> Sequence[ArrayLike]: """Underlying index list with single-index cases excluded.""" if not self.is_list: raise TypeError("IndexBatch data is a single index.") if isinstance(self.data, np.ndarray): return self.data.tolist() return self.data @property def index(self) -> ArrayLike: """Alias for the underlying data stored. Returns ------- Union[np.ndarray, torch.Tensor] Underlying index """ if self.is_list: raise ValueError("Underlying data is not a single index, use `index_list`") return self._index_data @property def index_list(self) -> Sequence[ArrayLike]: """Alias for the underlying data list stored. Returns ------- List[Union[np.ndarray, torch.Tensor]] Underlying index list """ if not self.is_list: raise ValueError("Underlying data is a single index, use `index`") return self._index_list @property def full_index(self) -> ArrayLike: """Returns the index combining all sub-indexes, if relevant. Returns ------- Union[np.ndarray, torch.Tensor] (N) Complete concatenated index """ if not self.is_list: return self._index_data else: index_list = self._index_list return self._cat(index_list) if len(index_list) else self._empty(0) @property def index_ids(self) -> ArrayLike: """Returns the ID of the index in the list each element belongs to. Returns ------- Union[np.ndarray, torch.Tensor] (M) List of index IDs for each element """ if not self.is_list: raise ValueError("Underlying data must be a list of index") return self._repeat(self._arange(len(self.data)), self.single_counts) @property def full_counts(self) -> ArrayLike: """Returns the total number of elements in each batch entry. Returns ------- Union[np.ndarray, torch.Tensor] (B) Number of elements in each batch entry """ if not self.is_list: return self.counts else: full_counts = self._empty(self.batch_size) for b in range(self.batch_size): lower, upper = self.edges[b], self.edges[b + 1] full_counts[b] = self._sum(self.single_counts[lower:upper]) return self._as_long(full_counts) @property def batch_ids(self) -> ArrayLike: """Returns the batch ID of each index in the list. Returns ------- Union[np.ndarray, torch.Tensor] (I) Batch ID array, one per index in the list """ return self._repeat(self._arange(self.batch_size), self.counts) @property def full_batch_ids(self) -> ArrayLike: """Returns the batch ID of each element in the full index list. Returns ------- Union[np.ndarray, torch.Tensor] (N) Complete batch ID array, one per element """ return self._repeat(self._arange(self.batch_size), self.full_counts)
[docs] def split(self) -> list[ArrayLike] | list[list[ArrayLike]]: """Breaks up the index batch into its constituents. Returns ------- List[List[Union[np.ndarray, torch.Tensor]]] List of list of indexes per entry in the batch """ if self.is_list: indexes = [] for batch_id in range(self.batch_size): lower, upper = self.edges[batch_id], self.edges[batch_id + 1] indexes.append( [index - self.offsets[batch_id] for index in self.data[lower:upper]] ) return indexes indexes = list(self._split(self.data, self.splits)) for batch_id in range(self.batch_size): indexes[batch_id] = indexes[batch_id] - self.offsets[batch_id] return indexes
[docs] def merge(self, index_batch: "IndexBatch") -> "IndexBatch": """Merge this index batch with another. Parameters ---------- index_batch : IndexBatch Other index batch object to merge with Returns ------- IndexBatch Merged index batch """ # Basic cross-checks if not (self.spans == index_batch.spans).all(): raise ValueError("Both index batches should carry the same spans.") # Stack the indexes entry-wise in the batch indexes, single_counts = [], [] for b in range(self.batch_size): if self.is_list: lower, upper = self.edges[b], self.edges[b + 1] indexes.extend(self.index_list[lower:upper]) single_counts.extend(self.single_counts[lower:upper]) lower, upper = index_batch.edges[b], index_batch.edges[b + 1] indexes.extend(index_batch.index_list[lower:upper]) single_counts.extend(index_batch.single_counts[lower:upper]) else: lower, upper = self.edges[b], self.edges[b + 1] indexes.append(self.index[lower:upper]) lower, upper = index_batch.edges[b], index_batch.edges[b + 1] indexes.append(index_batch.index[lower:upper]) counts = self.counts + index_batch.counts if self.is_list: return IndexBatch( indexes, self.spans, counts, single_counts, ) else: return IndexBatch(self._cat(indexes), self.spans, counts)
[docs] def to_numpy(self) -> "IndexBatch": """Cast underlying index to a `np.ndarray` and return a new instance. Returns ------- TensorBatch New `TensorBatch` object with an underlying np.ndarray tensor. """ # If the underlying data is of the right type, nothing to do if self.is_numpy: return self if not self.is_list: data = self._to_numpy(self.data) else: data = [self._to_numpy(d) for d in self.data] counts = self._to_numpy(self.counts) spans = self._to_numpy(self.spans) single_counts = None if self.is_list: single_counts = self._to_numpy(self.single_counts) return IndexBatch(data, spans, counts, single_counts)
[docs] def to_tensor(self, dtype: Any = None, device: Any = None) -> "IndexBatch": """Cast underlying index to a `torch.tensor` and return a new instance. Parameters ---------- dtype : torch.dtype, optional Data type of the tensor to create device : torch.device, optional Device on which to put the tensor Returns ------- TensorBatch New `TensorBatch` object with an underlying np.ndarray tensor. """ # If the underlying data is of the right type, nothing to do if not self.is_numpy: return self if not self.is_list: data = self._to_tensor(self.data, dtype, device) else: data = [self._to_tensor(d, dtype, device) for d in self.data] counts = self._to_tensor(self.counts, dtype, device) spans = self._to_tensor(self.spans, dtype, device) single_counts = None if self.is_list: single_counts = self._to_tensor(self.single_counts, dtype, device) return IndexBatch(data, spans, counts, single_counts)