"""Module with a dataclass targeted at a batch index or list of indexes."""
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
from warnings import warn
import numpy as np
from .base import ArrayLike, BatchBase
__all__ = ["IndexBatch"]
[docs]
@dataclass(eq=False)
class IndexBatch(BatchBase):
"""Batched index with the necessary methods to slice it.
Attributes
----------
spans : Union[np.ndarray, torch.Tensor]
(B) Per-entry parent spans used to build the batch offsets. This is
the same quantity as the parser-side ``span`` and may be
required when serializing unwrapped indexes for later rebatching.
offsets : Union[np.ndarray, torch.Tensor]
(B) Offsets between successive indexes in the batch, computed from
the cumulative sum of ``spans``.
single_counts : Union[np.ndarray, torch.Tensor]
(I) Number of index elements per index in the index list. This
is the same as counts if the underlying data is a single index
"""
data: ArrayLike | Sequence[ArrayLike]
counts: ArrayLike
edges: ArrayLike
batch_size: int
spans: ArrayLike
offsets: ArrayLike
single_counts: ArrayLike
[docs]
def __init__(
self,
data: ArrayLike | Sequence[ArrayLike],
spans: Sequence[int] | ArrayLike,
counts: Sequence[int] | ArrayLike | None = None,
single_counts: Sequence[int] | ArrayLike | None = None,
batch_ids: Sequence[int] | ArrayLike | None = None,
batch_size: int | None = None,
default: ArrayLike | None = None,
) -> None:
"""Initialize the attributes of the class.
Parameters
----------
data : Union[np.ndarray, torch.Tensor,
List[Union[np.ndarray, torch.Tensor]]]
Simple batched index or list of indexes
spans : Union[List[int], np.ndarray, torch.Tensor]
(B) Per-entry parent spans used to derive ``offsets``.
counts : Union[List[int], np.ndarray, torch.Tensor], optional
(B) Number of indexes in the batch
single_counts : Union[List[int], np.ndarray, torch.Tensor], optional
(I) Number of index elements per index in the index list. This
is the same as counts if the underlying data is a single index
batch_ids : Union[List[int], np.ndarray, torch.Tensor], optional
(I) Batch index of each of the clusters. If not specified, the
assumption is that each count corresponds to a specific entry
batch_size : int, optional
Number of entries in the batch. Must be specified along batch_ids
default : Union[np.ndarray, torch.Tensor], optional
Empty-index prototype used when initializing an empty index list
"""
# Check weather the input is a single index or a list
is_list = (
isinstance(data, (list, tuple)) or getattr(data, "dtype", None) == object
)
# Initialize the base class
if not is_list:
init_data = data
elif len(data):
init_data = data[0]
else:
if default is None:
warn(
"The input index data is an empty list without a default "
"index. Will use numpy as an underlying representation."
)
default = np.empty(0, dtype=np.int64)
init_data = default
super().__init__(init_data, is_list=is_list)
# Get the counts if they are not provided for free
if counts is None:
if batch_ids is None or batch_size is None:
raise ValueError("Must provide `batch_size` alongside `batch_ids`.")
counts = self.get_counts(batch_ids, batch_size)
else:
batch_size = len(counts)
# Get the number of index elements per entry in the batch
if single_counts is None:
if self.is_list:
raise ValueError(
"When initializing an index list, provide `single_counts`."
)
single_counts = counts
else:
if len(single_counts) != len(data):
raise ValueError(
"There must be one single count per index in the list."
)
# Cast
counts = self._as_long(counts)
single_counts = self._as_long(single_counts)
spans = self._as_long(spans)
# Do a couple of basic sanity checks
if self._sum(counts) != len(data):
raise ValueError("The `counts` provided must add up to the index length.")
if len(counts) != len(spans):
raise ValueError("The number of `spans` must match the number of `counts`.")
# Compute the offsets from the per-entry spans
offsets = self._zeros(len(spans), None if self.is_numpy else spans.device)
offsets[1:] = self._cumsum(spans)[:-1]
# Get the boundaries between successive index using the counts
edges = self.get_edges(counts)
# Store the attributes
self.data = data
self.counts = counts
self.single_counts = single_counts
self.edges = edges
self.spans = spans
self.offsets = offsets
self.batch_size = batch_size
def __getitem__(self, batch_id: int) -> ArrayLike | list[ArrayLike]:
"""Returns a subset of the index corresponding to one entry.
Parameters
----------
batch_id : int
Entry index
"""
# Make sure the batch_id is sensible
if batch_id >= self.batch_size:
raise IndexError(
f"Index {batch_id} out of bound for a batch size "
f"of ({self.batch_size})"
)
# Return
lower, upper = self.edges[batch_id], self.edges[batch_id + 1]
if not self.is_list:
return self._index_data[lower:upper] - self.offsets[batch_id]
else:
return [
index - self.offsets[batch_id]
for index in self._index_list[lower:upper]
]
@property
def _index_data(self) -> ArrayLike:
"""Underlying single index with index-list cases excluded."""
if self.is_list:
raise TypeError("IndexBatch data is an index list.")
return self.data
@property
def _index_list(self) -> Sequence[ArrayLike]:
"""Underlying index list with single-index cases excluded."""
if not self.is_list:
raise TypeError("IndexBatch data is a single index.")
if isinstance(self.data, np.ndarray):
return self.data.tolist()
return self.data
@property
def index(self) -> ArrayLike:
"""Alias for the underlying data stored.
Returns
-------
Union[np.ndarray, torch.Tensor]
Underlying index
"""
if self.is_list:
raise ValueError("Underlying data is not a single index, use `index_list`")
return self._index_data
@property
def index_list(self) -> Sequence[ArrayLike]:
"""Alias for the underlying data list stored.
Returns
-------
List[Union[np.ndarray, torch.Tensor]]
Underlying index list
"""
if not self.is_list:
raise ValueError("Underlying data is a single index, use `index`")
return self._index_list
@property
def full_index(self) -> ArrayLike:
"""Returns the index combining all sub-indexes, if relevant.
Returns
-------
Union[np.ndarray, torch.Tensor]
(N) Complete concatenated index
"""
if not self.is_list:
return self._index_data
else:
index_list = self._index_list
return self._cat(index_list) if len(index_list) else self._empty(0)
@property
def index_ids(self) -> ArrayLike:
"""Returns the ID of the index in the list each element belongs to.
Returns
-------
Union[np.ndarray, torch.Tensor]
(M) List of index IDs for each element
"""
if not self.is_list:
raise ValueError("Underlying data must be a list of index")
return self._repeat(self._arange(len(self.data)), self.single_counts)
@property
def full_counts(self) -> ArrayLike:
"""Returns the total number of elements in each batch entry.
Returns
-------
Union[np.ndarray, torch.Tensor]
(B) Number of elements in each batch entry
"""
if not self.is_list:
return self.counts
else:
full_counts = self._empty(self.batch_size)
for b in range(self.batch_size):
lower, upper = self.edges[b], self.edges[b + 1]
full_counts[b] = self._sum(self.single_counts[lower:upper])
return self._as_long(full_counts)
@property
def batch_ids(self) -> ArrayLike:
"""Returns the batch ID of each index in the list.
Returns
-------
Union[np.ndarray, torch.Tensor]
(I) Batch ID array, one per index in the list
"""
return self._repeat(self._arange(self.batch_size), self.counts)
@property
def full_batch_ids(self) -> ArrayLike:
"""Returns the batch ID of each element in the full index list.
Returns
-------
Union[np.ndarray, torch.Tensor]
(N) Complete batch ID array, one per element
"""
return self._repeat(self._arange(self.batch_size), self.full_counts)
[docs]
def split(self) -> list[ArrayLike] | list[list[ArrayLike]]:
"""Breaks up the index batch into its constituents.
Returns
-------
List[List[Union[np.ndarray, torch.Tensor]]]
List of list of indexes per entry in the batch
"""
if self.is_list:
indexes = []
for batch_id in range(self.batch_size):
lower, upper = self.edges[batch_id], self.edges[batch_id + 1]
indexes.append(
[index - self.offsets[batch_id] for index in self.data[lower:upper]]
)
return indexes
indexes = list(self._split(self.data, self.splits))
for batch_id in range(self.batch_size):
indexes[batch_id] = indexes[batch_id] - self.offsets[batch_id]
return indexes
[docs]
def merge(self, index_batch: "IndexBatch") -> "IndexBatch":
"""Merge this index batch with another.
Parameters
----------
index_batch : IndexBatch
Other index batch object to merge with
Returns
-------
IndexBatch
Merged index batch
"""
# Basic cross-checks
if not (self.spans == index_batch.spans).all():
raise ValueError("Both index batches should carry the same spans.")
# Stack the indexes entry-wise in the batch
indexes, single_counts = [], []
for b in range(self.batch_size):
if self.is_list:
lower, upper = self.edges[b], self.edges[b + 1]
indexes.extend(self.index_list[lower:upper])
single_counts.extend(self.single_counts[lower:upper])
lower, upper = index_batch.edges[b], index_batch.edges[b + 1]
indexes.extend(index_batch.index_list[lower:upper])
single_counts.extend(index_batch.single_counts[lower:upper])
else:
lower, upper = self.edges[b], self.edges[b + 1]
indexes.append(self.index[lower:upper])
lower, upper = index_batch.edges[b], index_batch.edges[b + 1]
indexes.append(index_batch.index[lower:upper])
counts = self.counts + index_batch.counts
if self.is_list:
return IndexBatch(
indexes,
self.spans,
counts,
single_counts,
)
else:
return IndexBatch(self._cat(indexes), self.spans, counts)
[docs]
def to_numpy(self) -> "IndexBatch":
"""Cast underlying index to a `np.ndarray` and return a new instance.
Returns
-------
TensorBatch
New `TensorBatch` object with an underlying np.ndarray tensor.
"""
# If the underlying data is of the right type, nothing to do
if self.is_numpy:
return self
if not self.is_list:
data = self._to_numpy(self.data)
else:
data = [self._to_numpy(d) for d in self.data]
counts = self._to_numpy(self.counts)
spans = self._to_numpy(self.spans)
single_counts = None
if self.is_list:
single_counts = self._to_numpy(self.single_counts)
return IndexBatch(data, spans, counts, single_counts)
[docs]
def to_tensor(self, dtype: Any = None, device: Any = None) -> "IndexBatch":
"""Cast underlying index to a `torch.tensor` and return a new instance.
Parameters
----------
dtype : torch.dtype, optional
Data type of the tensor to create
device : torch.device, optional
Device on which to put the tensor
Returns
-------
TensorBatch
New `TensorBatch` object with an underlying np.ndarray tensor.
"""
# If the underlying data is of the right type, nothing to do
if not self.is_numpy:
return self
if not self.is_list:
data = self._to_tensor(self.data, dtype, device)
else:
data = [self._to_tensor(d, dtype, device) for d in self.data]
counts = self._to_tensor(self.counts, dtype, device)
spans = self._to_tensor(self.spans, dtype, device)
single_counts = None
if self.is_list:
single_counts = self._to_tensor(self.single_counts, dtype, device)
return IndexBatch(data, spans, counts, single_counts)