"""Base interfaces for data augmentation modules."""
from abc import ABC, abstractmethod
from typing import Any
import numpy as np
from spine.data import Meta
from spine.geo import GeoManager
from spine.io.parse.data import ParserTensor
[docs]
class AugmentBase(ABC):
"""Base class for augmentation modules."""
name = ""
def __call__(
self,
data: dict[str, Any],
meta: Meta,
keys: list[str],
context: dict[str, Any],
) -> tuple[dict[str, Any], Meta]:
"""Apply an augmentation module.
Parameters
----------
data : dict
Dictionary of event data products to augment
meta : Meta
Shared image metadata
keys : List[str]
Keys corresponding to data products that carry coordinates
context : dict
Shared augmentation context built by the manager
Returns
-------
Tuple[Dict[str, Any], Meta]
Updated data dictionary and shared metadata
"""
return self.apply(data, meta, keys, context)
[docs]
@abstractmethod
def apply(
self,
data: dict[str, Any],
meta: Meta,
keys: list[str],
context: dict[str, Any],
) -> tuple[dict[str, Any], Meta]:
"""Apply an augmentation to one event.
Parameters
----------
data : dict
Dictionary of event data products to augment
meta : Meta
Shared image metadata
keys : List[str]
Keys corresponding to data products that carry coordinates
context : dict
Shared augmentation context built by the manager
Returns
-------
Tuple[Dict[str, Any], Meta]
Updated data dictionary and shared metadata
"""
[docs]
@staticmethod
def resolve_center(
meta: Meta,
center: np.ndarray | None = None,
use_geo_center: bool = False,
) -> np.ndarray:
"""Resolve the pivot center for a geometric transform.
Parameters
----------
meta : Meta
Current image metadata
center : np.ndarray, optional
Explicit center in detector coordinates (cm)
use_geo_center : bool, default False
If ``True``, use the detector TPC center from the geometry manager
Returns
-------
np.ndarray
(3,) Pivot center in detector coordinates (cm)
"""
if center is not None and use_geo_center:
raise ValueError("Cannot provide both `center` and `use_geo_center`.")
if center is not None:
center = np.asarray(center, dtype=np.float32)
if center.shape != (3,):
raise ValueError("Transform center must be a 3D point in cm.")
return center
if use_geo_center:
return GeoManager.get_instance().tpc.center.astype(np.float32)
return ((meta.lower + meta.upper) / 2.0).astype(np.float32)
[docs]
@staticmethod
def voxel_to_cm(coords: np.ndarray, meta: Meta) -> np.ndarray:
"""Convert voxel indices to detector coordinates at voxel centers.
Parameters
----------
coords : np.ndarray
``(N, 3)`` Array of voxel indices
meta : Meta
Metadata used to convert voxel indices to detector coordinates
Returns
-------
np.ndarray
``(N, 3)`` Detector coordinates in cm at voxel centers
"""
return meta.to_cm(coords, center=True)
[docs]
@staticmethod
def cm_to_voxel(coords_cm: np.ndarray, meta: Meta, dtype: np.dtype) -> np.ndarray:
"""Convert detector coordinates at voxel centers back to indices.
Parameters
----------
coords_cm : np.ndarray
``(N, 3)`` Detector coordinates in cm at voxel centers
meta : Meta
Metadata used to convert detector coordinates back to pixel space
dtype : np.dtype
Output dtype to use for the returned voxel indices
Returns
-------
np.ndarray
``(N, 3)`` Array of voxel indices
"""
return np.rint(meta.to_px(coords_cm) - 0.5).astype(dtype)
[docs]
@staticmethod
def parse_optional_vector(
value: float | list[float] | tuple[float, ...] | np.ndarray | None,
name: str,
) -> np.ndarray | None:
"""Parse an optional scalar-or-vector parameter into a length-3 array.
Parameters
----------
value : float or sequence or np.ndarray, optional
Input value to parse. Scalars are broadcast to all three axes.
name : str
Parameter name used in validation error messages.
Returns
-------
np.ndarray or None
Length-3 vector if a value is provided, otherwise ``None``
"""
if value is None:
return None
if np.isscalar(value):
scalar = float(np.asarray(value, dtype=np.float32).item())
array = np.full(3, scalar, dtype=np.float32)
else:
array = np.asarray(value, dtype=np.float32)
if array.shape != (3,):
raise ValueError(f"{name} must be a scalar or a length-3 vector.")
return array
[docs]
@staticmethod
def resolve_activity_center(
data: dict[str, Any],
keys: list[str],
meta: Meta,
weighted: bool = False,
feature_index: int = 0,
) -> np.ndarray:
"""Estimate an activity center from all coordinate-carrying tensors.
Parameters
----------
data : dict
Dictionary of event data products
keys : List[str]
Keys corresponding to data products that carry coordinates
meta : Meta
Shared image metadata
weighted : bool, default False
If ``True``, weight the center by the absolute feature value in the
requested feature column
feature_index : int, default 0
Feature column to use when ``weighted=True``
Returns
-------
np.ndarray
(3,) Activity center in detector coordinates (cm)
"""
center, _ = AugmentBase.resolve_activity_stats(
data,
keys,
meta,
weighted=weighted,
feature_index=feature_index,
)
return center
[docs]
@staticmethod
def resolve_activity_stats(
data: dict[str, Any],
keys: list[str],
meta: Meta,
weighted: bool = False,
feature_index: int = 0,
) -> tuple[np.ndarray, np.ndarray | None]:
"""Estimate activity center and spread from coordinate-carrying tensors.
Parameters
----------
data : dict
Dictionary of event data products
keys : List[str]
Keys corresponding to data products that carry coordinates
meta : Meta
Shared image metadata
weighted : bool, default False
If ``True``, weight the center and spread by the absolute feature
value in the requested feature column
feature_index : int, default 0
Feature column to use when ``weighted=True``
Returns
-------
Tuple[np.ndarray, np.ndarray or None]
``(3,)`` Activity center and standard deviation in detector
coordinates (cm). If no activity is available, the center falls
back to the metadata center and the spread is ``None``.
"""
coords_list = []
weights_list = []
for key in keys:
value = data.get(key)
if not isinstance(value, ParserTensor) or value.coords is None:
continue
if len(value.coords) == 0:
continue
coords_cm = meta.to_cm(value.coords, center=True)
coords_list.append(coords_cm)
if weighted:
features = np.asarray(value.features)
if features.ndim == 1:
weights = np.abs(features)
else:
column = min(feature_index, features.shape[1] - 1)
weights = np.abs(features[:, column])
weights_list.append(weights)
if not coords_list:
center = ((meta.lower + meta.upper) / 2.0).astype(np.float32)
return center, None
coords = np.vstack(coords_list)
if not weighted:
center = np.mean(coords, axis=0)
spread = np.std(coords, axis=0)
return center.astype(np.float32), spread.astype(np.float32)
weights = np.concatenate(weights_list).astype(np.float64).reshape(-1)
total_weight = np.sum(weights, dtype=np.float64, initial=0.0)
if np.allclose(total_weight, 0.0):
center = np.mean(coords, axis=0)
spread = np.std(coords, axis=0)
return center.astype(np.float32), spread.astype(np.float32)
weighted_coords = coords * weights[:, None]
center = (
np.sum(weighted_coords, axis=0, dtype=np.float64, initial=0.0)
/ total_weight
)
weighted_variance = ((coords - center) ** 2) * weights[:, None]
variance = (
np.sum(weighted_variance, axis=0, dtype=np.float64, initial=0.0)
/ total_weight
)
spread = np.sqrt(variance)
return center.astype(np.float32), spread.astype(np.float32)
[docs]
@staticmethod
def sample_box_lower(
lower: np.ndarray,
upper: np.ndarray,
dimensions: np.ndarray,
anchor: np.ndarray | None = None,
spread: np.ndarray | None = None,
) -> np.ndarray:
"""Sample the lower corner of a crop/mask box.
Parameters
----------
lower : np.ndarray
Lower detector bounds of the allowed sampling region in cm
upper : np.ndarray
Upper detector bounds of the allowed sampling region in cm
dimensions : np.ndarray
Requested crop or mask box dimensions in cm
anchor : np.ndarray, optional
Preferred box center in cm. If provided, sampling is biased around
this center.
spread : np.ndarray, optional
Standard deviation of the Gaussian proposal in cm when sampling
around an anchor. If not provided, a fraction of the available
range is used.
Returns
-------
np.ndarray
Lower detector corner of the sampled box in cm
If an anchor is provided, sample the box center around it with a normal
distribution and clamp to the valid range. Otherwise use a uniform draw.
"""
max_lower = upper - dimensions
if anchor is None:
return lower + np.random.rand(3) * (max_lower - lower)
center_lower = lower + dimensions / 2.0
center_upper = upper - dimensions / 2.0
anchor = np.clip(anchor, center_lower, center_upper)
if spread is None:
spread = 0.25 * (center_upper - center_lower)
sampled_center = np.random.normal(loc=anchor, scale=spread)
sampled_center = np.clip(sampled_center, center_lower, center_upper)
return sampled_center - dimensions / 2.0