"""Classes in charge of constructing Fragment objects."""
from __future__ import annotations
from typing import Any
import numpy as np
from scipy.special import softmax
from spine.constants import CLUST_COL, PART_COL, TRACK_SHP
from spine.data.out import RecoFragment, TruthFragment
from .base import BuilderBase
__all__ = ["FragmentBuilder"]
[docs]
class FragmentBuilder(BuilderBase):
"""Builds reconstructed and truth fragments.
It takes the raw output of the reconstruction chain, extracts the
necessary information and builds :class:`RecoFragment` and
:class:`TruthFragment` objects from it.
"""
# Builder name
name = "fragment"
# Types of objects constructed by the builder
_reco_type = RecoFragment
_truth_type = TruthFragment
# Necessary/optional data products to build a reconstructed object
_build_reco_keys = (
("fragment_clusts", True),
("fragment_shapes", True),
("fragment_start_points", False),
("fragment_end_points", False),
("fragment_group_pred", False),
("fragment_node_pred", False),
*BuilderBase._build_reco_keys,
)
# Necessary/optional data products to build a truth object
_build_truth_keys = (("particles", False), *BuilderBase._build_truth_keys)
# Necessary/optional data products to load a reconstructed object
_load_reco_keys = (("reco_fragments", True), *BuilderBase._load_reco_keys)
# Necessary/optional data products to load a truth object
_load_truth_keys = (("truth_fragments", True), *BuilderBase._load_truth_keys)
[docs]
def build_reco(self, data: dict[str, Any]) -> list[RecoFragment]:
"""Builds :class:`RecoFragment` objects from the full chain output.
Parameters
----------
data : dict
Dictionary of data products
Returns
-------
List[RecoFragment]
List of constructed reconstructed fragment instances
"""
return self._build_reco(**data)
def _build_reco(
self,
points: np.ndarray,
depositions: np.ndarray,
fragment_clusts: list[np.ndarray],
fragment_shapes: np.ndarray,
fragment_start_points: np.ndarray | None = None,
fragment_end_points: np.ndarray | None = None,
fragment_group_pred: np.ndarray | None = None,
fragment_node_pred: np.ndarray | None = None,
sources: np.ndarray | None = None,
orig_index: np.ndarray | None = None,
) -> list[RecoFragment]:
"""Builds :class:`RecoFragment` objects from the full chain output.
Parameters
----------
points : np.ndarray
(N, 3) Set of deposition coordinates in the image
depositions : np.ndarray
(N) Set of deposition values
fragment_clusts : List[np.ndarray]
(P) List of indexes, each corresponding to a fragment instance
fragment_shapes : np.ndarray
(P) List of fragment shapes (shower, track, etc.)
fragment_start_points : np.ndarray, optional
(P, 3) List of fragment start point coordinates
fragment_end_points : np.ndarray, optional
(P, 3) List of fragment end point coordinates
fragment_group_pred : np.ndarray, optional
(P) Interaction group each fragment belongs to
sources : np.ndarray, optional
(N, 2) Tensor which contains the module/tpc information
orig_index : np.ndarray, optional
(N) Tensor which contains the indexes in the original
point cloud (before any filtering or deghosting)
Returns
-------
List[RecoFragment]
List of constructed reconstructed fragment instances
"""
# Convert the logits to softmax scores and the scores to a prediction
primary_info: tuple[np.ndarray, np.ndarray] | None = None
if fragment_node_pred is not None:
primary_scores = softmax(fragment_node_pred, axis=1)
primary_pred = np.argmax(primary_scores, axis=1)
primary_info = (primary_scores, primary_pred)
# Loop over the fragment instances
reco_fragments = []
for i, index in enumerate(fragment_clusts):
# Initialize
fragment = RecoFragment(
id=i,
shape=fragment_shapes[i],
index=index,
points=points[index],
depositions=depositions[index],
)
# Add optional arguments
if sources is not None:
fragment.sources = sources[index]
if orig_index is not None:
fragment.orig_index = orig_index[index]
if fragment_start_points is not None:
fragment.start_point = fragment_start_points[i]
if fragment_end_points is not None and fragment.shape == TRACK_SHP:
fragment.end_point = fragment_end_points[i]
if fragment_group_pred is not None:
fragment.particle_id = fragment_group_pred[i]
if primary_info is not None:
primary_scores, primary_pred = primary_info
fragment.primary_scores = primary_scores[i]
fragment.is_primary = bool(primary_pred[i])
# Append
reco_fragments.append(fragment)
return reco_fragments
[docs]
def build_truth(self, data: dict[str, Any]) -> list[TruthFragment]:
"""Builds :class:`TruthFragment` objects from the full chain output.
Parameters
----------
data : dict
Dictionary of data products
Returns
-------
List[TruthFragment]
List of constructed true fragment instances
"""
return self._build_truth(**data)
def _build_truth(
self,
label_tensor: np.ndarray,
points_label: np.ndarray,
depositions_label: np.ndarray,
depositions_q_label: np.ndarray | None = None,
label_adapt_tensor: np.ndarray | None = None,
points: np.ndarray | None = None,
depositions: np.ndarray | None = None,
label_g4_tensor: np.ndarray | None = None,
points_g4: np.ndarray | None = None,
depositions_g4: np.ndarray | None = None,
sources_label: np.ndarray | None = None,
sources: np.ndarray | None = None,
orig_index_label: np.ndarray | None = None,
particles: list[Any] | None = None,
) -> list[TruthFragment]:
"""Builds :class:`TruthFragment` objects from the full chain output.
Parameters
----------
label_tensor : np.ndarray
Tensor which contains the cluster labels of each deposition
points_label : np.ndarray
(N', 3) Set of deposition coordinates in the label image (identical
for pixel TPCs, different if deghosting is involved)
depositions_label : np.ndarray
(N') Set of true deposition values in MeV
depositions_q_label : np.ndarray, optional
(N') Set of true deposition values in ADC, if relevant
label_adapt_tensor : np.ndarray, optional
Tensor which contains the cluster labels of each deposition,
adapted to the semantic segmentation prediction.
points : np.ndarray, optional
(N, 3) Set of deposition coordinates in the image
depositions : np.ndarray, optional
(N) Set of deposition values
label_tensor_g4 : np.ndarray, optional
Tensor which contains the cluster labels of each deposition
in the Geant4 image (before the detector simulation)
points_g4 : np.ndarray, optional
(N'', 3) Set of deposition coordinates in the Geant4 image
depositions_g4 : np.ndarray, optional
(N'') Set of deposition values in the Geant4 image
sources_label : np.ndarray, optional
(N', 2) Tensor which contains the label module/tpc information
sources : np.ndarray, optional
(N, 2) Tensor which contains the module/tpc information
orig_index_label : np.ndarray, optional
(N') Tensor which contains the indexes in the original
point cloud (before any filtering or deghosting)
particles : List[Particle], optional
List of true particles
Returns
-------
List[TruthFragment]
List of constructed true fragment instances
"""
# Check once if the fragment labels have been altered
broken = (label_tensor[:, CLUST_COL] != label_tensor[:, PART_COL]).any()
truth_only = label_adapt_tensor is None
# If the adapted labels are available (the full chain was run), use
# those as a basis to form fragments (fragments depend on upstream
# segmentation). Use original labels otherwise (pure truth mode)
truth_fragments = []
ref_tensor = label_tensor if truth_only else label_adapt_tensor
unique_fragment_ids = np.unique(ref_tensor[:, CLUST_COL]).astype(int)
valid_fragment_ids = unique_fragment_ids[unique_fragment_ids > -1]
for i, frag_id in enumerate(valid_fragment_ids):
# Initialize fragment
fragment = TruthFragment(id=i)
# Find the particle which matches this fragment best
index_ref = np.where(ref_tensor[:, CLUST_COL] == frag_id)[0]
if particles is not None:
part_id = frag_id
if not truth_only or broken:
part_ids, counts = np.unique(
ref_tensor[index_ref, PART_COL], return_counts=True
)
part_id = int(part_ids[np.argmax(counts)])
if part_id > -1:
# Load the MC particle information
if part_id >= len(particles):
raise ValueError(
"Invalid particle ID found in fragment labels."
)
particle = particles[part_id]
fragment = TruthFragment(**particle.as_dict(include_derived=False))
# Override the indexes of the fragment but preserve them
fragment.orig_id = part_id
fragment.orig_group_id = particle.group_id
fragment.orig_parent_id = particle.parent_id
fragment.orig_interaction_id = particle.interaction_id
fragment.orig_children_id = particle.children_id
fragment.id = i
fragment.group_id = i
fragment.parent_id = i
fragment.children_id = np.empty(
0, dtype=fragment.orig_children_id.dtype
)
# Fill long-form attributes
if truth_only:
# Fill the true long-form attributes, if there was no adaptation
fragment.index = index_ref
fragment.points = points_label[index_ref]
fragment.depositions = depositions_label[index_ref]
if depositions_q_label is not None:
fragment.depositions_q = depositions_q_label[index_ref]
if sources_label is not None:
fragment.sources = sources_label[index_ref]
if orig_index_label is not None:
fragment.orig_index = orig_index_label[index_ref]
# If the fragments are not broken, can match to G4 info
if not broken:
# If available, append the Geant4 information
if label_g4_tensor is not None:
if points_g4 is None or depositions_g4 is None:
raise ValueError(
"Geant4 points and depositions must be provided "
"if label_g4_tensor is given."
)
index_g4 = np.where(label_g4_tensor[:, CLUST_COL] == frag_id)[0]
fragment.index_g4 = index_g4
fragment.points_g4 = points_g4[index_g4]
fragment.depositions_g4 = depositions_g4[index_g4]
else:
# Fill the adapted long-form attributes otherwise
if points is None or depositions is None:
raise ValueError(
"Points and depositions must be provided to build "
"adapted truth fragments."
)
fragment.index_adapt = index_ref
fragment.points_adapt = points[index_ref]
fragment.depositions_adapt = depositions[index_ref]
if sources is not None:
fragment.sources_adapt = sources[index_ref]
# Append
truth_fragments.append(fragment)
return truth_fragments
[docs]
def load_reco(self, data: dict[str, Any]) -> list[RecoFragment]:
"""Load :class:`RecoFragment` objects from their stored versions.
Parameters
----------
data : dict
Dictionary of data products
Returns
-------
List[RecoFragment]
List of restored reconstructed fragment instances
"""
return self._load_reco(**data)
def _load_reco(
self,
reco_fragments: list[RecoFragment],
points: np.ndarray | None = None,
depositions: np.ndarray | None = None,
sources: np.ndarray | None = None,
) -> list[RecoFragment]:
"""Load :class:`RecoFragment` objects from their stored versions.
Parameters
----------
reco_fragments : List[RecoFragment]
(F) List of partial reconstructed fragments
points : np.ndarray, optional
(N, 3) Set of deposition coordinates in the image
depositions : np.ndarray, optional
(N) Set of deposition values
sources : np.ndarray, optional
(N, 2) Tensor which contains the module/tpc information
Returns
-------
List[RecoFragment]
List of restored reconstructed fragment instances
"""
# Loop over the dictionaries
for i, fragment in enumerate(reco_fragments):
# Check that the fragment ID checks out
if fragment.id != i:
raise ValueError("The ordering of the stored fragments is wrong.")
# Update the fragment with its long-form attributes
if points is not None:
if depositions is None:
raise ValueError(
"Depositions must be provided to load reco fragments if "
"points are provided."
)
fragment.points = points[fragment.index]
fragment.depositions = depositions[fragment.index]
if sources is not None:
fragment.sources = sources[fragment.index]
return reco_fragments
[docs]
def load_truth(self, data: dict[str, Any]) -> list[TruthFragment]:
"""Load :class:`TruthFragment` objects from their stored versions.
Parameters
----------
data : dict
Dictionary of data products
Returns
-------
List[TruthFragment]
List of restored true fragment instances
"""
return self._load_truth(**data)
def _load_truth(
self,
truth_fragments: list[TruthFragment],
points_label: np.ndarray | None = None,
depositions_label: np.ndarray | None = None,
depositions_q_label: np.ndarray | None = None,
points: np.ndarray | None = None,
depositions: np.ndarray | None = None,
points_g4: np.ndarray | None = None,
depositions_g4: np.ndarray | None = None,
sources_label: np.ndarray | None = None,
sources: np.ndarray | None = None,
) -> list[TruthFragment]:
"""Load :class:`TruthFragment` objects from their stored versions.
Parameters
----------
truth_fragments : List[TruthFragment]
(F) List of partial truth fragments
points_label : np.ndarray, optional
(N', 3) Set of deposition coordinates in the label image (identical
for pixel TPCs, different if deghosting is involved)
depositions_label : np.ndarray, optional
(N') Set of true deposition values in MeV
depositions_q_label : np.ndarray, optional
(N') Set of true deposition values in ADC, if relevant
points : np.ndarray, optional
(N, 3) Set of deposition coordinates in the image
depositions : np.ndarray, optional
(N) Set of deposition values
points_g4 : np.ndarray, optional
(N'', 3) Set of deposition coordinates in the Geant4 image
depositions_g4 : np.ndarray, optional
(N'') Set of deposition values in the Geant4 image
sources : np.ndarray, optional
(N, 2) Tensor which contains the module/tpc information
sources_label : np.ndarray, optional
(N', 2) Tensor which contains the label module/tpc information
Returns
-------
List[TruthFragment]
List of restored true fragment instances
"""
# Loop over the dictionaries
for i, fragment in enumerate(truth_fragments):
# Check that the fragment ID checks out
if fragment.id != i:
raise ValueError("The ordering of the stored fragments is wrong.")
# Update the fragment with its long-form attributes
if points_label is not None:
if depositions_label is None:
raise ValueError(
"Depositions must be provided to load truth fragments if "
"label points are provided."
)
fragment.points = points_label[fragment.index]
fragment.depositions = depositions_label[fragment.index]
if depositions_q_label is not None:
fragment.depositions_q = depositions_q_label[fragment.index]
if sources_label is not None:
fragment.sources = sources_label[fragment.index]
if points is not None:
if depositions is None:
raise ValueError(
"Depositions must be provided to load adapted truth "
"fragments if points are provided."
)
fragment.points_adapt = points[fragment.index_adapt]
fragment.depositions_adapt = depositions[fragment.index_adapt]
if sources is not None:
fragment.sources_adapt = sources[fragment.index_adapt]
if points_g4 is not None:
if depositions_g4 is None:
raise ValueError(
"Depositions must be provided to load Geant4 truth "
"fragments if points are provided."
)
fragment.points_g4 = points_g4[fragment.index_g4]
fragment.depositions_g4 = depositions_g4[fragment.index_g4]
return truth_fragments