"""Module with a dataclass targeted at a batched edge index.
An edge index is a sparse representation of a graph incidence matrix.
"""
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
from .base import ArrayLike, BatchBase
__all__ = ["EdgeIndexBatch"]
[docs]
@dataclass(eq=False)
class EdgeIndexBatch(BatchBase):
"""Batched edge index with the necessary methods to slice it.
Attributes
----------
spans : Union[np.ndarray, torch.Tensor]
(B) Per-entry parent spans used to build the batch offsets.
offsets : Union[np.ndarray, torch.Tensor]
(B) Offsets between successive indexes in the batch, computed from
the cumulative sum of ``spans``.
directed : bool
Whether the edge index is directed or undirected
"""
data: ArrayLike
counts: ArrayLike
edges: ArrayLike
batch_size: int
spans: ArrayLike
offsets: ArrayLike
directed: bool
[docs]
def __init__(
self,
data: ArrayLike,
counts: Sequence[int] | ArrayLike,
spans: Sequence[int] | ArrayLike,
directed: bool,
) -> None:
"""Initialize the attributes of the class.
If the edge index corresponds to an undirected graph, each edge
should have its reciprocal edge immediately after, e.g.
.. code-block:: python
[[0,1,0,2,0,3,...],
[1,0,2,0,3,0,...]]
Parameters
----------
data : Union[np.ndarray, torch.Tensor]
(2, E) Batched edge index
counts : Union[List[int], np.ndarray, torch.Tensor]
(B) Number of index elements per entry in the batch
spans : Union[List[int], np.ndarray, torch.Tensor]
(B) Per-entry parent spans used to derive ``offsets``.
directed : bool
Whether the edge index is directed or undirected
"""
# Initialize the base class
super().__init__(data)
# Cast
counts = self._as_long(counts)
spans = self._as_long(spans)
# Do a couple of basic sanity checks
if self._sum(counts) != data.shape[1]:
raise ValueError("The `counts` provided do not add up to the index length")
if len(counts) != len(spans):
raise ValueError(
"The number of `spans` does not match the number of `counts`"
)
if not directed and data.shape[1] % 2 != 0:
raise ValueError(
"If the edge index is undirected, it should have an "
"even number of edge"
)
# Compute the offsets from the per-entry spans
offsets = self._zeros(len(spans), None if self.is_numpy else spans.device)
offsets[1:] = self._cumsum(spans)[:-1]
# Get the boundaries between successive index using the counts
edges = self.get_edges(counts)
# Store the attributes
self.data = data
self.counts = counts
self.spans = spans
self.edges = edges
self.offsets = offsets
self.directed = directed
self.batch_size = len(counts)
def __getitem__(self, batch_id: int) -> ArrayLike:
"""Returns a subset of the index corresponding to one entry.
Parameters
----------
batch_id : int
Entry index
"""
# Make sure the batch_id is sensible
if batch_id >= self.batch_size:
raise IndexError(
f"Index {batch_id} out of bound for a batch size "
f"of ({self.batch_size})"
)
# Return
lower, upper = self.edges[batch_id], self.edges[batch_id + 1]
index = self.data[:, lower:upper] - self.offsets[batch_id]
return self._transpose(index)
@property
def index(self) -> ArrayLike:
"""Alias for the underlying data stored.
Returns
-------
Union[np.ndarray, torch.Tensor]
(2, E) Underlying batch of edge indexes
"""
return self.data
@property
def index_t(self) -> ArrayLike:
"""Alias for the underlying data stored, transposed
Returns
-------
Union[np.ndarray, torch.Tensor]
(E, 2) Underlying batch of edge indexes, transposed
"""
return self._transpose(self.data)
@property
def batch_ids(self) -> ArrayLike:
"""Returns the batch ID of each element in the full index list.
Returns
-------
Union[np.ndarray, torch.Tensor]
(N) Complete batch ID array, one per element
"""
return self._repeat(self._arange(self.batch_size), self.counts)
@property
def directed_index(self) -> ArrayLike:
"""Index of the directed graph. If a graph is undirected, it only
returns one of the two edges corresponding to a connection.
Returns
-------
Union[np.ndarray, torch.Tensor]
(2, E//2) Underlying batch of edge indexes
"""
# If the graph is directed, nothing to do
if self.directed:
return self.data
# Otherwise, skip every second edge in the index
return self.data[:, ::2]
@property
def directed_index_t(self) -> ArrayLike:
"""Index of the directed graph, transposed. If the graph is undirected,
it only returns one of the two edges corresponding to a connection.
Returns
-------
Union[np.ndarray, torch.Tensor]
(E//2, 2) Underlying batch of edge indexes, transposed
"""
return self._transpose(self.directed_index)
@property
def directed_counts(self) -> ArrayLike:
"""Returns the number of edges per entry, counting edges once even
if they are bidirectional.
Returns
-------
Union[np.ndarray, torch.Tensor]
(B) Complete batch ID array, one per element
"""
# If the graph is directed, the counts are exact
if self.directed:
return self.counts
# Otherwise, indexes are twice as long
return self.counts // 2
@property
def directed_batch_ids(self) -> ArrayLike:
"""Returns the batch ID of each element in the directed index.
Returns
-------
Union[np.ndarray, torch.Tensor]
(N) Complete batch ID array, one per element
"""
return self._repeat(self._arange(self.batch_size), self.directed_counts)
[docs]
def split(self) -> list[ArrayLike]:
"""Breaks up the index batch into its constituents.
Returns
-------
List[Union[np.ndarray, torch.Tensor]]
List of one index per entry in the batch
"""
indexes = list(self._split(self._transpose(self.index), self.splits))
for batch_id in range(self.batch_size):
indexes[batch_id] = indexes[batch_id] - self.offsets[batch_id]
return indexes
[docs]
def to_numpy(self) -> "EdgeIndexBatch":
"""Cast underlying index to a `np.ndarray` and return a new instance.
Returns
-------
TensorBatch
New `TensorBatch` object with an underlying np.ndarray tensor.
"""
# If the underlying data is of the right type, nothing to do
if self.is_numpy:
return self
data = self._to_numpy(self.data)
counts = self._to_numpy(self.counts)
spans = self._to_numpy(self.spans)
return EdgeIndexBatch(data, counts, spans, self.directed)
[docs]
def to_tensor(self, dtype: Any = None, device: Any = None) -> "EdgeIndexBatch":
"""Cast underlying index to a `torch.tensor` and return a new instance.
Parameters
----------
dtype : torch.dtype, optional
Data type of the tensor to create
device : torch.device, optional
Device on which to put the tensor
Returns
-------
TensorBatch
New `TensorBatch` object with an underlying np.ndarray tensor.
"""
# If the underlying data is of the right type, nothing to do
if not self.is_numpy:
return self
data = self._to_tensor(self.data, dtype, device)
counts = self._to_tensor(self.counts, dtype, device)
spans = self._to_tensor(self.spans, dtype, device)
return EdgeIndexBatch(data, counts, spans, self.directed)