"""Module with a dataclass targeted at batched matrix/tensors."""
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
import numpy as np
from spine.constants import BATCH_COL, COORD_COLS
from spine.utils.conditional import ME, is_sparse_tensor_like, torch
from .base import ArrayLike, BatchBase, SparseTensorLike
__all__ = ["TensorBatch"]
[docs]
@dataclass(eq=False)
class TensorBatch(BatchBase):
"""Batched tensor with the necessary methods to slice it."""
data: ArrayLike | SparseTensorLike
counts: ArrayLike
edges: ArrayLike
batch_size: int
has_batch_col: bool
coord_cols: Sequence[int] | np.ndarray | None
@property
def _array_data(self) -> ArrayLike:
"""Dense tensor data with sparse cases excluded."""
if is_sparse_tensor_like(self.data):
raise TypeError("TensorBatch data is sparse.")
return self.data
@property
def _sparse_data(self) -> SparseTensorLike:
"""Sparse tensor data with dense cases excluded."""
if not is_sparse_tensor_like(self.data):
raise TypeError("TensorBatch data is not sparse.")
return self.data
[docs]
def __init__(
self,
data: ArrayLike | SparseTensorLike,
counts: Sequence[int] | ArrayLike | None = None,
batch_size: int | None = None,
is_sparse: bool = False,
has_batch_col: bool = False,
coord_cols: Sequence[int] | np.ndarray | None = None,
) -> None:
"""Initialize the attributes of the class.
Parameters
----------
data : Union[np.ndarray, torch.Tensor, ME.SparseTensor]
(N, C) Batched tensors
counts : Union[List[int], np.ndarray, torch.Tensor]
(B) Number of data rows in each entry
batch_size : int, optional
Number of entries that make up the batched data
is_sparse : bool, default False
If initializing from an ME sparse data, flip to True
has_batch_col : bool, default False
Wheather the tensor has a column specifying the batch ID
coord_cols : Union[List[int], np.ndarray], optional
List of columns specifying coordinates
"""
# Initialize the base class
super().__init__(data, is_sparse=is_sparse)
# Should provide either the counts, or the batch size
if (counts is not None) == (batch_size is not None):
raise ValueError("Provide either `counts` or `batch_size`, not both.")
# If the data is sparse, it must have a batch column and coordinates
if is_sparse:
has_batch_col = True
coord_cols = COORD_COLS
# If the counts are not provided, must build them once
if counts is None:
# Define the array functions depending on the input type
if not has_batch_col:
raise ValueError("Cannot get the counts without a batch column.")
if batch_size is None: # pragma: no cover
raise ValueError("Must provide `batch_size` to infer counts.")
batch_size_value = batch_size
if is_sparse:
if not is_sparse_tensor_like(data): # pragma: no cover
raise TypeError(
"Sparse tensor batches must be initialized with "
"MinkowskiEngine-like sparse tensor data."
)
ref = data.C
else:
if is_sparse_tensor_like(data):
raise TypeError("Sparse tensor data must set `is_sparse=True`.")
ref = data
counts = self.get_counts(ref[:, BATCH_COL], batch_size_value)
else:
# If the number of batches is not provided, get it from the counts
batch_size_value = len(counts)
# Cast
counts = self._as_long(counts)
if self._sum(counts) != len(data):
raise ValueError(
"The `counts` provided do not add up to the tensor length."
)
# Get the boundaries between entries in the batch
edges = self.get_edges(counts)
# Store the attributes
self.data = data
self.counts = counts
self.edges = edges
self.batch_size = batch_size_value
self.has_batch_col = has_batch_col
self.coord_cols = coord_cols
def __getitem__(self, batch_id: int) -> ArrayLike | SparseTensorLike:
"""Returns a subset of the tensor corresponding to one entry.
Parameters
----------
batch_id : int
Entry index
"""
# Make sure the batch_id is sensible
if batch_id >= self.batch_size:
raise IndexError(
f"Index {batch_id} out of bound for a batch size "
f"of ({self.batch_size})"
)
# Return
lower, upper = self.edges[batch_id], self.edges[batch_id + 1]
if not self.is_sparse:
return self._array_data[lower:upper]
else:
data = self._sparse_data
return ME.SparseTensor(data.F[lower:upper], coordinates=data.C[lower:upper])
@property
def tensor(self) -> ArrayLike | SparseTensorLike:
"""Alias for the underlying data stored.
Returns
-------
Union[np.ndarray, torch.Tensor, ME.SparseTensor]
Underlying tensor of data
"""
return self.data
@property
def batch_ids(self) -> ArrayLike:
"""Returns the batch ID of each of the elements in the tensor.
Returns
-------
Union[np.ndarray, torch.Tensor]
(N) Batch ID of each element in the tensor
"""
return self._repeat(self._arange(self.batch_size), self.counts)
[docs]
def split(self) -> list[ArrayLike | SparseTensorLike]:
"""Breaks up the tensor batch into its constituents.
Returns
-------
List[Union[np.ndarray, torch.Tensor]]
List of one tensor per entry in the batch
"""
if not self.is_sparse:
return self._split(self._array_data, self.splits)
else:
data = self._sparse_data
coords = self._split(data.C, self.splits)
feats = self._split(data.F, self.splits)
return [
ME.SparseTensor(feats[i], coordinates=coords[i])
for i in range(self.batch_size)
]
[docs]
def apply_mask(self, mask: ArrayLike) -> None:
"""Apply a global mask to the underlying tensor, update batching.
Parameters
----------
mask : Union[np.ndarray, torch.Tensor]
(N) Boolean mask to apply to the underlying tensor
"""
# Update underlying tensor in place
self.data = self.data[mask]
# Update batching information
batch_ids = self.batch_ids[mask]
self.counts = self.get_counts(batch_ids, self.batch_size)
self.edges = self.get_edges(self.counts)
[docs]
def merge(self, tensor_batch: "TensorBatch") -> "TensorBatch":
"""Merge this tensor batch with another.
Parameters
----------
tensor_batch : TensorBatch
Other tensor batch object to merge with
Returns
-------
TensorBatch
Merged tensor batch
"""
# Stack the tensors entry-wise in the batch
entries = []
for b in range(self.batch_size):
entries.append(self[b])
entries.append(tensor_batch[b])
tensor = self._cat(entries)
counts = self.counts + tensor_batch.counts
return TensorBatch(tensor, counts)
[docs]
def to_numpy(self) -> "TensorBatch":
"""Cast underlying tensor to a `np.ndarray` and return a new instance.
Returns
-------
TensorBatch
New `TensorBatch` object with an underlying np.ndarray tensor.
"""
# If the underlying data is of the right type, nothing to do
if self.is_numpy:
return self
data = self.data
if self.is_sparse:
sparse_data = self._sparse_data
data = torch.cat(
[sparse_data.C.to(dtype=sparse_data.F.dtype), sparse_data.F], dim=1
)
data = self._to_numpy(data)
counts = self._to_numpy(self.counts)
return TensorBatch(
data, counts, has_batch_col=self.has_batch_col, coord_cols=self.coord_cols
)
[docs]
def to_tensor(self, dtype: Any = None, device: Any = None) -> "TensorBatch":
"""Cast underlying tensor to a `torch.tensor` and return a new instance.
Parameters
----------
dtype : torch.dtype, optional
Data type of the tensor to create
device : torch.device, optional
Device on which to put the tensor
Returns
-------
TensorBatch
New `TensorBatch` object with an underlying np.ndarray tensor.
"""
# If the underlying data is of the right type, nothing to do
if not self.is_numpy:
return self
data = self._to_tensor(self.data, dtype, device)
counts = self._to_tensor(self.counts, dtype, device)
return TensorBatch(
data, counts, has_batch_col=self.has_batch_col, coord_cols=self.coord_cols
)
[docs]
def to_cm(self, meta: Any) -> None:
"""Converts the pixel coordinates of the tensor to cm.
Parameters
----------
meta : Meta
Metadata information about the rasterized image
"""
if not self.is_numpy:
raise ValueError("Can only convert units of numpy arrays.")
data = self._array_data
data[:, COORD_COLS] = meta.to_cm(data[:, COORD_COLS], center=True)
[docs]
def to_px(self, meta: Any) -> None:
"""Converts the coordinates of the tensor to pixel indexes.
Parameters
----------
meta : Meta
Metadata information about the rasterized image
"""
if not self.is_numpy:
raise ValueError("Can only convert units of numpy arrays.")
data = self._array_data
data[:, COORD_COLS] = meta.to_px(data[:, COORD_COLS], floor=True)
[docs]
@classmethod
def from_list(cls, data_list: Sequence[ArrayLike]) -> "TensorBatch":
"""Builds a batch from a list of tensors.
Parameters
----------
data_list : List[Union[np.ndarray, torch.Tensor]]
List of tensors, exactly one per batch
"""
# Check that we are not fed an empty list of tensors
if not len(data_list):
raise ValueError("Must provide at least one tensor to build a tensor batch")
is_numpy = not isinstance(data_list[0], torch.Tensor)
# Compute the counts from the input list
counts = [len(t) for t in data_list]
# Concatenate input
if is_numpy:
return cls(np.concatenate(data_list, axis=0), counts)
else:
return cls(torch.cat(data_list, dim=0), counts)