Source code for spine.post.base

"""Contains base class of all post-processors."""

from abc import ABC, abstractmethod


[docs] class PostBase(ABC): """Base class of all post-processors. This base class performs the following functions: - Ensures that the necessary method exist - Checks that the post-processor is provided the necessary information to do its job - Fetches the appropriate coordinate attributes - Ensures that the appropriate units are provided Attributes ---------- name : str Name of the post-processor as defined in the configuration file aliases : Tuple[str] Alternative acceptable names for a post-processor """ # Name of the post-processor (as specified in the configuration) name = None # Alternative allowed names of the post-processor aliases = () # Units in which the post-processor expects objects to be expressed in units = "cm" # Whether this post-processor might use paths relative to the parent configuration path provide_parent_path = False # Set of data keys needed for this post-processor to operate _keys = () # Set of post-processors which must be run before this one is _upstream = () # List of recognized object types _obj_types = ("fragment", "particle", "interaction") # List of recognized run modes _run_modes = ("reco", "truth", "both", "all") # List of known reconstructed particle identification modes _pid_modes = ("pid", "chi2_pid") # List of known point modes for true particles and their corresponding keys _point_modes = ( ("points", "points_label"), ("points_adapt", "points"), ("points_g4", "points_g4"), ) # List of known source modes for true particles and their corresponding keys _source_modes = ( ("sources", "sources_label"), ("sources_adapt", "sources"), ("sources_g4", "sources_g4"), ) # List of known deposition modes for true particles and their corresponding keys _dep_modes = ( ("depositions", "depositions_label"), ("depositions_q", "depositions_q_label"), ("depositions_adapt", "depositions_label_adapt"), ("depositions_adapt_q", "depositions"), ("depositions_g4", "depositions_g4"), ) def __init__( self, obj_type=None, run_mode=None, truth_point_mode=None, truth_dep_mode=None, pid_mode=None, parent_path=None, ): """Initialize default post-processor object properties. Parameters ---------- obj_type : Union[str, List[str]] Name or list of names of the object types to process run_mode : str, optional If specified, tells whether the post-processor must run on reconstructed ('reco'), true ('true') or both objects ('both' or 'all') truth_point_mode : str, optional If specified, tells which attribute of the :class:`TruthFragment`, :class:`TruthParticle` or :class:`TruthInteraction` object to use to fetch its point coordinates truth_dep_mode : str, optional If specified, tells which attribute of the :class:`TruthFragment`, :class:`TruthParticle` or :class:`TruthInteraction` object to use to fetch its depositions parent_path : str, optional Path to the parent directory of the main analysis configuration. This allows for the use of relative paths in the post-processors. """ # If run mode is specified, process it if run_mode is not None: # Check that the run mode is recognized assert run_mode in self._run_modes, ( f"`run_mode` not recognized: {run_mode}. Must be one of " f"{self._run_modes}." ) # Check that all the object sources are recognized if obj_type is None: obj_type = [] elif isinstance(obj_type, str): obj_type = [obj_type] for obj in obj_type: assert obj in self._obj_types, ( f"Object type must be one of {self._obj_types}. Got " f"`{obj}` instead." ) # Make a list of object keys to process self.fragment_keys = [] self.particle_keys = [] self.interaction_keys = [] for name in self._obj_types: # Initialize one list per object type setattr(self, f"{name}_keys", []) # Skip object types which are not requested if name in obj_type: if run_mode != "truth": getattr(self, f"{name}_keys").append(f"reco_{name}s") if run_mode != "reco": getattr(self, f"{name}_keys").append(f"truth_{name}s") self.obj_keys = self.fragment_keys + self.particle_keys + self.interaction_keys # Update underlying keys, if needed self.update_keys({k: True for k in self.obj_keys}) # If a truth point mode is specified, store it if truth_point_mode is not None: assert truth_point_mode in self.point_modes, ( "The `truth_point_mode` argument must be one of " f"{self.point_modes.keys()}. Got `{truth_point_mode}` instead." ) self.truth_point_mode = truth_point_mode self.truth_point_key = self.point_modes[self.truth_point_mode] self.truth_source_mode = truth_point_mode.replace("points", "sources") self.truth_source_key = self.source_modes[self.truth_source_mode] self.truth_index_mode = truth_point_mode.replace("points", "index") # If a truth deposition mode is specified, store it if truth_dep_mode is not None: assert truth_dep_mode in self.dep_modes, ( "The `truth_dep_mode` argument must be one of " f"{self.dep_modes.keys()}. Got `{truth_dep_mode}` instead." ) if truth_point_mode is not None: prefix = truth_point_mode.replace("points", "depositions") assert truth_dep_mode.startswith(prefix), ( f"Points mode {truth_point_mode} and deposition mode " f"{truth_dep_mode} are incompatible." ) self.truth_dep_mode = truth_dep_mode self.truth_dep_key = self.dep_modes[truth_dep_mode] # If a PID mode is specified, store it if pid_mode is not None: assert pid_mode in self._pid_modes, ( f"The `pid_mode` argument must be one of {self._pid_modes}. " f"Got {pid_mode} instead." ) self.pid_mode = pid_mode # Store the parent path self.parent_path = parent_path @property def keys(self): """Dictionary of (key, necessity) pairs which determine which data keys are needed/optional for the post-processor to run. Returns ------- Dict[str, bool] Dictionary of (key, necessity) pairs to be used """ return dict(self._keys) @property def point_modes(self): """Dictionary which makes the correspondance between the name of a true object point attribute with the underlying point tensor it points to. Returns ------- Dict[str, str] Dictionary of (attribute, key) mapping for point coordinates """ return dict(self._point_modes) @property def source_modes(self): """Dictionary which makes the correspondance between the name of a true object source attribute with the underlying source tensor it points to. Returns ------- Dict[str, str] Dictionary of (attribute, key) mapping for point sources """ return dict(self._source_modes) @property def dep_modes(self): """Dictionary which makes the correspondance between the name of a true object deposition attribute with the underlying deposition array it points to. Returns ------- Dict[str, str] Dictionary of (attribute, key) mapping for point depositions """ return dict(self._dep_modes)
[docs] def update_keys(self, update_dict): """Update the underlying set of keys and their necessity in place. Parameters ---------- update_dict : Dict[str, bool] Dictionary of (key, necessity) pairs to update the keys with """ if len(update_dict) > 0: keys = self.keys keys.update(update_dict) self._keys = tuple(keys.items())
[docs] def update_upstream(self, key): """Update the underlying set of required upstream modules in place. Parameters ---------- key : str Post-processor module name to add to the list """ self._upstream = (*self._upstream, key)
def __call__(self, data, entry=None): """Calls the post processor on one entry. Parameters ---------- data : dict Dicitionary of data products entry : int, optional Entry in the batch Returns ------- dict Update to the input dictionary """ # Fetch the input dictionary data_filter = {} for key, req in self._keys: # If this key is needed, check that it exists assert not req or key in data, ( f"Post-processor `{self.name}` is missing an essential " f"input to be used: `{key}`." ) # Append if key in data: data_filter[key] = data[key] if entry is not None: data_filter[key] = data[key][entry] # Run the post-processor return self.process(data_filter)
[docs] def get_index(self, obj): """Get a certain pre-defined index attribute of an object. The :class:`TruthFragment`, :class:`TruthParticle` and :class:`TruthInteraction` objects index are obtained using the `truth_index_mode` attribute of the class. Parameters ---------- obj : Union[FragmentBase, ParticleBase, InteractionBase] Fragment, Particle or Interaction object Results ------- np.ndarray (N) Object index """ if not obj.is_truth: return obj.index else: return getattr(obj, self.truth_index_mode)
[docs] def get_points(self, obj): """Get a certain pre-defined point attribute of an object. The :class:`TruthFragment`, :class:`TruthParticle` and :class:`TruthInteraction` objects points are obtained using the `truth_point_mode` attribute of the class. Parameters ---------- obj : Union[FragmentBase, ParticleBase, InteractionBase] Fragment, Particle or Interaction object Results ------- np.ndarray (N, 3) Point coordinates """ if not obj.is_truth: return obj.points else: return getattr(obj, self.truth_point_mode)
[docs] def get_sources(self, obj): """Get a certain pre-defined sources attribute of an object. The :class:`TruthFragment`, :class:`TruthParticle` and :class:`TruthInteraction` objects sources are obtained using the `truth_source_mode` attribute of the class. Parameters ---------- obj : Union[FragmentBase, ParticleBase, InteractionBase] Fragment, Particle or Interaction object Results ------- np.ndarray (N, 2) Object sources """ if not obj.is_truth: return obj.sources else: return getattr(obj, self.truth_source_mode)
[docs] def get_depositions(self, obj): """Get a certain pre-defined deposition attribute of an object. The :class:`TruthFragment`, :class:`TruthParticle` and :class:`TruthInteraction` objects points are obtained using the `truth_dep_mode` attribute of the class. Parameters ---------- obj : Union[FragmentBase, ParticleBase, InteractionBase] Fragment, Particle or Interaction object Results ------- np.ndarray (N) Depositions """ if not obj.is_truth: return obj.depositions else: return getattr(obj, self.truth_dep_mode)
[docs] def get_pid(self, obj): """Get a certain pre-defined PID prediction of an object. The :class:`TruthParticle` PID predictions are obtained using the `pid_mode` attribute of the class. Parameters ---------- obj : Union[ParticleBase] Particle object Results ------- int Particle identification enumerator """ if not obj.is_truth: return getattr(obj, self.pid_mode) else: return obj.pid
[docs] def check_units(self, obj): """Check that the point coordinates of an object are as expected. Parameters ---------- obj : Union[FragmentBase, ParticleBase, InteractionBase] Particle or interaction object Results ------- np.ndarray (N, 3) Point coordinates """ if obj.units != self.units: raise ValueError( f"Coordinates must be expressed in {self.units} but are " f"currently in {obj.units} instead." )
[docs] @abstractmethod def process(self, data): """Place-holder method to be defined in each post-processor. Parameters ---------- data : dict Dictionary of processed data products """ raise NotImplementedError("Must define the `process` function.")