Source code for spine.construct.fragment

"""Classes in charge of constructing Fragment objects."""

from __future__ import annotations

from typing import Any

import numpy as np
from scipy.special import softmax

from spine.constants import CLUST_COL, PART_COL, TRACK_SHP
from spine.data.out import RecoFragment, TruthFragment

from .base import BuilderBase

__all__ = ["FragmentBuilder"]


[docs] class FragmentBuilder(BuilderBase): """Builds reconstructed and truth fragments. It takes the raw output of the reconstruction chain, extracts the necessary information and builds :class:`RecoFragment` and :class:`TruthFragment` objects from it. """ # Builder name name = "fragment" # Types of objects constructed by the builder _reco_type = RecoFragment _truth_type = TruthFragment # Necessary/optional data products to build a reconstructed object _build_reco_keys = ( ("fragment_clusts", True), ("fragment_shapes", True), ("fragment_start_points", False), ("fragment_end_points", False), ("fragment_group_pred", False), ("fragment_node_pred", False), *BuilderBase._build_reco_keys, ) # Necessary/optional data products to build a truth object _build_truth_keys = (("particles", False), *BuilderBase._build_truth_keys) # Necessary/optional data products to load a reconstructed object _load_reco_keys = (("reco_fragments", True), *BuilderBase._load_reco_keys) # Necessary/optional data products to load a truth object _load_truth_keys = (("truth_fragments", True), *BuilderBase._load_truth_keys)
[docs] def build_reco(self, data: dict[str, Any]) -> list[RecoFragment]: """Builds :class:`RecoFragment` objects from the full chain output. Parameters ---------- data : dict Dictionary of data products Returns ------- List[RecoFragment] List of constructed reconstructed fragment instances """ return self._build_reco(**data)
def _build_reco( self, points: np.ndarray, depositions: np.ndarray, fragment_clusts: list[np.ndarray], fragment_shapes: np.ndarray, fragment_start_points: np.ndarray | None = None, fragment_end_points: np.ndarray | None = None, fragment_group_pred: np.ndarray | None = None, fragment_node_pred: np.ndarray | None = None, sources: np.ndarray | None = None, orig_index: np.ndarray | None = None, ) -> list[RecoFragment]: """Builds :class:`RecoFragment` objects from the full chain output. Parameters ---------- points : np.ndarray (N, 3) Set of deposition coordinates in the image depositions : np.ndarray (N) Set of deposition values fragment_clusts : List[np.ndarray] (P) List of indexes, each corresponding to a fragment instance fragment_shapes : np.ndarray (P) List of fragment shapes (shower, track, etc.) fragment_start_points : np.ndarray, optional (P, 3) List of fragment start point coordinates fragment_end_points : np.ndarray, optional (P, 3) List of fragment end point coordinates fragment_group_pred : np.ndarray, optional (P) Interaction group each fragment belongs to sources : np.ndarray, optional (N, 2) Tensor which contains the module/tpc information orig_index : np.ndarray, optional (N) Tensor which contains the indexes in the original point cloud (before any filtering or deghosting) Returns ------- List[RecoFragment] List of constructed reconstructed fragment instances """ # Convert the logits to softmax scores and the scores to a prediction primary_info: tuple[np.ndarray, np.ndarray] | None = None if fragment_node_pred is not None: primary_scores = softmax(fragment_node_pred, axis=1) primary_pred = np.argmax(primary_scores, axis=1) primary_info = (primary_scores, primary_pred) # Loop over the fragment instances reco_fragments = [] for i, index in enumerate(fragment_clusts): # Initialize fragment = RecoFragment( id=i, shape=fragment_shapes[i], index=index, points=points[index], depositions=depositions[index], ) # Add optional arguments if sources is not None: fragment.sources = sources[index] if orig_index is not None: fragment.orig_index = orig_index[index] if fragment_start_points is not None: fragment.start_point = fragment_start_points[i] if fragment_end_points is not None and fragment.shape == TRACK_SHP: fragment.end_point = fragment_end_points[i] if fragment_group_pred is not None: fragment.particle_id = fragment_group_pred[i] if primary_info is not None: primary_scores, primary_pred = primary_info fragment.primary_scores = primary_scores[i] fragment.is_primary = bool(primary_pred[i]) # Append reco_fragments.append(fragment) return reco_fragments
[docs] def build_truth(self, data: dict[str, Any]) -> list[TruthFragment]: """Builds :class:`TruthFragment` objects from the full chain output. Parameters ---------- data : dict Dictionary of data products Returns ------- List[TruthFragment] List of constructed true fragment instances """ return self._build_truth(**data)
def _build_truth( self, label_tensor: np.ndarray, points_label: np.ndarray, depositions_label: np.ndarray, depositions_q_label: np.ndarray | None = None, label_adapt_tensor: np.ndarray | None = None, points: np.ndarray | None = None, depositions: np.ndarray | None = None, label_g4_tensor: np.ndarray | None = None, points_g4: np.ndarray | None = None, depositions_g4: np.ndarray | None = None, sources_label: np.ndarray | None = None, sources: np.ndarray | None = None, orig_index_label: np.ndarray | None = None, particles: list[Any] | None = None, ) -> list[TruthFragment]: """Builds :class:`TruthFragment` objects from the full chain output. Parameters ---------- label_tensor : np.ndarray Tensor which contains the cluster labels of each deposition points_label : np.ndarray (N', 3) Set of deposition coordinates in the label image (identical for pixel TPCs, different if deghosting is involved) depositions_label : np.ndarray (N') Set of true deposition values in MeV depositions_q_label : np.ndarray, optional (N') Set of true deposition values in ADC, if relevant label_adapt_tensor : np.ndarray, optional Tensor which contains the cluster labels of each deposition, adapted to the semantic segmentation prediction. points : np.ndarray, optional (N, 3) Set of deposition coordinates in the image depositions : np.ndarray, optional (N) Set of deposition values label_tensor_g4 : np.ndarray, optional Tensor which contains the cluster labels of each deposition in the Geant4 image (before the detector simulation) points_g4 : np.ndarray, optional (N'', 3) Set of deposition coordinates in the Geant4 image depositions_g4 : np.ndarray, optional (N'') Set of deposition values in the Geant4 image sources_label : np.ndarray, optional (N', 2) Tensor which contains the label module/tpc information sources : np.ndarray, optional (N, 2) Tensor which contains the module/tpc information orig_index_label : np.ndarray, optional (N') Tensor which contains the indexes in the original point cloud (before any filtering or deghosting) particles : List[Particle], optional List of true particles Returns ------- List[TruthFragment] List of constructed true fragment instances """ # Check once if the fragment labels have been altered broken = (label_tensor[:, CLUST_COL] != label_tensor[:, PART_COL]).any() truth_only = label_adapt_tensor is None # If the adapted labels are available (the full chain was run), use # those as a basis to form fragments (fragments depend on upstream # segmentation). Use original labels otherwise (pure truth mode) truth_fragments = [] ref_tensor = label_tensor if truth_only else label_adapt_tensor unique_fragment_ids = np.unique(ref_tensor[:, CLUST_COL]).astype(int) valid_fragment_ids = unique_fragment_ids[unique_fragment_ids > -1] for i, frag_id in enumerate(valid_fragment_ids): # Initialize fragment fragment = TruthFragment(id=i) # Find the particle which matches this fragment best index_ref = np.where(ref_tensor[:, CLUST_COL] == frag_id)[0] if particles is not None: part_id = frag_id if not truth_only or broken: part_ids, counts = np.unique( ref_tensor[index_ref, PART_COL], return_counts=True ) part_id = int(part_ids[np.argmax(counts)]) if part_id > -1: # Load the MC particle information if part_id >= len(particles): raise ValueError( "Invalid particle ID found in fragment labels." ) particle = particles[part_id] fragment = TruthFragment(**particle.as_dict(include_derived=False)) # Override the indexes of the fragment but preserve them fragment.orig_id = part_id fragment.orig_group_id = particle.group_id fragment.orig_parent_id = particle.parent_id fragment.orig_interaction_id = particle.interaction_id fragment.orig_children_id = particle.children_id fragment.id = i fragment.group_id = i fragment.parent_id = i fragment.children_id = np.empty( 0, dtype=fragment.orig_children_id.dtype ) # Fill long-form attributes if truth_only: # Fill the true long-form attributes, if there was no adaptation fragment.index = index_ref fragment.points = points_label[index_ref] fragment.depositions = depositions_label[index_ref] if depositions_q_label is not None: fragment.depositions_q = depositions_q_label[index_ref] if sources_label is not None: fragment.sources = sources_label[index_ref] if orig_index_label is not None: fragment.orig_index = orig_index_label[index_ref] # If the fragments are not broken, can match to G4 info if not broken: # If available, append the Geant4 information if label_g4_tensor is not None: if points_g4 is None or depositions_g4 is None: raise ValueError( "Geant4 points and depositions must be provided " "if label_g4_tensor is given." ) index_g4 = np.where(label_g4_tensor[:, CLUST_COL] == frag_id)[0] fragment.index_g4 = index_g4 fragment.points_g4 = points_g4[index_g4] fragment.depositions_g4 = depositions_g4[index_g4] else: # Fill the adapted long-form attributes otherwise if points is None or depositions is None: raise ValueError( "Points and depositions must be provided to build " "adapted truth fragments." ) fragment.index_adapt = index_ref fragment.points_adapt = points[index_ref] fragment.depositions_adapt = depositions[index_ref] if sources is not None: fragment.sources_adapt = sources[index_ref] # Append truth_fragments.append(fragment) return truth_fragments
[docs] def load_reco(self, data: dict[str, Any]) -> list[RecoFragment]: """Load :class:`RecoFragment` objects from their stored versions. Parameters ---------- data : dict Dictionary of data products Returns ------- List[RecoFragment] List of restored reconstructed fragment instances """ return self._load_reco(**data)
def _load_reco( self, reco_fragments: list[RecoFragment], points: np.ndarray | None = None, depositions: np.ndarray | None = None, sources: np.ndarray | None = None, ) -> list[RecoFragment]: """Load :class:`RecoFragment` objects from their stored versions. Parameters ---------- reco_fragments : List[RecoFragment] (F) List of partial reconstructed fragments points : np.ndarray, optional (N, 3) Set of deposition coordinates in the image depositions : np.ndarray, optional (N) Set of deposition values sources : np.ndarray, optional (N, 2) Tensor which contains the module/tpc information Returns ------- List[RecoFragment] List of restored reconstructed fragment instances """ # Loop over the dictionaries for i, fragment in enumerate(reco_fragments): # Check that the fragment ID checks out if fragment.id != i: raise ValueError("The ordering of the stored fragments is wrong.") # Update the fragment with its long-form attributes if points is not None: if depositions is None: raise ValueError( "Depositions must be provided to load reco fragments if " "points are provided." ) fragment.points = points[fragment.index] fragment.depositions = depositions[fragment.index] if sources is not None: fragment.sources = sources[fragment.index] return reco_fragments
[docs] def load_truth(self, data: dict[str, Any]) -> list[TruthFragment]: """Load :class:`TruthFragment` objects from their stored versions. Parameters ---------- data : dict Dictionary of data products Returns ------- List[TruthFragment] List of restored true fragment instances """ return self._load_truth(**data)
def _load_truth( self, truth_fragments: list[TruthFragment], points_label: np.ndarray | None = None, depositions_label: np.ndarray | None = None, depositions_q_label: np.ndarray | None = None, points: np.ndarray | None = None, depositions: np.ndarray | None = None, points_g4: np.ndarray | None = None, depositions_g4: np.ndarray | None = None, sources_label: np.ndarray | None = None, sources: np.ndarray | None = None, ) -> list[TruthFragment]: """Load :class:`TruthFragment` objects from their stored versions. Parameters ---------- truth_fragments : List[TruthFragment] (F) List of partial truth fragments points_label : np.ndarray, optional (N', 3) Set of deposition coordinates in the label image (identical for pixel TPCs, different if deghosting is involved) depositions_label : np.ndarray, optional (N') Set of true deposition values in MeV depositions_q_label : np.ndarray, optional (N') Set of true deposition values in ADC, if relevant points : np.ndarray, optional (N, 3) Set of deposition coordinates in the image depositions : np.ndarray, optional (N) Set of deposition values points_g4 : np.ndarray, optional (N'', 3) Set of deposition coordinates in the Geant4 image depositions_g4 : np.ndarray, optional (N'') Set of deposition values in the Geant4 image sources : np.ndarray, optional (N, 2) Tensor which contains the module/tpc information sources_label : np.ndarray, optional (N', 2) Tensor which contains the label module/tpc information Returns ------- List[TruthFragment] List of restored true fragment instances """ # Loop over the dictionaries for i, fragment in enumerate(truth_fragments): # Check that the fragment ID checks out if fragment.id != i: raise ValueError("The ordering of the stored fragments is wrong.") # Update the fragment with its long-form attributes if points_label is not None: if depositions_label is None: raise ValueError( "Depositions must be provided to load truth fragments if " "label points are provided." ) fragment.points = points_label[fragment.index] fragment.depositions = depositions_label[fragment.index] if depositions_q_label is not None: fragment.depositions_q = depositions_q_label[fragment.index] if sources_label is not None: fragment.sources = sources_label[fragment.index] if points is not None: if depositions is None: raise ValueError( "Depositions must be provided to load adapted truth " "fragments if points are provided." ) fragment.points_adapt = points[fragment.index_adapt] fragment.depositions_adapt = depositions[fragment.index_adapt] if sources is not None: fragment.sources_adapt = sources[fragment.index_adapt] if points_g4 is not None: if depositions_g4 is None: raise ValueError( "Depositions must be provided to load Geant4 truth " "fragments if points are provided." ) fragment.points_g4 = points_g4[fragment.index_g4] fragment.depositions_g4 = depositions_g4[fragment.index_g4] return truth_fragments