Source code for pipeline.Embedding.embedding_validation
"""A module that defines tools to perform the validation step of the embedding step.
"""
from __future__ import annotations
import typing
import logging
from uncertainties import ufloat
from uncertainties.core import Variable
from tqdm.auto import tqdm
from uncertainties import unumpy as unp
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import Data
from matplotlib.axes import Axes
from matplotlib import rcParams
from GNN import perfect_gnn
from TrackBuilding.builder import batch_from_edges_to_tracks
from utils.modelutils.exploration import ParamExplorer
from utils.loaderutils.tracks import get_tracks_from_batch
from utils.plotutils import plotools
from .build_embedding import EmbeddingInferenceBuilder
from .embedding_base import EmbeddingBase
[docs]def get_default_squared_distance_max(
model: EmbeddingBase, squared_distance_max: float | None = None
) -> float:
"""Get the default squared distance max for inference, for a given model."""
if squared_distance_max is None:
squared_distance_max_inference = model.hparams.get("squared_distance_max_infer")
if squared_distance_max_inference is None:
squared_distance_max_inference = model.hparams["squared_distance_max"]
return squared_distance_max_inference
else:
return squared_distance_max
[docs]def evaluate_embedding_performance(
model: EmbeddingBase,
batches: typing.List[Data],
squared_distance_max: float | None = None,
k_max: int | None = None,
overall: bool = False,
) -> typing.Tuple[Variable, Variable, Variable]:
"""Compute the edge efficiency and edge purity of a given model, on a subset
of the train, val or test dataset.
Args:
model: PyTorch model inheriting from
:py:class:`utils.modelutils.basemodel.ModelBase`
partition: ``train``, ``val``, ``test`` (for the current already loaded
test sample) or the name of a test dataset
squared_distance_max: Maximal distance squared for the KNN. If not given, taken
from the hyperparameter in the model.
k_max: Maximal number of neighbours for the KNN. If not given,
taken from the hyperparameter in the model.
n_events: Number of events to compute the performance metrics on
seed: Seed used to randomly select the ``n_events``
Returns:
A tuple of 3 ufloat numbers corresponding to the event-based average of the
edge efficiency and edge purity, and the graph size
"""
# Handle default values for `k_max` and `squared_distance_max`
squared_distance_max = get_default_squared_distance_max(
model=model, squared_distance_max=squared_distance_max
)
k_max = model.hparams["k_max"] if k_max is None else k_max
n_batches = len(batches)
# Compute performance for each batch
with torch.no_grad():
efficiencies = np.full(shape=n_batches, fill_value=np.nan)
purities = np.full(shape=n_batches, fill_value=np.nan)
graph_sizes = np.full(shape=n_batches, fill_value=np.nan)
for batch_idx, batch in enumerate(batches):
results = model.inference(
batch=batch,
squared_distance_max=squared_distance_max,
k_max=k_max,
evaluate=True,
overall=overall,
)
efficiencies[batch_idx] = results["eff"]
purities[batch_idx] = results["pur"]
graph_sizes[batch_idx] = results["edge_indices"].shape[1]
return (
ufloat(efficiencies.mean(), efficiencies.std()),
ufloat(purities.mean(), purities.std()),
ufloat(graph_sizes.mean(), graph_sizes.std()),
)
[docs]def evaluate_embedding_performances_given_squared_distance_max_k_max(
model: EmbeddingBase,
partitions: typing.List[str] = ["train", "val"],
n_events: int = 10,
squared_distance_max: typing.List[float] | float | None = None,
k_max: typing.List[int] | int | None = None,
seed: int | None = None,
overall: bool = False,
) -> typing.Dict[str, typing.Dict[str, unp.matrix]]:
"""Compute edge efficiency, purity and graph size as a function of the maximal
squared_distance_max or maximal number of neighbours in the k-nearest neighbour
algorithm.
Args:
model: Embedding model
path_or_config: YAML configuration. Only needed if ``output_wpath``
is not provided.
partitions: List of partitions to plot
n_events: Maximal number of events to use for each partition for performance
evaluationn
squared_distance_max: Squared maximal distance squared in the embedding space
k_max: Maximal number of neighbours
seed: Random seed for the random choice of ``n_events`` events
Returns:
Dictionary that associates associates a metric name with another dictionary
that associates a partition with the list of metric values, for the different
``squared_distance_max`` or ``k_max`` given as input.
"""
if isinstance(k_max, list):
list_hyperparam_values = k_max
hyperparam_name = "k_max"
elif isinstance(squared_distance_max, list):
list_hyperparam_values = squared_distance_max
hyperparam_name = "squared_distance_max"
else:
raise ValueError(
"Both `k_max` and `squared_distance_max` are provided as array but only one "
"is supported."
)
dict_metrics_partitions = {
"edge_efficiency": {},
"edge_purity": {},
"graph_size": {},
}
for partition in partitions:
logging.info(f"Compute edge performance metrics for {partition}")
batches = model.fetch_partition(
partition=partition,
n_events=n_events,
shuffle=True,
seed=seed,
map_location=model.device,
)
# Move batches to save device as model
# batches = [batch.to(model.device) for batch in batches] # type: ignore
efficiencies = []
purities = []
graph_sizes = []
for hyperparam_value in (pbar := tqdm(list_hyperparam_values)):
pbar.set_description(
f"Loop over {hyperparam_name} (current value: {hyperparam_value})"
)
(
efficiency,
purity,
graph_size,
) = evaluate_embedding_performance(
model=model,
batches=batches, # type: ignore
squared_distance_max=(
hyperparam_value
if hyperparam_name == "squared_distance_max"
else squared_distance_max
), # type: ignore
k_max=(
int(hyperparam_value) # type: ignore
if hyperparam_name == "k_max"
else k_max
),
overall=overall,
)
efficiencies.append(efficiency)
purities.append(purity)
graph_sizes.append(graph_size)
dict_metrics_partitions["edge_efficiency"][partition] = np.array(efficiencies)
dict_metrics_partitions["edge_purity"][partition] = np.array(purities)
dict_metrics_partitions["graph_size"][partition] = np.array(graph_sizes)
return dict_metrics_partitions
[docs]class EmbeddingDistanceMaxExplorer(ParamExplorer):
"""A class that allows to vary the maximal squared distance
and compare the best metric performances of track finding,
in the case where all the fake edges are filtered out.
"""
def __init__(self, model: EmbeddingBase, builder: str | None = "default") -> None:
label = r"$d_{\text{max}}^{2}$" if rcParams["text.usetex"] else r"$d_{max}^{2}$"
super().__init__(model, varname="squared_max_distance", varlabel=label)
self.builder = str(builder) if builder is not None else "default"
@property
def default_step(self) -> str:
return "embedding"
[docs] def add_lhcb_text(self, ax: Axes, metric_name: str):
if metric_name in ["efficiency"]:
plotools.add_text(ax, ha="right", y=0.3)
elif metric_name == "hit_efficiency_per_candidate":
plotools.add_text(ax, ha="right", y=0.2)
elif metric_name == "clone_rate":
plotools.add_text(ax, ha="left", va="center")
else:
plotools.add_text(ax, ha="left", y=0.3)
[docs] def get_tracks(
self,
value: float,
batches: typing.List[Data],
k_max: int | None = None,
processing: str | typing.List[str] | None = None,
) -> pd.DataFrame:
assert isinstance(self.model, EmbeddingBase)
# Run embedding inference
embeddingInferenceBuilder = EmbeddingInferenceBuilder(
model=self.model,
k_max=self.model.hparams["k_max"] if k_max is None else k_max,
squared_distance_max=value,
max_plane_diff=self.model.hparams.get("max_plane_diff"),
)
batches = [
embeddingInferenceBuilder.process_one_step(
batch=batch.clone(),
processing=processing,
)
for batch in tqdm(batches, desc="Graph Building")
]
# Run perfect GNN inference
if self.builder == "default":
perfectInferenceBuilder = perfect_gnn.PerfectInferenceBuilder()
batches = [
perfectInferenceBuilder.construct_downstream(batch=batch)
for batch in batches
]
# Run track reconstruction
batches = [
batch_from_edges_to_tracks(
batch=batch, edge_index=batch["edge_index"][:, batch["scores"]]
).cpu()
for batch in batches
]
elif self.builder == "triplet":
perfectTripletInferenceBuilder = (
perfect_gnn.PerfectTripletInferenceBuilder()
)
batches = [
perfectTripletInferenceBuilder.construct_downstream(batch=batch)
for batch in batches
]
else:
raise ValueError(
f"`builder` attribute is {self.builder} which is not recognised."
)
# Define dataframe of tracks
return pd.concat(
tuple(get_tracks_from_batch(batch=batch.cpu()) for batch in batches)
).drop_duplicates()