"""A module that defines :py:class:`ParamExplorer`, a class that allows to vary
a parameter and check the efficiency that is obtained for this choice.
"""
from __future__ import annotations
import typing
import abc
import os.path as op
from tqdm.auto import tqdm
import numpy as np
import numpy.typing as npt
import pandas as pd
from torch_geometric.data import Data
import montetracko as mt
import montetracko.lhcb as mtb
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from Evaluation.matching import perform_matching
from utils.loaderutils.preprocessing import load_preprocessed_dataframes
from utils.plotutils import plotools
from utils.commonutils.config import get_performance_directory_experiment
from .basemodel import ModelBase
from ..loaderutils.dataiterator import LazyDatasetBase
def _get_metric_errors(values, dict_performances, metric_name, category_name):
metric_errors = []
for value in values:
output = dict_performances[value][category_name, metric_name]
if "err_pos" in output and "err_neg" in output:
metric_errors.append([output["err_neg"], output["err_pos"]])
elif "err" in output:
metric_errors.append(output["err"])
else:
raise Exception(
"Neither `err_pos` and `err_neg`, not `err` " "is in the dictionary."
)
return np.array(metric_errors).T
[docs]class ParamExplorer(abc.ABC):
"""A class that allow to explore the track matching performance for various choices
of a given parameter of a trained model (e.g., best efficiency as a function
of the squared maximal distance of the kNN)
"""
def __init__(
self, model: ModelBase, varname: str, varlabel: str | None = None
) -> None:
self.model = model
self.varname = str(varname)
self.varlabel = str(varlabel) if varlabel is not None else self.varname
@property
def default_step(self) -> str:
"""Name of the temp to fall back to if not provided."""
raise NotImplementedError("No default step were provided.")
[docs] def load_preprocessed_dataframes(
self,
batches: typing.List[Data] | LazyDatasetBase,
) -> typing.Tuple[pd.DataFrame, pd.DataFrame]:
"""Load the preprocessed dataframes of hits-particles and particles associated
with the PyTorch DataSets given as input.
Args:
batches: list of PyTorch Geometric Data objects
Returns:
Tuple of dataframes of hits-particles and particles
"""
truncated_paths = [batch.truncated_path for batch in batches]
df_hits_particles = load_preprocessed_dataframes(
truncated_paths=truncated_paths,
ending="-hits_particles",
columns=["particle_id", "hit_id"],
)
df_particles = load_preprocessed_dataframes(
truncated_paths=truncated_paths,
ending="-particles",
columns=[
"particle_id",
"has_velo",
"has_scifi",
"pid",
"eta",
"from_sdecay",
"mother_pid",
"p",
],
)
return df_hits_particles, df_particles
[docs] def run_inference(
self, batches: typing.List[Data] | LazyDatasetBase
) -> typing.List[Data]:
"""Run the inference on a batch.
Args:
batches: List of batches
Returns:
List of inferred batches
"""
return [batch for batch in batches]
[docs] @abc.abstractmethod
def get_tracks(
self, value: float, batches: typing.List[Data] | LazyDatasetBase, **kwargs
) -> pd.DataFrame:
"""Get the dataframe of tracks from the inferred batches.
Args:
value: current value of the parameter that is explored
batches: list of inferred batches
Returns:
Dataframe of tracks, with columns ``track_id`` and ``hit_id``
"""
raise NotImplementedError()
[docs] def get_output_dir(self, path_or_config: str | dict, step: str):
return op.join(get_performance_directory_experiment(path_or_config), step)
[docs] def add_lhcb_text(self, ax: Axes, metric_name: str):
plotools.add_text(ax, ha="right", y=0.3)
[docs] def plot(
self,
partition: str,
values: typing.Sequence[float],
n_events: int | None = None,
seed: int | None = None,
metric_names: typing.List[str] | None = None,
categories: typing.List[mt.requirement.Category] | None = None,
track_metric_names: typing.List[str] | None = None,
identifier: str | None = None,
path_or_config: str | dict | None = None,
output_path: str | None = None,
same_fig: bool = True,
lhcb: bool = False,
category_name_to_color: dict | None = None,
step: str | None = None,
with_err: bool = True,
legend_inside: bool | None = None,
**kwargs,
) -> typing.Tuple[
Figure | npt.NDArray,
typing.List[Axes],
typing.Dict[
float, typing.Dict[typing.Tuple[str | None, str], typing.Dict[str, float]]
],
]:
"""Plot metrics in differences categories for different hyperparameter
values.
Args:
path_or_config: pipeline configuration
partition: ``train``, ``val`` or the name of a test dataset
values: list of values for the hyperparameter of interest
n_events: Maximal number of events for the evaluation
seed: Random seed for randomly selecting ``n_events``
metric_names: List of metric names to compute. If not set,
``efficiency``, ``clone_rate`` and ``hit_efficiency_per_candidate``
are computed and plotted.
categories: list of categories to compute the performance in.
By default, this is "Velo Without Electrons" and "Long Electrons".
track_metric_names: list of track-related metrics (that do not
depend on any category) to plot
identifier: Identifier for the figure name. Only used if ``output_path``
is not provided
output_path: Output path where the figure is saved. If ``same_fig``
is set to ``False``, the string should contain the placeholder
``{metric_name}``.
same_fig: in the case where several metrics are plotted, whether to
plot them in the same matplotlib figure object
**kwargs: Other keyword arguments passed to
:py:func:`ParamExplorer.compute_performance_metrics`
Returns:
3-tuple of the Matplotlib Figures and Axes, and the dictionary of
metric values for every tuple ``(value, category.name, metric_name)``
"""
if legend_inside is None:
legend_inside = same_fig
if category_name_to_color is None:
category_name_to_color = {}
if step is None:
step = self.default_step
if metric_names is None:
metric_names = ["efficiency", "clone_rate", "hit_efficiency_per_candidate"]
if categories is None:
categories = [
mtb.category.category_velo_no_electrons,
mtb.category.category_long_only_electrons,
]
if identifier is None:
identifier = ""
dict_performances = self.compute_performance_metrics(
values=values,
partition=partition,
metric_names=metric_names,
track_metric_names=track_metric_names,
categories=categories,
n_events=n_events,
seed=seed,
with_err=with_err,
**kwargs,
)
if track_metric_names is None:
track_metric_names = []
if same_fig:
figs, axes_ = plotools.get_figs_axes_on_grid(
1,
len(metric_names) + len(track_metric_names),
same_fig=True,
figsize=(8, 6),
)
else:
if metric_names:
figs_with_cat, axes_with_cat_ = plotools.get_figs_axes_on_grid(
1,
len(metric_names),
same_fig=False,
figsize=(8, 6) if legend_inside else (12, 6),
)
else:
figs_with_cat = None
axes_with_cat_ = None
if track_metric_names:
figs_track, axes_track_ = plotools.get_figs_axes_on_grid(
1,
len(track_metric_names),
same_fig=False,
figsize=(8, 6),
)
else:
figs_track = None
axes_track_ = None
figs = np.concatenate(
[figs_ for figs_ in (figs_with_cat, figs_track) if figs_ is not None],
)
axes_ = np.concatenate(
[
axes__
for axes__ in (axes_with_cat_, axes_track_)
if axes__ is not None
],
)
axes = np.atleast_1d(axes_).tolist()
for metric_idx, metric_name in enumerate(metric_names):
axes[metric_idx].set_ylabel(mt.metricsLibrary.label(metric_name))
axes[metric_idx].grid(color="grey", alpha=0.5)
for category in categories:
if with_err:
metric_values = [
dict_performances[value][category.name, metric_name]["mean"]
for value in values
]
metric_errors = _get_metric_errors(
values=values,
dict_performances=dict_performances,
metric_name=metric_name,
category_name=category.name,
)
axes[metric_idx].errorbar(
values,
metric_values,
label=category.label,
marker=".",
yerr=metric_errors,
color=category_name_to_color.get(category.name),
)
else:
metric_values = [
dict_performances[value][category.name, metric_name]
for value in values
]
axes[metric_idx].plot(
values,
metric_values,
label=category.label,
marker=".",
color=category_name_to_color.get(category.name),
)
for rel_track_metric_idx, track_metric_name in enumerate(track_metric_names):
track_metric_idx = rel_track_metric_idx + len(metric_names)
axes[track_metric_idx].set_ylabel(
mt.metricsLibrary.label(track_metric_name)
)
if with_err:
metric_values = [
dict_performances[value][None, track_metric_name]["mean"]
for value in values
]
metric_errors = _get_metric_errors(
values=values,
dict_performances=dict_performances,
metric_name=track_metric_name,
category_name=None,
)
axes[track_metric_idx].errorbar(
values,
metric_values,
marker=".",
color="k",
yerr=metric_errors,
)
else:
metric_values = [
dict_performances[value][None, track_metric_name]
for value in values
]
axes[track_metric_idx].plot(
values, metric_values, marker=".", color="k"
)
for ax, metric_name in zip(axes, metric_names + track_metric_names):
ax.set_xlabel(self.varlabel)
ax.grid(color="grey", alpha=0.5)
if lhcb:
plotools.pad_on_top(ax=ax)
plotools.add_text(ax=ax, ha="left", va="top")
if legend_inside:
legend_kwargs = dict()
else:
legend_kwargs = dict(loc="center left", bbox_to_anchor=(1, 0.5))
if same_fig:
if legend_inside:
axes[0].legend(**legend_kwargs)
else:
axes[-1].legend(**legend_kwargs)
else:
for metric_idx, _ in enumerate(metric_names):
axes[metric_idx].legend(**legend_kwargs)
if same_fig:
if output_path is None:
assert path_or_config is not None
output_path = op.join(
self.get_output_dir(path_or_config=path_or_config, step=step),
f"performance_given_{self.varname}{identifier}_{partition}",
)
assert isinstance(figs, Figure)
plotools.save_fig(fig=figs, path=output_path)
else:
if output_path is None:
assert path_or_config is not None
output_path = op.join(
self.get_output_dir(path_or_config=path_or_config, step=step),
f"{{metric_name}}_given_{self.varname}{identifier}_{partition}",
)
assert isinstance(figs, np.ndarray)
for fig, metric_name in zip(figs, metric_names + track_metric_names):
plotools.save_fig(fig, output_path.format(metric_name=metric_name))
return (figs, axes, dict_performances)