"""Module with a data class objects which represent output interactions."""
from dataclasses import dataclass, field
from typing import cast
from warnings import warn
import numpy as np
from spine.constants import PID_LABELS, PID_TAGS, SHOWR_SHP
from spine.data.decorator import stored_property
from spine.data.field import FieldMetadata
from spine.data.larcv.neutrino import Neutrino
from .base import OutBase, RecoBase, TruthBase
from .particle import ParticleBase, RecoParticle, TruthParticle
__all__ = ["RecoInteraction", "TruthInteraction"]
@dataclass(eq=False, repr=False)
class InteractionBase(OutBase):
"""Base interaction-specific information.
Attributes
----------
particles: List[ParticleBase]
List of particles in the interaction (defined in subclasses with specific types)
primary_particles: List[ParticleBase]
List of primary particles associated with the interaction
particle_ids : np.ndarray
List of Particle IDs that make up this interaction
primary_particle_ids : np.ndarray
List of primary Particle IDs associated with this interaction
num_particles : int
Number of particles that make up this interaction
num_primary_particles : int
Number of primary particles associated with this interaction
particle_counts : np.ndarray
(P) Number of particles of each species in this interaction
primary_particle_counts : np.ndarray
(P) Number of primary particles of each species in this interaction
vertex : np.ndarray
(3) Coordinates of the interaction vertex
is_fiducial : bool
Whether this interaction vertex is inside the fiducial volume
is_flash_matched : bool
True if the interaction was matched to an optical flash
flash_ids : np.ndarray
(F) Indices of the optical flashes the interaction is matched to
flash_volume_ids : np.ndarray
(F) Indices of the optical volumes the flashes where recorded in
flash_times : np.ndarray
(F) Times at which the flashes occurred in microseconds
flash_scores : np.ndarray
(F) Flash matching quality scores reported for each match
flash_total_pe : float
Total number of photoelectrons associated with the flash
flash_hypo_pe : float
Total number of photoelectrons expected to be produced by the interaction
is_crt_matched : bool
True if any particle in the interaction was matched to a CRT hit
crt_ids : np.ndarray
(C) Indices of the CRT hits the interaction is matched to
crt_times : np.ndarray
(C) Times at which the CRT hits occurred in microseconds
crt_scores : np.ndarray
(C) Quality metric associated with the CRT matches
topology : str
String representing the interaction topology
"""
# Scalar attributes
is_fiducial: bool = False
is_flash_matched: bool = False
flash_total_pe: float = np.nan
flash_hypo_pe: float = np.nan
# Object list attributes
# Note: Subclasses override this with specific List[RecoParticle/TruthParticle]
particles: list[ParticleBase] = field(
default_factory=list,
metadata=FieldMetadata(skip=True),
)
# Vector attributes
particle_ids: np.ndarray = field(
default_factory=lambda: np.empty(0, dtype=np.int32),
metadata=FieldMetadata(dtype=np.int32),
)
vertex: np.ndarray = field(
default_factory=lambda: np.full(3, np.nan, dtype=np.float32),
metadata=FieldMetadata(
length=3,
dtype=np.float32,
position=True,
units="instance",
),
)
flash_ids: np.ndarray = field(
default_factory=lambda: np.empty(0, dtype=np.int32),
metadata=FieldMetadata(dtype=np.int32),
)
flash_volume_ids: np.ndarray = field(
default_factory=lambda: np.empty(0, dtype=np.int32),
metadata=FieldMetadata(dtype=np.int32),
)
flash_times: np.ndarray = field(
default_factory=lambda: np.empty(0, dtype=np.float32),
metadata=FieldMetadata(dtype=np.float32, units="us"),
)
flash_scores: np.ndarray = field(
default_factory=lambda: np.empty(0, dtype=np.float32),
metadata=FieldMetadata(dtype=np.float32),
)
def __str__(self) -> str:
"""Human-readable string representation of the interaction object.
Results
-------
str
Basic information about the interaction properties
"""
match = self.match_ids[0] if len(self.match_ids) > 0 else -1
info = (
f"Interaction(ID: {self.id:<3} "
f"| Size: {self.size:<5} | Topology: {self.topology:<10} "
f"| Match: {match:<3})"
)
if len(self.particles) > 0:
info += "\n" + len(info) * "-"
for particle in self.particles:
info += "\n" + str(particle)
return info
def reset_flash_match(self) -> None:
"""Reset all the flash matching attributes."""
self.is_flash_matched = False
self.flash_total_pe = np.nan
self.flash_hypo_pe = np.nan
self.flash_ids = np.empty(0, dtype=np.int32)
self.flash_volume_ids = np.empty(0, dtype=np.int32)
self.flash_times = np.empty(0, dtype=np.float32)
self.flash_scores = np.empty(0, dtype=np.float32)
@property
def primary_particles(self) -> list[ParticleBase]:
"""List of primary particles associated with this interaction.
Returns
-------
List[ParticleBase]
List of primary Particle objects associated with this interaction
"""
return [part for part in self.particles if part.is_primary]
@property
@stored_property(dtype=np.int32)
def primary_particle_ids(self) -> np.ndarray:
"""List of primary Particle IDs associated with this interaction.
Returns
-------
np.darray
List of primary Particle IDs associated with this interaction
"""
return np.array([part.id for part in self.primary_particles], dtype=np.int32)
@property
@stored_property
def num_particles(self) -> int:
"""Number of particles that make up this interaction.
Returns
-------
int
Number of particles that make up the interaction instance
"""
return len(self.particle_ids)
@property
@stored_property
def num_primary_particles(self) -> int:
"""Number of primary particles associated with this interaction.
Returns
-------
int
Number of primary particles associated with the interaction instance
"""
return len(self.primary_particle_ids)
@property
@stored_property(dtype=np.int32, length=len(PID_LABELS) - 1)
def particle_counts(self) -> np.ndarray:
"""Number of particles of each PID species in this interaction.
Returns
-------
np.ndarray
(P) Number of particles of each PID
"""
counts = np.zeros(len(PID_LABELS) - 1, dtype=np.int32)
for part in self.particles:
if part.pid > -1 and part.is_valid:
counts[part.pid] += 1
return counts
@property
@stored_property(dtype=np.int32, length=len(PID_LABELS) - 1)
def primary_particle_counts(self) -> np.ndarray:
"""Number of primary particles of each PID species in this interaction.
Returns
-------
np.ndarray
(P) Number of primary particles of each PID
"""
counts = np.zeros(len(PID_LABELS) - 1, dtype=np.int32)
for part in self.primary_particles:
if part.pid > -1 and part.is_valid:
counts[part.pid] += 1
return counts
@property
@stored_property
def is_crt_matched(self) -> bool:
"""Checks if any particle in the interaction was matched to a CRT hit.
Returns
-------
bool
`True` if any of the particle was matched to a CRT hit
"""
return bool(np.any([part.is_crt_matched for part in self.particles]))
@property
@stored_property(dtype=np.int32)
def crt_ids(self) -> np.ndarray:
"""Returns the list of CRT hit IDs matched to this interaction.
Returns
-------
np.ndarray
(C) List of CRT hit IDs matched to this interaction
"""
if len(self.particles) > 0:
return np.concatenate([part.crt_ids for part in self.particles])
return np.empty(0, dtype=np.int32)
@property
@stored_property(dtype=np.float32, units="us")
def crt_times(self) -> np.ndarray:
"""Returns the list of CRT hit times matched to this interaction.
Returns
-------
np.ndarray
(C) List of CRT hit times matched to this interaction
"""
if len(self.particles) > 0:
return np.concatenate([part.crt_times for part in self.particles])
return np.empty(0, dtype=np.float32)
@property
@stored_property(dtype=np.float32)
def crt_scores(self) -> np.ndarray:
"""Returns the list of quality metrics of CRT hits matched to this interaction.
Returns
-------
np.ndarray
(C) List of quality metrics of CRT hits matched to this interaction
"""
if len(self.particles) > 0:
return np.concatenate([part.crt_scores for part in self.particles])
return np.empty(0, dtype=np.float32)
@property
@stored_property
def topology(self) -> str:
"""String representing the interaction topology.
Returns
-------
str
String listing the number of primary particles in this interaction
"""
topology = ""
for i, count in enumerate(self.primary_particle_counts):
if count > 0:
topology += f"{count}{PID_TAGS[i]}"
return topology
@classmethod
def from_particles(cls, particles: list[ParticleBase]):
"""Builds an Interaction instance from its constituent Particle objects.
Parameters
----------
particles : List[ParticleBase]
List of Particle objects that make up the Interaction
Returns
-------
InteractionBase
Interaction built from the particle list
"""
# Construct interaction object
interaction = cls()
# Fill unique attributes which must be shared between particles
unique_attrs = ("is_truth", "units")
for attr in unique_attrs:
if hasattr(particles[0], attr):
if len(np.unique([getattr(p, attr) for p in particles])) >= 2:
raise ValueError(f"{attr} must be unique in the list of particles.")
setattr(interaction, attr, getattr(particles[0], attr))
# Attach particle list
interaction.particles = particles
interaction.particle_ids = np.array([p.id for p in particles])
# Build long-form attributes
for attr in interaction._cat_attrs:
val_list = [getattr(p, attr) for p in particles]
setattr(interaction, attr, np.concatenate(val_list))
return interaction
[docs]
@dataclass(eq=False, repr=False)
class RecoInteraction(InteractionBase, RecoBase):
"""Reconstructed interaction information.
Attributes
----------
particles : List[RecoParticle]
List of particles that make up the interaction
"""
# Object list attributes
particles: list[RecoParticle] = field( # type: ignore[assignment]
default_factory=lambda: [],
metadata=FieldMetadata(skip=True),
)
def __str__(self):
"""Human-readable string representation of the interaction object.
Results
-------
str
Basic information about the interaction properties
"""
return "Reco" + super().__str__()
@property
def leading_shower(self) -> RecoParticle | None:
"""Leading primary shower of this interaction.
Returns
-------
RecoParticle
Primary shower with the highest kinetic energy
"""
showers = [
part
for part in self.particles
if part.is_primary and part.shape == SHOWR_SHP
]
if len(showers) == 0:
return None
return max(showers, key=lambda x: cast(float, x.ke))
[docs]
@dataclass(eq=False, repr=False)
class TruthInteraction(Neutrino, InteractionBase, TruthBase):
"""Truth interaction information.
This inherits all of the attributes of :class:`Interaction`, which contains
the G4 truth information for the interaction.
Attributes
----------
particles : List[TruthParticle]
List of particles that make up the interaction
nu_id : int
Index of the neutrino matched to this interaction
reco_vertex : np.ndarray
(3) Coordinates of the reconstructed interaction vertex
"""
# Scalar attributes
nu_id: int = -1
# Object list attributes
particles: list[TruthParticle] = field( # type: ignore[assignment]
default_factory=lambda: [],
metadata=FieldMetadata(skip=True),
)
# Vector attributes
reco_vertex: np.ndarray = field(
default_factory=lambda: np.full(3, np.nan, dtype=np.float32),
metadata=FieldMetadata(
length=3,
dtype=np.float32,
position=True,
units="instance",
),
)
def __str__(self) -> str:
"""Human-readable string representation of the interaction object.
Results
-------
str
Basic information about the interaction properties
"""
return "Truth" + super().__str__()
[docs]
def attach_neutrino(self, neutrino) -> None:
"""Attach neutrino generator information to this interaction.
Parameters
----------
neutrino : Neutrino
Neutrino to fetch the attributes from
"""
# Transfer all the neutrino attributes
for attr, val in neutrino.as_dict(include_derived=False).items():
if attr != "id":
setattr(self, attr, val)
else:
if neutrino.id != self.nu_id:
warn(
"The neutrino ID as stored in the larcv.Neutrino "
"object does not match its index."
)
# Set the interaction vertex position
self.vertex = neutrino.position