Source code for pipeline.utils.modelutils.exploration

"""A module that defines :py:class:`ParamExplorer`, a class that allows to vary
a parameter and check the efficiency that is obtained for this choice.
"""

from __future__ import annotations
import typing
import abc
import os.path as op

from tqdm.auto import tqdm
import numpy as np
import numpy.typing as npt
import pandas as pd
from torch_geometric.data import Data
import montetracko as mt
import montetracko.lhcb as mtb
from matplotlib.axes import Axes
from matplotlib.figure import Figure

from Evaluation.matching import perform_matching
from utils.loaderutils.preprocessing import load_preprocessed_dataframes

from utils.plotutils import plotools
from utils.commonutils.config import get_performance_directory_experiment
from .basemodel import ModelBase
from ..loaderutils.dataiterator import LazyDatasetBase


def _get_metric_errors(values, dict_performances, metric_name, category_name):
    metric_errors = []
    for value in values:
        output = dict_performances[value][category_name, metric_name]
        if "err_pos" in output and "err_neg" in output:
            metric_errors.append([output["err_neg"], output["err_pos"]])
        elif "err" in output:
            metric_errors.append(output["err"])
        else:
            raise Exception(
                "Neither `err_pos` and `err_neg`, not `err` " "is in the dictionary."
            )
    return np.array(metric_errors).T


