Source code for pipeline.Embedding.embedding_plots
"""A module that handles the validation plots for the embedding phase specifically.
"""
from __future__ import annotations
import typing
import os.path as op
from uncertainties import unumpy as unp
import numpy.typing as npt
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from matplotlib.axes import Axes
from matplotlib import rcParams
from Embedding.embedding_validation import (
evaluate_embedding_performances_given_squared_distance_max_k_max,
)
from utils.plotutils.plotconfig import partition_to_color, partition_to_label
from utils.plotutils.plotools import save_fig, add_text
from utils.commonutils.config import load_config, get_performance_directory_experiment
from .embedding_base import EmbeddingBase
from .embedding_validation import EmbeddingDistanceMaxExplorer
metric_name_to_lhcb_text_params: typing.Dict[str, typing.Dict[str, typing.Any]] = {
"edge_efficiency": {"ha": "right", "va": "bottom"},
"edge_purity": {"ha": "right", "va": "top"},
"graph_size": {"ha": "left", "va": "top"},
}
metric_name_to_label: typing.Dict[str, str] = {
"edge_efficiency": "Average graph edge efficiency",
"edge_purity": "Average graph edge purity",
"graph_size": "\\# graph edges",
}
[docs]def plot_embedding_performance_given_squared_distance_max_k_max(
model: EmbeddingBase,
path_or_config: str | dict | None = None,
partitions: typing.List[str] | typing.Dict[str, str | None] = ["train", "val"],
n_events: int = 10,
squared_distance_max: typing.List[float] | float | None = None,
k_max: typing.List[int] | int | None = None,
show_err: bool = True,
output_wpath: str | None = None,
lhcb: bool = False,
step: str = "embedding",
overall: bool = False,
) -> typing.Tuple[
typing.Dict[str, typing.Tuple[Figure, Axes]],
typing.Dict[str, typing.Dict[str, unp.matrix]],
]:
"""Plot edge efficiency, purity and graph size as a function of the maximal
squared_distance_max or maximal number of neighbours in the k-nearest neighbour
algorithm.
Args:
model: Embedding model
path_or_config: YAML configuration. Only needed if ``output_wpath``
is not provided.
partitions: List of partitions to plot
n_events: Maximal number of events to use for each partition for performance
evaluationn
squared_distance_max: Squared maximal distance squared in the embedding space
k_max: Maximal number of neighbours
show_err: whether to show the error bars
output_wpath: wildcard path where the plots are saved, with placeholder
``{metric_name}``
Returns:
Tuples of 2 dictionary. The first dictionary associates a metric name with
the tuple of matplotlib Figure and Axes.
The second dictionary associates a metric name with another dictionary
that associates a partition with the list of metric values, for the different
``squared_distance_max`` or ``k_max`` given as input.
"""
if output_wpath is None:
assert path_or_config is not None
performance_dir = get_performance_directory_experiment(
path_or_config=path_or_config
)
step_extended = f"{step}_overall" if overall else step
output_wpath = op.join(performance_dir, step_extended, "{metric_name}")
k_max_is_array = isinstance(k_max, list)
squared_distance_max_is_array = isinstance(squared_distance_max, list)
if squared_distance_max_is_array and k_max_is_array:
raise ValueError(
"Error: Cannot vary `squared_distance_max` and `k_max` at the same time "
"but they were both provided as a list."
)
elif squared_distance_max_is_array:
list_hyperparam_values = squared_distance_max
hyperparam_name = "squared_distance_max"
elif k_max_is_array:
list_hyperparam_values = k_max
hyperparam_name = "k_max"
else:
raise ValueError(
"Either `squared_distance_max` or `k_max` should be a numpy array."
)
if isinstance(partitions, dict):
partition_labels = partitions
partition_names = list(partitions.keys())
else:
partition_labels = {partition: partition for partition in partitions}
partition_names = partitions
dict_metrics_partitions = (
evaluate_embedding_performances_given_squared_distance_max_k_max(
model=model,
partitions=partition_names,
squared_distance_max=squared_distance_max,
k_max=k_max,
n_events=n_events,
overall=overall,
)
)
dict_figs_axs = {}
for metric_name, dict_partitions in dict_metrics_partitions.items():
fig, ax = plt.subplots(figsize=(8, 6))
for partition_name, partition_label in partition_labels.items():
ax.errorbar(
x=list_hyperparam_values,
y=unp.nominal_values(dict_partitions[partition_name]),
yerr=(
unp.std_devs(dict_partitions[partition_name]) # type: ignore
if show_err
else None
),
color=partition_to_color.get(partition_name, "k"),
label=partition_to_label.get(partition_name, partition_label),
marker=".",
)
if hyperparam_name == "squared_distance_max":
label = (
r"$d_{\text{max}}^{2}$" if rcParams["text.usetex"] else r"$d_{max}^{2}$"
)
ax.set_xlabel(label)
elif hyperparam_name == "k_max":
ax.set_xlabel("Maximal number of neighbours")
else:
raise Exception()
ax.grid(color="grey", alpha=0.5)
if not all(
[
partition_to_label.get(partition_name, partition_label) is None
for partition_name, partition_label in partition_labels.items()
]
):
ax.legend()
ax.set_ylabel(metric_name_to_label[metric_name])
if lhcb:
add_text(ax, **metric_name_to_lhcb_text_params[metric_name])
save_fig(fig=fig, path=output_wpath.format(metric_name=metric_name))
dict_figs_axs[metric_name] = (fig, ax)
return dict_figs_axs, dict_metrics_partitions
[docs]def plot_best_performances_squared_distance_max(
model: EmbeddingBase,
path_or_config: str | dict,
partition: str,
list_squared_distance_max: typing.Sequence[float],
k_max: int | None = None,
n_events: int | None = None,
seed: int | None = None,
builder: str = "default",
step: str = "embedding",
identifier: str | None = None,
**kwargs,
) -> typing.Tuple[
Figure | npt.NDArray,
typing.List[Axes],
typing.Dict[
float, typing.Dict[typing.Tuple[str | None, str], typing.Dict[str, float]]
],
]:
"""Plot best performance for perfect inference as a function of the squared
maximal distance.
Args:
model: Embedding model
path_or_config: YAML configuration. Only needed if ``output_wpath``
is not provided.
list_squared_distance_max: list of squared maximal distance squared to try
k_max: Maximal number of neigbhours. If not provided, the one stored
in the model is used.
n_events: Maximal number of events to use for each partition for performance
evaluation
seed: Random seed for the random choice of ``n_events`` events
show_err: whether to show the error bars
builder: Builder to use to build the tracks after the GNN.
It can be ``default`` (build the tracks by applying a connected component
algorithm on the hits) or ``triplets`` (build triplets and form the
tracks from these triplets.)
"""
embeddingDistanceMaxExplorer = EmbeddingDistanceMaxExplorer(
model=model, builder=builder
)
if identifier is None and builder != "default":
identifier = f"_{builder}"
config = load_config(path_or_config=path_or_config)
return embeddingDistanceMaxExplorer.plot(
path_or_config=config,
partition=partition,
values=list_squared_distance_max,
n_events=n_events,
seed=seed,
k_max=k_max,
processing=config[step].get("processing"),
step=step,
identifier=identifier,
**kwargs,
)