Source code for pipeline.Embedding.embedding_validation

"""A module that defines tools to perform the validation step of the embedding step.
"""
from __future__ import annotations
import typing
import logging

from uncertainties import ufloat
from uncertainties.core import Variable
from tqdm.auto import tqdm
from uncertainties import unumpy as unp
import numpy as np
import pandas as pd
import torch

from torch_geometric.data import Data
from matplotlib.axes import Axes
from matplotlib import rcParams

from GNN import perfect_gnn
from TrackBuilding.builder import batch_from_edges_to_tracks
from utils.modelutils.exploration import ParamExplorer
from utils.loaderutils.tracks import get_tracks_from_batch
from utils.plotutils import plotools


from .build_embedding import EmbeddingInferenceBuilder
from .embedding_base import EmbeddingBase


[docs]def get_default_squared_distance_max( model: EmbeddingBase, squared_distance_max: float | None = None ) -> float: """Get the default squared distance max for inference, for a given model.""" if squared_distance_max is None: squared_distance_max_inference = model.hparams.get("squared_distance_max_infer") if squared_distance_max_inference is None: squared_distance_max_inference = model.hparams["squared_distance_max"] return squared_distance_max_inference else: return squared_distance_max
[docs]def evaluate_embedding_performance( model: EmbeddingBase, batches: typing.List[Data], squared_distance_max: float | None = None, k_max: int | None = None, overall: bool = False, ) -> typing.Tuple[Variable, Variable, Variable]: """Compute the edge efficiency and edge purity of a given model, on a subset of the train, val or test dataset. Args: model: PyTorch model inheriting from :py:class:`utils.modelutils.basemodel.ModelBase` partition: ``train``, ``val``, ``test`` (for the current already loaded test sample) or the name of a test dataset squared_distance_max: Maximal distance squared for the KNN. If not given, taken from the hyperparameter in the model. k_max: Maximal number of neighbours for the KNN. If not given, taken from the hyperparameter in the model. n_events: Number of events to compute the performance metrics on seed: Seed used to randomly select the ``n_events`` Returns: A tuple of 3 ufloat numbers corresponding to the event-based average of the edge efficiency and edge purity, and the graph size """ # Handle default values for `k_max` and `squared_distance_max` squared_distance_max = get_default_squared_distance_max( model=model, squared_distance_max=squared_distance_max ) k_max = model.hparams["k_max"] if k_max is None else k_max n_batches = len(batches) # Compute performance for each batch with torch.no_grad(): efficiencies = np.full(shape=n_batches, fill_value=np.nan) purities = np.full(shape=n_batches, fill_value=np.nan) graph_sizes = np.full(shape=n_batches, fill_value=np.nan) for batch_idx, batch in enumerate(batches): results = model.inference( batch=batch, squared_distance_max=squared_distance_max, k_max=k_max, evaluate=True, overall=overall, ) efficiencies[batch_idx] = results["eff"] purities[batch_idx] = results["pur"] graph_sizes[batch_idx] = results["edge_indices"].shape[1] return ( ufloat(efficiencies.mean(), efficiencies.std()), ufloat(purities.mean(), purities.std()), ufloat(graph_sizes.mean(), graph_sizes.std()), )
[docs]def evaluate_embedding_performances_given_squared_distance_max_k_max( model: EmbeddingBase, partitions: typing.List[str] = ["train", "val"], n_events: int = 10, squared_distance_max: typing.List[float] | float | None = None, k_max: typing.List[int] | int | None = None, seed: int | None = None, overall: bool = False, ) -> typing.Dict[str, typing.Dict[str, unp.matrix]]: """Compute edge efficiency, purity and graph size as a function of the maximal squared_distance_max or maximal number of neighbours in the k-nearest neighbour algorithm. Args: model: Embedding model path_or_config: YAML configuration. Only needed if ``output_wpath`` is not provided. partitions: List of partitions to plot n_events: Maximal number of events to use for each partition for performance evaluationn squared_distance_max: Squared maximal distance squared in the embedding space k_max: Maximal number of neighbours seed: Random seed for the random choice of ``n_events`` events Returns: Dictionary that associates associates a metric name with another dictionary that associates a partition with the list of metric values, for the different ``squared_distance_max`` or ``k_max`` given as input. """ if isinstance(k_max, list): list_hyperparam_values = k_max hyperparam_name = "k_max" elif isinstance(squared_distance_max, list): list_hyperparam_values = squared_distance_max hyperparam_name = "squared_distance_max" else: raise ValueError( "Both `k_max` and `squared_distance_max` are provided as array but only one " "is supported." ) dict_metrics_partitions = { "edge_efficiency": {}, "edge_purity": {}, "graph_size": {}, } for partition in partitions: logging.info(f"Compute edge performance metrics for {partition}") batches = model.fetch_partition( partition=partition, n_events=n_events, shuffle=True, seed=seed, map_location=model.device, ) # Move batches to save device as model # batches = [batch.to(model.device) for batch in batches] # type: ignore efficiencies = [] purities = [] graph_sizes = [] for hyperparam_value in (pbar := tqdm(list_hyperparam_values)): pbar.set_description( f"Loop over {hyperparam_name} (current value: {hyperparam_value})" ) ( efficiency, purity, graph_size, ) = evaluate_embedding_performance( model=model, batches=batches, # type: ignore squared_distance_max=( hyperparam_value if hyperparam_name == "squared_distance_max" else squared_distance_max ), # type: ignore k_max=( int(hyperparam_value) # type: ignore if hyperparam_name == "k_max" else k_max ), overall=overall, ) efficiencies.append(efficiency) purities.append(purity) graph_sizes.append(graph_size) dict_metrics_partitions["edge_efficiency"][partition] = np.array(efficiencies) dict_metrics_partitions["edge_purity"][partition] = np.array(purities) dict_metrics_partitions["graph_size"][partition] = np.array(graph_sizes) return dict_metrics_partitions
[docs]class EmbeddingDistanceMaxExplorer(ParamExplorer): """A class that allows to vary the maximal squared distance and compare the best metric performances of track finding, in the case where all the fake edges are filtered out. """ def __init__(self, model: EmbeddingBase, builder: str | None = "default") -> None: label = r"$d_{\text{max}}^{2}$" if rcParams["text.usetex"] else r"$d_{max}^{2}$" super().__init__(model, varname="squared_max_distance", varlabel=label) self.builder = str(builder) if builder is not None else "default" @property def default_step(self) -> str: return "embedding"
[docs] def add_lhcb_text(self, ax: Axes, metric_name: str): if metric_name in ["efficiency"]: plotools.add_text(ax, ha="right", y=0.3) elif metric_name == "hit_efficiency_per_candidate": plotools.add_text(ax, ha="right", y=0.2) elif metric_name == "clone_rate": plotools.add_text(ax, ha="left", va="center") else: plotools.add_text(ax, ha="left", y=0.3)
[docs] def get_tracks( self, value: float, batches: typing.List[Data], k_max: int | None = None, processing: str | typing.List[str] | None = None, ) -> pd.DataFrame: assert isinstance(self.model, EmbeddingBase) # Run embedding inference embeddingInferenceBuilder = EmbeddingInferenceBuilder( model=self.model, k_max=self.model.hparams["k_max"] if k_max is None else k_max, squared_distance_max=value, max_plane_diff=self.model.hparams.get("max_plane_diff"), ) batches = [ embeddingInferenceBuilder.process_one_step( batch=batch.clone(), processing=processing, ) for batch in tqdm(batches, desc="Graph Building") ] # Run perfect GNN inference if self.builder == "default": perfectInferenceBuilder = perfect_gnn.PerfectInferenceBuilder() batches = [ perfectInferenceBuilder.construct_downstream(batch=batch) for batch in batches ] # Run track reconstruction batches = [ batch_from_edges_to_tracks( batch=batch, edge_index=batch["edge_index"][:, batch["scores"]] ).cpu() for batch in batches ] elif self.builder == "triplet": perfectTripletInferenceBuilder = ( perfect_gnn.PerfectTripletInferenceBuilder() ) batches = [ perfectTripletInferenceBuilder.construct_downstream(batch=batch) for batch in batches ] else: raise ValueError( f"`builder` attribute is {self.builder} which is not recognised." ) # Define dataframe of tracks return pd.concat( tuple(get_tracks_from_batch(batch=batch.cpu()) for batch in batches) ).drop_duplicates()