Source code for spine.construct.interaction

"""Class in charge of constructing Interaction objects."""

from __future__ import annotations

from typing import Any
from warnings import warn

import numpy as np

from spine.data.out import RecoInteraction, TruthInteraction

from .base import BuilderBase

__all__ = ["InteractionBuilder"]


[docs] class InteractionBuilder(BuilderBase): """Builds reconstructed and truth interactions. It takes the raw output of the reconstruction chain, extracts the necessary information and builds :class:`RecoInteraction` and :class:`TruthInteraction` objects from it. """ # Builder name name = "interaction" # Types of objects constructed by the builder _reco_type = RecoInteraction _truth_type = TruthInteraction # Necessary/optional data products to build a reconstructed object _build_reco_keys = (("reco_particles", True),) # Necessary/optional data products to build a truth object _build_truth_keys = (("truth_particles", True), ("neutrinos", False)) # Necessary/optional data products to load a reconstructed object _load_reco_keys = (("reco_interactions", True), ("reco_particles", True)) # Necessary/optional data products to load a truth object _load_truth_keys = (("truth_interactions", True), ("truth_particles", True))
[docs] def build_reco(self, data: dict[str, Any]) -> list[RecoInteraction]: """Builds :class:`RecoInteraction` objects from the full chain output. Parameters ---------- data : dict Dictionary of data products Returns ------- List[RecoInteraction] List of constructed reconstructed interaction instances """ return self._build_reco(**data)
def _build_reco(self, reco_particles: list[Any]) -> list[RecoInteraction]: """Builds :class:`RecoInteraction` objects from the full chain output. This class builds an interaction by assembling particles together. Parameters ---------- reco_particles : List[RecoParticle] List of reconstructed particle objects Returns ------- List[RecoInteraction] List of constructed reconstructed interaction instances """ # Loop over unique interaction IDs reco_interactions = [] inter_ids = np.array([part.interaction_id for part in reco_particles]) for i, inter_id in enumerate(np.unique(inter_ids)): # Get the list of particles associates with this interaction if inter_id <= -1: raise ValueError("Invalid reconstructed interaction ID found.") particle_ids = np.where(inter_ids == inter_id)[0] inter_particles = [reco_particles[j] for j in particle_ids] # Build interaction interaction = RecoInteraction.from_particles(inter_particles) interaction.id = i # Match the interaction ID of the constituent particles for part in inter_particles: part.interaction_id = i for frag in part.fragments: frag.interaction_id = i # Append reco_interactions.append(interaction) return reco_interactions
[docs] def build_truth(self, data: dict[str, Any]) -> list[TruthInteraction]: """Builds :class:`TruthInteraction` objects from the full chain output. Parameters ---------- data : dict Dictionary of data products Returns ------- List[TruthInteraction] List of constructed truth interaction instances """ return self._build_truth(**data)
def _build_truth( self, truth_particles: list[Any], neutrinos: list[Any] | None = None, ) -> list[TruthInteraction]: """Builds :class:`TruthInteraction` objects from the full chain output. This class builds an interaction by assembling particles together. Parameters ---------- truth_particles : List[TruthParticle] List of truth particle objects neutrinos : List[Neutrino], optional List of true neutrino information from the generator """ # Loop over unique interaction IDs truth_interactions = [] inter_ids = np.array([part.interaction_id for part in truth_particles]) unique_inter_ids = np.unique(inter_ids) valid_inter_ids = unique_inter_ids[unique_inter_ids > -1] for i, inter_id in enumerate(valid_inter_ids): # Get the list of particles associates with this interaction particle_ids = np.where(inter_ids == inter_id)[0] inter_particles = [truth_particles[j] for j in particle_ids] # Build interaction interaction = TruthInteraction.from_particles(inter_particles) interaction.id = i interaction.orig_id = inter_id # Match the interaction ID of the constituent particles for part in inter_particles: part.orig_interaction_id = inter_id part.interaction_id = i for frag in part.fragments: frag.interaction_id = i # Append the neutrino information, if it is provided nu_ids = [part.nu_id for part in inter_particles] if len(np.unique(nu_ids)) != 1: raise ValueError( "Interaction made up of particles with different " "neutrino IDs. Must be unique." ) interaction.nu_id = nu_ids[0] if neutrinos is not None and nu_ids[0] > -1: if nu_ids[0] >= len(neutrinos): raise ValueError( "Invalid neutrino ID found in truth interaction particles." ) interaction.attach_neutrino(neutrinos[nu_ids[0]]) else: anc_pos = [part.ancestor_position for part in inter_particles] anc_pos = np.unique(anc_pos, axis=0) if len(anc_pos) != 1: warn( "Particles making up a true interaction have " "different ancestor positions." ) anc_pos = np.max(anc_pos, axis=0) interaction.vertex = anc_pos.flatten() # Append truth_interactions.append(interaction) return truth_interactions
[docs] def load_reco(self, data: dict[str, Any]) -> list[RecoInteraction]: """Load :class:`RecoInteraction` objects from their stored versions. Parameters ---------- data : dict Dictionary of data products Returns ------- List[RecoInteraction] List of restored reconstructed interaction instances """ return self._load_reco(**data)
def _load_reco( self, reco_interactions: list[RecoInteraction], reco_particles: list[Any], ) -> list[RecoInteraction]: """Load :class:`RecoInteraction` objects from their stored versions. Parameters ---------- reco_interactions : List[RecoInteraction] List of partial reconstructed interaction objects reco_particles : List[RecoParticle] List of reconstructed particle objects Returns ------- List[RecoInteraction] List of restored reconstructed interaction instances """ # Loop over the dictionaries for i, interaction in enumerate(reco_interactions): # Check that the interaction ID checks out if interaction.id != i: raise ValueError("The ordering of the stored interactions is wrong.") # Fetch and assign the list of particles matched to this interaction inter_particles = [reco_particles[j] for j in interaction.particle_ids] if not len(inter_particles): raise ValueError("Every interaction should contain >= 1 particle.") interaction.particles = inter_particles # Update the interaction with its long-form attributes for attr in interaction._cat_attrs: val_list = [getattr(part, attr) for part in inter_particles] setattr(interaction, attr, np.concatenate(val_list)) return reco_interactions
[docs] def load_truth(self, data: dict[str, Any]) -> list[TruthInteraction]: """Load :class:`TruthInteraction` objects from their stored versions. Parameters ---------- data : dict Dictionary of data products Returns ------- List[TruthInteraction] List of restored truth interaction instances """ return self._load_truth(**data)
def _load_truth( self, truth_interactions: list[TruthInteraction], truth_particles: list[Any], ) -> list[TruthInteraction]: """Load :class:`TruthInteraction` objects from their stored versions. Parameters ---------- data : dict Dictionary of data products Parameters ---------- truth_interactions : List[TruthInteraction] List of partial truth interaction objects truth_particles : List[TruthParticle] List of truth particle objects Returns ------- List[TruthInteraction] List of restored truth interaction instances """ # Loop over the dictionaries for i, interaction in enumerate(truth_interactions): # Check that the interaction ID checks out if interaction.id != i: raise ValueError("The ordering of the stored interactions is wrong.") # Fetch and assign the list of particles matched to this interaction inter_particles = [truth_particles[j] for j in interaction.particle_ids] if not len(inter_particles): raise ValueError("Every interaction should contain >= 1 particle.") interaction.particles = inter_particles # Update the interaction with its long-form attributes for attr in interaction._cat_attrs: val_list = [getattr(part, attr) for part in inter_particles] setattr(interaction, attr, np.concatenate(val_list)) return truth_interactions