"""Base class of all analysis scripts."""
from abc import ABC, abstractmethod
from warnings import warn
from spine.io.write.csv import CSVWriter
[docs]
class AnaBase(ABC):
"""Parent class of all analysis scripts.
This base class performs the following functions:
- Ensures that the necessary methods exist
- Checks that the script is provided the necessary information
- Writes the output of the analysis to CSV
Attributes
----------
name : str
Name of the analysis script (to call it from a configuration file)
req_keys : List[str]
Data products needed to run the analysis script
opt_keys : List[str]
Optional data products which will be used only if they are provided
units : str
Units in which the coordinates are expressed
"""
# Name of the analysis script (as specified in the configuration)
name = None
# Alternative allowed names of the analysis script
aliases = ()
# Units in which the analysis script expects objects to be expressed in
units = "cm"
# Set of data keys needed for this analysis script to operate
_keys = ()
# List of recognized object types
_obj_types = ("fragment", "particle", "interaction")
# Valid run modes
_run_modes = ("reco", "truth", "both", "all")
# List of known point modes for true particles and their corresponding keys
_point_modes = (
("points", "points_label"),
("points_adapt", "points"),
("points_g4", "points_g4"),
)
# List of known deposition modes for true particles and their corresponding keys
_dep_modes = (
("depositions", "depositions_label"),
("depositions_q", "depositions_q_label"),
("depositions_adapt", "depositions_label_adapt"),
("depositions_adapt_q", "depositions"),
("depositions_g4", "depositions_g4"),
)
def __init__(
self,
obj_type=None,
run_mode=None,
truth_point_mode=None,
truth_dep_mode=None,
append=False,
overwrite=False,
log_dir=None,
prefix=None,
buffer_size=1,
):
"""Initialize default anlysis script object properties.
Parameters
----------
obj_type : Union[str, List[str]]
Name or list of names of the object types to process
run_mode : str, optional
If specified, tells whether the analysis script must run on
reconstructed ('reco'), true ('true') or both objects
('both' or 'all')
truth_point_mode : str, optional
If specified, tells which attribute of the :class:`TruthFragment`,
:class:`TruthParticle` or :class:`TruthInteraction` object to use
to fetch its point coordinates
truth_dep_mode : str, optional
If specified, tells which attribute of the :class:`TruthFragment`,
:class:`TruthParticle` or :class:`TruthInteraction` object to use
to fetch its energy depositions
append : bool, default False
If True, appends existing CSV files instead of creating new ones
overwrite : bool, default False
If True and an output CSV file exists, overwrite it
log_dir : str
Output CSV file directory (shared with driver log)
prefix : str, default None
Name to prefix every output CSV file with
buffer_size : int, default 1
CSV file buffer size. 1 is line buffered (safe default),
-1 uses system default, 0 is unbuffered, >1 is buffer size in bytes
"""
# Initialize default keys
self.update_keys(
{
"index": True,
"file_index": True,
"file_entry_index": False,
"run_info": False,
}
)
# If run mode is specified, process it
self.run_mode = run_mode
if run_mode is not None:
# Check that the run mode is recognized
assert run_mode in self._run_modes, (
f"`run_mode` not recognized: {run_mode}. Must be one of "
f"{self._run_modes}."
)
self.prefixes = []
if run_mode != "truth":
self.prefixes.append("reco")
if run_mode != "reco":
self.prefixes.append("truth")
# Check that all the object sources are recognized
self.obj_type = obj_type
if self.obj_type is not None:
if isinstance(self.obj_type, str):
self.obj_type = [self.obj_type]
for obj in self.obj_type:
assert obj in self._obj_types, (
f"Object type must be one of {self._obj_types}. Got "
f"`{obj}` instead."
)
# Make a list of object keys to process
self.fragment_keys = []
self.particle_keys = []
self.interaction_keys = []
for name in self._obj_types:
# Initialize one list per object type
setattr(self, f"{name}_keys", [])
# Skip object types which are not requested
if self.obj_type is not None and name in self.obj_type:
if run_mode != "truth":
getattr(self, f"{name}_keys").append(f"reco_{name}s")
if run_mode != "reco":
getattr(self, f"{name}_keys").append(f"truth_{name}s")
self.obj_keys = self.fragment_keys + self.particle_keys + self.interaction_keys
# Update underlying keys, if needed
self.update_keys({k: True for k in self.obj_keys})
# If a truth point mode is specified, store it
if truth_point_mode is not None:
assert truth_point_mode in self.point_modes, (
"The `truth_point_mode` argument must be one of "
f"{self.point_modes.keys()}. Got `{truth_point_mode}` instead."
)
self.truth_point_mode = truth_point_mode
self.truth_point_key = self.point_modes[truth_point_mode]
self.truth_index_mode = truth_point_mode.replace("points", "index")
# If a truth deposition mode is specified, store it
if truth_dep_mode is not None:
assert truth_dep_mode in self.dep_modes, (
"The `truth_dep_mode` argument must be one of "
f"{self.dep_modes.keys()}. Got `{truth_dep_mode}` instead."
)
if truth_point_mode is not None:
prefix = truth_point_mode.replace("points", "depositions")
assert truth_dep_mode.startswith(prefix), (
f"Points mode {truth_point_mode} and deposition mode "
f"{truth_dep_mode} are incompatible."
)
self.truth_dep_mode = truth_dep_mode
self.truth_dep_key = self.dep_modes[truth_dep_mode]
# Store the append flag
self.append_file = append
self.overwrite_file = overwrite
# Initialize a writer dictionary to be filled by the children classes
self.base_dict = {}
self.log_dir = log_dir
self.output_prefix = prefix
self.buffer_size = buffer_size
self.writers = {}
def __del__(self):
"""Destructor to ensure CSV files are closed.
This acts as a safety net in case close_writers() is not called
explicitly. However, explicit cleanup is preferred.
"""
self.close_writers()
[docs]
def close_writers(self):
"""Close all CSV writers and flush any remaining data.
This should be called when the analysis is complete to ensure
all data is written and files are properly closed.
"""
for writer in self.writers.values():
writer.close()
[docs]
def flush_writers(self):
"""Flush all CSV writer buffers without closing the files.
This forces any buffered data to be written to disk. Useful
for ensuring data persistence at checkpoints.
"""
for writer in self.writers.values():
writer.flush()
[docs]
def initialize_writer(self, name):
"""Adds a CSV writer to the list of writers for this script.
Parameters
----------
name : str
Name of the writer
"""
# Define the name of the file to write to
assert len(name) > 0, "Must provide a non-empty name."
file_name = f"{self.name}_{name}.csv"
if self.output_prefix:
file_name = f"{self.output_prefix}_{file_name}"
if self.log_dir:
file_name = f"{self.log_dir}/{file_name}"
# Initialize the writer
self.writers[name] = CSVWriter(
file_name,
append=self.append_file,
overwrite=self.overwrite_file,
buffer_size=self.buffer_size,
)
@property
def keys(self):
"""Dictionary of (key, necessity) pairs which determine which data keys
are needed/optional for the post-processor to run.
Returns
-------
Dict[str, bool]
Dictionary of (key, necessity) pairs to be used
"""
return dict(self._keys)
@keys.setter
def keys(self, keys):
"""Converts a dictionary of keys to an immutable tuple.
Parameters
----------
Dict[str, bool]
Dictionary of (key, necessity) pairs to be used
"""
self._keys = tuple(keys.items())
@property
def point_modes(self):
"""Dictionary which makes the correspondance between the name of a true
object point attribute with the underlying point tensor it points to.
Returns
-------
Dict[str, str]
Dictionary of (attribute, key) mapping for point coordinates
"""
return dict(self._point_modes)
@property
def dep_modes(self):
"""Dictionary which makes the correspondance between the name of a true
object deposition attribute with the underlying deposition array it points to.
Returns
-------
Dict[str, str]
Dictionary of (attribute, key) mapping for point depositions
"""
return dict(self._dep_modes)
[docs]
def update_keys(self, update_dict):
"""Update the underlying set of keys and their necessity in place.
Parameters
----------
update_dict : Dict[str, bool]
Dictionary of (key, necessity) pairs to update the keys with
"""
if len(update_dict) > 0:
keys = self.keys
keys.update(update_dict)
self._keys = tuple(keys.items())
[docs]
def get_base_dict(self, data):
"""Builds the entry information dictionary.
Parameters
----------
data : dict
Dictionary of data products
Returns
-------
dict
Dictionary of information for this entry
"""
# Extract basic information to store in every row
base_dict = {"index": data["index"], "file_index": data["file_index"]}
if "file_entry_index" in data:
base_dict["file_entry_index"] = data["file_entry_index"]
if "run_info" in data:
base_dict.update(**data["run_info"].scalar_dict())
else:
warn("`run_info` is missing; will not be included in CSV file.")
return base_dict
[docs]
def append(self, name, **kwargs):
"""Apppend a CSV log file with a set of values.
Parameters
----------
name : str
Name of the writer
**kwargs : dict
Dictionary of information to save to the writer
"""
self.writers[name].append({**self.base_dict, **kwargs})
def __call__(self, data, entry=None):
"""Runs the analysis script on one entry.
Parameters
----------
data : dict
Data dictionary for one entry
Returns
-------
dict
Update to the input dictionary
"""
# Fetch the necessary information
data_filter = {}
for key, req in self.keys.items():
# If this key is needed, check that it exists
assert not req or key in data, (
f"Analysis script `{self.name}` is missing an essential "
f"input to be used: `{key}`."
)
# Append
if key in data:
data_filter[key] = data[key]
if entry is not None:
data_filter[key] = data[key][entry]
# Fetch the base dictionary
self.base_dict = self.get_base_dict(data_filter)
# Run the analysis script
return self.process(data_filter)
[docs]
def get_index(self, obj):
"""Get a certain pre-defined index attribute of an object.
The :class:`TruthFragment`, :class:`TruthParticle` and
:class:`TruthInteraction` objects index are obtained using the
`truth_index_mode` attribute of the class.
Parameters
----------
obj : Union[FragmentBase, ParticleBase, InteractionBase]
Fragment, Particle or Interaction object
Results
-------
np.ndarray
(N) Object index
"""
if not obj.is_truth:
return obj.index
else:
return getattr(obj, self.truth_index_mode)
[docs]
def get_points(self, obj):
"""Get a certain pre-defined point attribute of an object.
The :class:`TruthFragment`, :class:`TruthParticle` and
:class:`TruthInteraction` objects points are obtained using the
`truth_point_mode` attribute of the class.
Parameters
----------
obj : Union[FragmentBase, ParticleBase, InteractionBase]
Fragment, Particle or Interaction object
Results
-------
np.ndarray
(N, 3) Point coordinates
"""
if not obj.is_truth:
return obj.points
else:
return getattr(obj, self.truth_point_mode)
[docs]
@abstractmethod
def process(self, data):
"""Place-holder method to be defined in each analysis script.
Parameters
----------
data : dict
Filtered data dictionary for one entry
"""
raise NotImplementedError("Must define the `process` function")