"""Class to build all require representations."""
from __future__ import annotations
from collections import OrderedDict, defaultdict
from collections.abc import Sequence
from typing import Any, ClassVar
import numpy as np
from spine.constants import COORD_COLS, VALUE_COL
from .fragment import FragmentBuilder
from .interaction import InteractionBuilder
from .particle import ParticleBuilder
from .utils import RunMode, Units, get_batch_size, is_single_index
[docs]
class BuildManager:
"""Manager which constructs data representations based on the chain output.
Takes care of two scenarios:
- Interpret the raw output of the reconstruction chain
- Load up existing objects stored as dictionaries
"""
# List of recognized run modes
_run_modes = ("reco", "truth", "both", "all")
# List of recognized units
_units = ("cm", "px")
# Name of input data products needed to build representations. These names
# are not set in stone; they can be set in the configuration
_default_sources: ClassVar[tuple[tuple[str, tuple[str, ...]], ...]] = (
("data_tensor", ("data_adapt", "data")),
("label_tensor", ("clust_label",)),
("label_adapt_tensor", ("clust_label_adapt",)),
("label_g4_tensor", ("clust_label_g4",)),
("depositions_q_label", ("charge_label",)),
("graph_label", ("graph_label",)),
("orig_index", ("orig_index",)),
("orig_index_label", ("orig_index_label", "orig_index")),
("sources", ("sources_adapt", "sources")),
("sources_label", ("sources_label",)),
("particles", ("particles",)),
("neutrinos", ("neutrinos",)),
("flashes", ("flashes",)),
("crthits", ("crthits",)),
)
[docs]
def __init__(
self,
fragments: bool,
particles: bool,
interactions: bool,
mode: RunMode = "both",
units: Units = "cm",
sources: dict[str, str | tuple[str, ...] | list[str]] | None = None,
lite: bool = False,
) -> None:
"""Initializes the build manager.
Parameters
----------
fragments : bool
Build/load RecoFragment/TruthFragment objects
particles : bool
Build/load RecoParticle/TruthParticle objects
interactions : bool
Build/load RecoInteraction/TruthInteraction objects
mode : str, default 'both'
Whether to construct reconstructed objects, true objects or both
sources : Dict[str, str], optional
Dictionary which maps the necessary data products onto a name
in the input/output dictionary of the reconstruction chain.
lite : bool, default False
If `True`, the objects being loaded are lite and do not map
to long-form attributes. Simply load the matches.
"""
# Check on the mode, store it
if mode not in self._run_modes:
raise ValueError(
f"Run mode not recognized: {mode}. Must be one {self._run_modes}"
)
self.mode = mode
# Check on the units, store them
if units not in self._units:
raise ValueError(
f"Units not recognized: {units}. Must be one {self._units}"
)
self.units = units
# Resolve the per-instance source mapping
sources_dict = dict(self._default_sources)
if sources is not None:
for key, value in sources.items():
if key not in sources_dict:
raise KeyError(
"Unexpected data product specified in `sources`: "
f"{key}. Should be one of {list(sources_dict.keys())}."
)
if isinstance(value, str):
sources_dict[key] = (value,)
else:
if not isinstance(value, Sequence) or isinstance(
value, (str, bytes)
):
raise TypeError(
"Source overrides must be strings or sequences of strings."
)
if not all(isinstance(item, str) for item in value):
raise TypeError(
"Source override sequences must contain only strings."
)
sources_dict[key] = tuple(value)
self.sources = sources_dict
# Initialize the builders
self.builders = OrderedDict()
if fragments:
self.builders["fragment"] = FragmentBuilder(mode, units)
if particles:
self.builders["particle"] = ParticleBuilder(mode, units)
if interactions:
if not particles:
raise ValueError(
"Interactions are built from particles. If `interactions` "
"is True, so must `particles` be."
)
self.builders["interaction"] = InteractionBuilder(mode, units)
# Store whether to load the long-form attributes or not
self.lite = lite
def __call__(self, data: dict[str, Any]) -> None:
"""Build the representations for one entry.
Parameters
----------
data : dict
Dictionary of input data and model outputs
Notes
-----
Modifies the data dictionary in place.
"""
# If this is the first time the builders are called, build
# the objects shared between fragments/particles/interactions
load = True
if not self.lite and "points" not in data and "points_label" not in data:
load = False
if is_single_index(data["index"]):
sources = self.build_sources(data)
else:
sources = defaultdict(list)
for entry in range(get_batch_size(data["index"])):
sources_e = self.build_sources(data, entry)
for key, val in sources_e.items():
sources[key].append(val)
data.update(**sources)
# Loop over builders
for name, builder in self.builders.items():
# Build representations
builder(data)
# Generate match pairs from stored matches
if load and self.mode in ["both", "all"]:
if is_single_index(data["index"]):
match_dict = self.load_match_pairs(data, name)
else:
match_dict = defaultdict(list)
for entry in range(get_batch_size(data["index"])):
match_dict_e = self.load_match_pairs(data, name, entry)
for key, val in match_dict_e.items():
match_dict[key].append(val)
data.update(**match_dict)
[docs]
def build_sources(
self, data: dict[str, Any], entry: int | None = None
) -> dict[str, Any]:
"""Construct the reference coordinate and value tensors used by
all the representations built by the module.
These objects should be stored along with the constructed objects
if the objects are to be loaded later on.
Parameters
----------
data : dict
Dictionary of input data and model outputs
entry : int, optional
Entry number
"""
# Fetch the orginal sources
sources = {}
for key, alt_keys in self.sources.items():
for alt in alt_keys:
if alt in data:
sources[key] = data[alt]
if entry is not None:
sources[key] = data[alt][entry]
break
# Build aditional information
update = {}
if self.mode != "truth" or "label_adapt_tensor" in sources:
update["points"] = sources["data_tensor"][:, COORD_COLS]
update["depositions"] = sources["data_tensor"][:, VALUE_COL]
if "sources" in sources:
update["sources"] = sources["sources"].astype(np.int32, copy=False)
if "orig_index" in sources:
update["orig_index"] = sources["orig_index"].astype(
np.int32, copy=False
)
if self.mode != "reco":
update["label_tensor"] = sources["label_tensor"]
update["points_label"] = sources["label_tensor"][:, COORD_COLS]
update["depositions_label"] = sources["label_tensor"][:, VALUE_COL]
if "depositions_q_label" in sources:
update["depositions_q_label"] = sources["depositions_q_label"][
:, VALUE_COL
]
if "label_adapt_tensor" in sources:
update["label_adapt_tensor"] = sources["label_adapt_tensor"]
update["depositions_label_adapt"] = sources["label_adapt_tensor"][
:, VALUE_COL
]
if "label_g4_tensor" in sources:
update["label_g4_tensor"] = sources["label_g4_tensor"]
update["points_g4"] = sources["label_g4_tensor"][:, COORD_COLS]
update["depositions_g4"] = sources["label_g4_tensor"][:, VALUE_COL]
if "sources_label" in sources:
update["sources_label"] = sources["sources_label"].astype(
np.int32, copy=False
)
if "orig_index_label" in sources:
update["orig_index_label"] = sources["orig_index_label"].astype(
np.int32, copy=False
)
# If provided, etch the point attributes to check their units
for obj in ["fragment", "particle"]:
for key in [f"{obj}_start_points", f"{obj}_end_points"]:
if key in data:
update[key] = data[key]
if entry is not None:
update[key] = update[key][entry]
# Convert everything to the proper units once and for all
if self.units != "px":
# Fetch metadata
if "meta" not in data:
raise KeyError("Must provide metadata to build objects in cm.")
meta = data["meta"][entry] if entry is not None else data["meta"]
for key in update:
if "points" in key and key in update:
if key in update:
update[key] = meta.to_cm(np.copy(update[key]), center=True)
for key in ["particles", "neutrinos"]:
if key in sources:
update[key] = sources[key]
for obj in sources[key]:
if obj.units != self.units:
obj.to_cm(meta)
return update
[docs]
@staticmethod
def load_match_pairs(
data: dict[str, Any], name: str, entry: int | None = None
) -> dict[str, list[Any]]:
"""Generate lists of matched object pairs from stored matches.
Parameters
----------
data : dict
Dictionary of input data and model outputs
name : str
Object type name
entry : int, optional
Entry number
"""
# Initialize the name of the match lists
prefix = f"{name}_matches"
# Create match pairs in both directions (true to reco and vice versa)
result = {}
for source, target in [("reco", "truth"), ("truth", "reco")]:
# Fetch the lists of objects to match
sources = data[f"{source}_{name}s"]
targets = data[f"{target}_{name}s"]
if entry is not None:
sources, targets = sources[entry], targets[entry]
# Loop
suffix = f"{source[0]}2{target[0]}"
match_key = f"{prefix}_{suffix}"
match_overlap_key = f"{match_key}_overlap"
result[match_key] = []
result[match_overlap_key] = []
for obj in sources:
if not obj.is_matched:
# If no match is found, give an empty value to the match
result[match_key].append((obj, None))
result[match_overlap_key].append(-1.0)
else:
# If a match is found, the first is always the best match
best_match = obj.match_ids[0]
result[match_key].append((obj, targets[best_match]))
result[match_overlap_key].append(obj.match_overlaps[0])
return result