[docs]class ParamExplorer(abc.ABC): """A class that allow to explore the track matching performance for various choices of a given parameter of a trained model (e.g., best efficiency as a function of the squared maximal distance of the kNN) """ def __init__( self, model: ModelBase, varname: str, varlabel: str | None = None ) -> None: self.model = model self.varname = str(varname) self.varlabel = str(varlabel) if varlabel is not None else self.varname @property def default_step(self) -> str: """Name of the temp to fall back to if not provided.""" raise NotImplementedError("No default step were provided.")
[docs] def load_preprocessed_dataframes( self, batches: typing.List[Data] | LazyDatasetBase, ) -> typing.Tuple[pd.DataFrame, pd.DataFrame]: """Load the preprocessed dataframes of hits-particles and particles associated with the PyTorch DataSets given as input. Args: batches: list of PyTorch Geometric Data objects Returns: Tuple of dataframes of hits-particles and particles """ truncated_paths = [batch.truncated_path for batch in batches] df_hits_particles = load_preprocessed_dataframes( truncated_paths=truncated_paths, ending="-hits_particles", columns=["particle_id", "hit_id"], ) df_particles = load_preprocessed_dataframes( truncated_paths=truncated_paths, ending="-particles", columns=[ "particle_id", "has_velo", "has_scifi", "pid", "eta", "from_sdecay", "mother_pid", "p", ], ) return df_hits_particles, df_particles
[docs] def run_inference( self, batches: typing.List[Data] | LazyDatasetBase ) -> typing.List[Data]: """Run the inference on a batch. Args: batches: List of batches Returns: List of inferred batches """ return [batch for batch in batches]
[docs] def compute_performance_metrics( self, values: typing.Sequence[float], partition: str, metric_names: typing.List[str], categories: typing.List[mt.requirement.Category], n_events: int | None = None, seed: int | None = None, track_metric_names: typing.List[str] | None = None, with_err: bool = True, **kwargs, ) -> typing.Dict[ float, typing.Dict[typing.Tuple[str | None, str], typing.Dict[str, float]] ]: """Compute the performance metrics for different values a hyperparameter. Args: values: list of values for the hyperparameter of interest partition: ``train``, ``val`` or the name of a test dataset n_events: Maximal number of events for the evaluation seed: Random seed for randomly selecting ``n_events`` metric_names: List of metric names to compute categories: list of categories to compute the performance in. Returns: 3-tuple of the Matplotlib Figure and Axes, and the dictionary of metric values for every tuple ``(value, category.name, metric_name)`` """ # Load PyTorch Geometric Data objects batches = self.model.fetch_partition( partition=partition, n_events=n_events, shuffle=True, seed=seed, map_location=self.model.device, ) # Move batches to save device as model # batches = [batch.to(model.device) for batch in batches] # type: ignore # Load associated pre-processed files that contains information # used for matching df_hits_particles, df_particles = self.load_preprocessed_dataframes( batches=batches ) # Run model inference batches = self.run_inference(batches) dict_performance = {} for value in (pbar := tqdm(values)): pbar.set_description(f"Loop over {self.varname} (current value: {value})") df_tracks = self.get_tracks(value=value, batches=batches, **kwargs) dict_performance[value] = self.get_performance_from_tracks( df_tracks=df_tracks, df_hits_particles=df_hits_particles, df_particles=df_particles, metric_names=metric_names, categories=categories, track_metric_names=track_metric_names, with_err=with_err, ) return dict_performance
[docs] def get_performance_from_tracks( self, df_tracks: pd.DataFrame, df_hits_particles: pd.DataFrame, df_particles: pd.DataFrame, metric_names: typing.List[str], categories: typing.List[mt.requirement.Category], track_metric_names: typing.List[str] | None = None, with_err: bool = True, ) -> typing.Dict[typing.Tuple[str | None, str], typing.Dict[str, float]]: """Get performance dictionary for given tracks. Args: df_tracks: dataframe of tracks df_hits_particles: dataframe of hits-particles df_particles: dataframe of particles metric_names: List of metric names to compute categories: list of categories to compute the performance in. Returns: Dictionary that associates the 2-tuple ``(category.name, metric_name)`` with the metric value for the given category """ trackEvaluator = perform_matching( df_tracks=df_tracks, df_hits_particles=df_hits_particles, df_particles=df_particles, min_track_length=3, ) if track_metric_names is None: track_metric_names = [] dict_performances: typing.Dict[typing.Tuple[str | None, str]] = { # type: ignore (category.name, metric_name): trackEvaluator.compute_metric( metric_name=metric_name, category=category, err="auto" if with_err else None, ) for category in categories for metric_name in metric_names } for track_metric_name in track_metric_names: dict_performances[None, track_metric_name] = trackEvaluator.compute_metric( metric_name=track_metric_name, category=None, err="auto" if with_err else None, ) return dict_performances
[docs] @abc.abstractmethod def get_tracks( self, value: float, batches: typing.List[Data] | LazyDatasetBase, **kwargs ) -> pd.DataFrame: """Get the dataframe of tracks from the inferred batches. Args: value: current value of the parameter that is explored batches: list of inferred batches Returns: Dataframe of tracks, with columns ``track_id`` and ``hit_id`` """ raise NotImplementedError()
[docs] def get_output_dir(self, path_or_config: str | dict, step: str): return op.join(get_performance_directory_experiment(path_or_config), step)
[docs] def add_lhcb_text(self, ax: Axes, metric_name: str): plotools.add_text(ax, ha="right", y=0.3)
[docs] def plot( self, partition: str, values: typing.Sequence[float], n_events: int | None = None, seed: int | None = None, metric_names: typing.List[str] | None = None, categories: typing.List[mt.requirement.Category] | None = None, track_metric_names: typing.List[str] | None = None, identifier: str | None = None, path_or_config: str | dict | None = None, output_path: str | None = None, same_fig: bool = True, lhcb: bool = False, category_name_to_color: dict | None = None, step: str | None = None, with_err: bool = True, legend_inside: bool | None = None, **kwargs, ) -> typing.Tuple[ Figure | npt.NDArray, typing.List[Axes], typing.Dict[ float, typing.Dict[typing.Tuple[str | None, str], typing.Dict[str, float]] ], ]: """Plot metrics in differences categories for different hyperparameter values. Args: path_or_config: pipeline configuration partition: ``train``, ``val`` or the name of a test dataset values: list of values for the hyperparameter of interest n_events: Maximal number of events for the evaluation seed: Random seed for randomly selecting ``n_events`` metric_names: List of metric names to compute. If not set, ``efficiency``, ``clone_rate`` and ``hit_efficiency_per_candidate`` are computed and plotted. categories: list of categories to compute the performance in. By default, this is "Velo Without Electrons" and "Long Electrons". track_metric_names: list of track-related metrics (that do not depend on any category) to plot identifier: Identifier for the figure name. Only used if ``output_path`` is not provided output_path: Output path where the figure is saved. If ``same_fig`` is set to ``False``, the string should contain the placeholder ``{metric_name}``. same_fig: in the case where several metrics are plotted, whether to plot them in the same matplotlib figure object **kwargs: Other keyword arguments passed to :py:func:`ParamExplorer.compute_performance_metrics` Returns: 3-tuple of the Matplotlib Figures and Axes, and the dictionary of metric values for every tuple ``(value, category.name, metric_name)`` """ if legend_inside is None: legend_inside = same_fig if category_name_to_color is None: category_name_to_color = {} if step is None: step = self.default_step if metric_names is None: metric_names = ["efficiency", "clone_rate", "hit_efficiency_per_candidate"] if categories is None: categories = [ mtb.category.category_velo_no_electrons, mtb.category.category_long_only_electrons, ] if identifier is None: identifier = "" dict_performances = self.compute_performance_metrics( values=values, partition=partition, metric_names=metric_names, track_metric_names=track_metric_names, categories=categories, n_events=n_events, seed=seed, with_err=with_err, **kwargs, ) if track_metric_names is None: track_metric_names = [] if same_fig: figs, axes_ = plotools.get_figs_axes_on_grid( 1, len(metric_names) + len(track_metric_names), same_fig=True, figsize=(8, 6), ) else: if metric_names: figs_with_cat, axes_with_cat_ = plotools.get_figs_axes_on_grid( 1, len(metric_names), same_fig=False, figsize=(8, 6) if legend_inside else (12, 6), ) else: figs_with_cat = None axes_with_cat_ = None if track_metric_names: figs_track, axes_track_ = plotools.get_figs_axes_on_grid( 1, len(track_metric_names), same_fig=False, figsize=(8, 6), ) else: figs_track = None axes_track_ = None figs = np.concatenate( [figs_ for figs_ in (figs_with_cat, figs_track) if figs_ is not None], ) axes_ = np.concatenate( [ axes__ for axes__ in (axes_with_cat_, axes_track_) if axes__ is not None ], ) axes = np.atleast_1d(axes_).tolist() for metric_idx, metric_name in enumerate(metric_names): axes[metric_idx].set_ylabel(mt.metricsLibrary.label(metric_name)) axes[metric_idx].grid(color="grey", alpha=0.5) for category in categories: if with_err: metric_values = [ dict_performances[value][category.name, metric_name]["mean"] for value in values ] metric_errors = _get_metric_errors( values=values, dict_performances=dict_performances, metric_name=metric_name, category_name=category.name, ) axes[metric_idx].errorbar( values, metric_values, label=category.label, marker=".", yerr=metric_errors, color=category_name_to_color.get(category.name), ) else: metric_values = [ dict_performances[value][category.name, metric_name] for value in values ] axes[metric_idx].plot( values, metric_values, label=category.label, marker=".", color=category_name_to_color.get(category.name), ) for rel_track_metric_idx, track_metric_name in enumerate(track_metric_names): track_metric_idx = rel_track_metric_idx + len(metric_names) axes[track_metric_idx].set_ylabel( mt.metricsLibrary.label(track_metric_name) ) if with_err: metric_values = [ dict_performances[value][None, track_metric_name]["mean"] for value in values ] metric_errors = _get_metric_errors( values=values, dict_performances=dict_performances, metric_name=track_metric_name, category_name=None, ) axes[track_metric_idx].errorbar( values, metric_values, marker=".", color="k", yerr=metric_errors, ) else: metric_values = [ dict_performances[value][None, track_metric_name] for value in values ] axes[track_metric_idx].plot( values, metric_values, marker=".", color="k" ) for ax, metric_name in zip(axes, metric_names + track_metric_names): ax.set_xlabel(self.varlabel) ax.grid(color="grey", alpha=0.5) if lhcb: plotools.pad_on_top(ax=ax) plotools.add_text(ax=ax, ha="left", va="top") if legend_inside: legend_kwargs = dict() else: legend_kwargs = dict(loc="center left", bbox_to_anchor=(1, 0.5)) if same_fig: if legend_inside: axes[0].legend(**legend_kwargs) else: axes[-1].legend(**legend_kwargs) else: for metric_idx, _ in enumerate(metric_names): axes[metric_idx].legend(**legend_kwargs) if same_fig: if output_path is None: assert path_or_config is not None output_path = op.join( self.get_output_dir(path_or_config=path_or_config, step=step), f"performance_given_{self.varname}{identifier}_{partition}", ) assert isinstance(figs, Figure) plotools.save_fig(fig=figs, path=output_path) else: if output_path is None: assert path_or_config is not None output_path = op.join( self.get_output_dir(path_or_config=path_or_config, step=step), f"{{metric_name}}_given_{self.varname}{identifier}_{partition}", ) assert isinstance(figs, np.ndarray) for fig, metric_name in zip(figs, metric_names + track_metric_names): plotools.save_fig(fig, output_path.format(metric_name=metric_name)) return (figs, axes, dict_performances)