Source code for pipeline.Embedding.embedding_plots

"""A module that handles the validation plots for the embedding phase specifically.
"""

from __future__ import annotations
import typing
import os.path as op

from uncertainties import unumpy as unp
import numpy.typing as npt
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from matplotlib.axes import Axes
from matplotlib import rcParams

from Embedding.embedding_validation import (
    evaluate_embedding_performances_given_squared_distance_max_k_max,
)
from utils.plotutils.plotconfig import partition_to_color, partition_to_label
from utils.plotutils.plotools import save_fig, add_text
from utils.commonutils.config import load_config, get_performance_directory_experiment
from .embedding_base import EmbeddingBase
from .embedding_validation import EmbeddingDistanceMaxExplorer


metric_name_to_lhcb_text_params: typing.Dict[str, typing.Dict[str, typing.Any]] = {
    "edge_efficiency": {"ha": "right", "va": "bottom"},
    "edge_purity": {"ha": "right", "va": "top"},
    "graph_size": {"ha": "left", "va": "top"},
}
metric_name_to_label: typing.Dict[str, str] = {
    "edge_efficiency": "Average graph edge efficiency",
    "edge_purity": "Average graph edge purity",
    "graph_size": "\\# graph edges",
}


[docs]def plot_embedding_performance_given_squared_distance_max_k_max( model: EmbeddingBase, path_or_config: str | dict | None = None, partitions: typing.List[str] | typing.Dict[str, str | None] = ["train", "val"], n_events: int = 10, squared_distance_max: typing.List[float] | float | None = None, k_max: typing.List[int] | int | None = None, show_err: bool = True, output_wpath: str | None = None, lhcb: bool = False, step: str = "embedding", overall: bool = False, ) -> typing.Tuple[ typing.Dict[str, typing.Tuple[Figure, Axes]], typing.Dict[str, typing.Dict[str, unp.matrix]], ]: """Plot edge efficiency, purity and graph size as a function of the maximal squared_distance_max or maximal number of neighbours in the k-nearest neighbour algorithm. Args: model: Embedding model path_or_config: YAML configuration. Only needed if ``output_wpath`` is not provided. partitions: List of partitions to plot n_events: Maximal number of events to use for each partition for performance evaluationn squared_distance_max: Squared maximal distance squared in the embedding space k_max: Maximal number of neighbours show_err: whether to show the error bars output_wpath: wildcard path where the plots are saved, with placeholder ``{metric_name}`` Returns: Tuples of 2 dictionary. The first dictionary associates a metric name with the tuple of matplotlib Figure and Axes. The second dictionary associates a metric name with another dictionary that associates a partition with the list of metric values, for the different ``squared_distance_max`` or ``k_max`` given as input. """ if output_wpath is None: assert path_or_config is not None performance_dir = get_performance_directory_experiment( path_or_config=path_or_config ) step_extended = f"{step}_overall" if overall else step output_wpath = op.join(performance_dir, step_extended, "{metric_name}") k_max_is_array = isinstance(k_max, list) squared_distance_max_is_array = isinstance(squared_distance_max, list) if squared_distance_max_is_array and k_max_is_array: raise ValueError( "Error: Cannot vary `squared_distance_max` and `k_max` at the same time " "but they were both provided as a list." ) elif squared_distance_max_is_array: list_hyperparam_values = squared_distance_max hyperparam_name = "squared_distance_max" elif k_max_is_array: list_hyperparam_values = k_max hyperparam_name = "k_max" else: raise ValueError( "Either `squared_distance_max` or `k_max` should be a numpy array." ) if isinstance(partitions, dict): partition_labels = partitions partition_names = list(partitions.keys()) else: partition_labels = {partition: partition for partition in partitions} partition_names = partitions dict_metrics_partitions = ( evaluate_embedding_performances_given_squared_distance_max_k_max( model=model, partitions=partition_names, squared_distance_max=squared_distance_max, k_max=k_max, n_events=n_events, overall=overall, ) ) dict_figs_axs = {} for metric_name, dict_partitions in dict_metrics_partitions.items(): fig, ax = plt.subplots(figsize=(8, 6)) for partition_name, partition_label in partition_labels.items(): ax.errorbar( x=list_hyperparam_values, y=unp.nominal_values(dict_partitions[partition_name]), yerr=( unp.std_devs(dict_partitions[partition_name]) # type: ignore if show_err else None ), color=partition_to_color.get(partition_name, "k"), label=partition_to_label.get(partition_name, partition_label), marker=".", ) if hyperparam_name == "squared_distance_max": label = ( r"$d_{\text{max}}^{2}$" if rcParams["text.usetex"] else r"$d_{max}^{2}$" ) ax.set_xlabel(label) elif hyperparam_name == "k_max": ax.set_xlabel("Maximal number of neighbours") else: raise Exception() ax.grid(color="grey", alpha=0.5) if not all( [ partition_to_label.get(partition_name, partition_label) is None for partition_name, partition_label in partition_labels.items() ] ): ax.legend() ax.set_ylabel(metric_name_to_label[metric_name]) if lhcb: add_text(ax, **metric_name_to_lhcb_text_params[metric_name]) save_fig(fig=fig, path=output_wpath.format(metric_name=metric_name)) dict_figs_axs[metric_name] = (fig, ax) return dict_figs_axs, dict_metrics_partitions
[docs]def plot_best_performances_squared_distance_max( model: EmbeddingBase, path_or_config: str | dict, partition: str, list_squared_distance_max: typing.Sequence[float], k_max: int | None = None, n_events: int | None = None, seed: int | None = None, builder: str = "default", step: str = "embedding", identifier: str | None = None, **kwargs, ) -> typing.Tuple[ Figure | npt.NDArray, typing.List[Axes], typing.Dict[ float, typing.Dict[typing.Tuple[str | None, str], typing.Dict[str, float]] ], ]: """Plot best performance for perfect inference as a function of the squared maximal distance. Args: model: Embedding model path_or_config: YAML configuration. Only needed if ``output_wpath`` is not provided. list_squared_distance_max: list of squared maximal distance squared to try k_max: Maximal number of neigbhours. If not provided, the one stored in the model is used. n_events: Maximal number of events to use for each partition for performance evaluation seed: Random seed for the random choice of ``n_events`` events show_err: whether to show the error bars builder: Builder to use to build the tracks after the GNN. It can be ``default`` (build the tracks by applying a connected component algorithm on the hits) or ``triplets`` (build triplets and form the tracks from these triplets.) """ embeddingDistanceMaxExplorer = EmbeddingDistanceMaxExplorer( model=model, builder=builder ) if identifier is None and builder != "default": identifier = f"_{builder}" config = load_config(path_or_config=path_or_config) return embeddingDistanceMaxExplorer.plot( path_or_config=config, partition=partition, values=list_squared_distance_max, n_events=n_events, seed=seed, k_max=k_max, processing=config[step].get("processing"), step=step, identifier=identifier, **kwargs, )