Source code for spine.post.manager

"""Manages the operation of post-processors."""

from collections import OrderedDict, defaultdict
from copy import deepcopy

import numpy as np

from spine.utils.stopwatch import StopwatchManager

from .factories import post_processor_factory


[docs] class PostManager: """Manager in charge of handling post-processing scripts. It loads all the post-processor objects once and feeds them data. """ def __init__(self, cfg, post_list=None, parent_path=None): """Initialize the post-processing manager. Parameters ---------- cfg : dict Post-processor configurations post_list : List[str], optional List of post-processors which have already been run parent_path : str, optional Path to the analysis tools configuration file """ # Loop over the post-processor modules and get their priorities cfg = deepcopy(cfg) keys = np.array(list(cfg.keys())) priorities = -np.ones(len(keys), dtype=np.int32) for i, key in enumerate(keys): if "priority" in cfg[key]: priorities[i] = cfg[key].pop("priority") # Add the modules to a processor list in decreasing order of priority self.watch = StopwatchManager() self.modules = OrderedDict() keys = keys[np.argsort(-priorities)] for key in keys: # Profile the module self.watch.initialize(key) # Append self.modules[key] = post_processor_factory( key, cfg[key], parent_path=parent_path ) # Check dependencies if post_list is not None: ups_post = tuple(self.modules) for post in self.modules[key]._upstream: assert post in (post_list + ups_post), ( f"Post-processor `{key}` is missing an essential " f"upstream post-processor: `{post}`." ) def __call__(self, data): """Pass one batch of data through the post-processors. Parameters ---------- data : dict Dictionary of data products """ # Reset active stopwatches self.watch.reset_if_active() # Loop over the post-processor modules single_entry = np.isscalar(data["index"]) for key, module in self.modules.items(): # Run the post-processor on each entry self.watch.start(key) if single_entry: num_entries = 1 result = module(data) else: num_entries = len(data["index"]) result = defaultdict(list) for entry in range(num_entries): result_e = module(data, entry) if result_e is not None: for k, v in result_e.items(): result[k].append(v) self.watch.stop(key) # Update the input dictionary if result is not None: for key, val in result.items(): if not single_entry: assert len(val) == num_entries, ( f"The number of {key} ({len(val)}) returned by the {key} " "post-processor does not match the number of entries " f"({num_entries}) in the batch." ) data[key] = val