Source code for pipeline.GNN.gnn_validation

from __future__ import annotations
import typing

from tqdm.auto import tqdm
import pandas as pd
from matplotlib.axes import Axes
import torch
from torch_geometric.data import Data

from .triplet_gnn_base import TripletGNNBase, get_df_edges_from_batch_only
from utils.loaderutils.dataiterator import LazyDatasetBase
from TrackBuilding.builder import (
    batch_from_triplets_to_tracks,
    batch_from_edges_to_tracks,
)
from utils.loaderutils.tracks import get_tracks_from_batch
from utils.modelutils.exploration import ParamExplorer
from utils.plotutils import plotools
from utils.graphutils.tripletbuilding import from_edge_index_to_triplet_indices


[docs]class GNNScoreCutExplorer(ParamExplorer): """A class that allows to vary the score cut after the GNN, and compare the metric performances of track finding. """ def __init__(self, model: TripletGNNBase, builder: str = "default") -> None: super().__init__(model, varname="score_cut", varlabel="Score cut") self.builder = str(builder) assert isinstance(model, TripletGNNBase) self.model = model @property def default_step(self) -> str: return "gnn" def _add_score(self, batch: Data): outputs = self.model.shared_evaluation(batch, log=False, with_triplets=False) batch["edge_score"] = outputs["edge_score"] return batch
[docs] def run_inference(self, batches: typing.List[Data] | LazyDatasetBase): with torch.no_grad(): batches = [self._add_score(batch=batch).cpu() for batch in batches] return batches
def _filter_edges(self, batch: Data, edge_score_min: float) -> Data: if self.builder == "default": batch = batch_from_edges_to_tracks( batch=batch, edge_index=batch["edge_index"][:, batch["edge_score"] > edge_score_min], ) return batch elif self.builder == "triplet": # Build triplets filtered_edge_index, edge_mask = self.model.filter_edges( edge_index=batch["edge_index"], edge_score=batch["edge_score"], edge_score_cut=edge_score_min, ) dict_triplet_indices = from_edge_index_to_triplet_indices( edge_index=filtered_edge_index ) batch["filtered_edge_index"] = filtered_edge_index batch["triplet_indices"] = dict_triplet_indices batch["triplet_scores"] = { triplet_name: torch.ones( triplet_index.shape[1], device=triplet_index.device ) for triplet_name, triplet_index in dict_triplet_indices.items() } return batch_from_triplets_to_tracks( batch=batch, triplet_indices=batch["triplet_indices"], triplet_scores=batch["triplet_scores"], edge_index=batch["filtered_edge_index"], triplet_score_cut=0.0, ).cpu() else: raise ValueError(f"builder `{self.builder}` is not recognised")
[docs] def get_tracks(self, value: float, batches: typing.List[Data]): # Run track reconstruction batches = [ self._filter_edges(batch=batch, edge_score_min=value) for batch in batches ] # Define dataframe of tracks return pd.concat( tuple(get_tracks_from_batch(batch=batch) for batch in batches) ).drop_duplicates()
[docs]class TripletGNNScoreCutExplorer(ParamExplorer): """A class that allows to vary the score cut after the GNN, and compare the metric performances of track finding. """ def __init__(self, model: TripletGNNBase) -> None: super().__init__(model, varname="score_cut", varlabel="Score cut")
[docs] def run_inference(self, batches: typing.List[Data] | LazyDatasetBase): assert isinstance(self.model, TripletGNNBase) with torch.no_grad(): for batch in tqdm(batches, desc="GNN inference"): # type: ignore outputs = self.model.shared_evaluation( batch, log=False, with_triplets=True ) batch.triplet_indices = outputs["triplet_indices"] batch.triplet_scores = outputs["triplet_scores"] batch.filtered_edge_index = outputs["filtered_edge_index"] return batches
[docs] def add_lhcb_text(self, ax: Axes, metric_name: str): if metric_name in ["efficiency", "hit_efficiency_per_candidate"]: plotools.add_text(ax, ha="left", y=0.3) elif metric_name == "ghost_rate": plotools.add_text(ax, ha="right", va="top") elif metric_name == "clone_rate": plotools.add_text(ax, ha="center", y=0.65) else: plotools.add_text(ax, ha="left", y=0.3)
[docs] def get_tracks( self, value: float, batches: typing.List[Data], ): # Run track reconstruction batches = [ batch_from_triplets_to_tracks( batch=batch, triplet_indices=batch["triplet_indices"], triplet_scores=batch["triplet_scores"], edge_index=batch["filtered_edge_index"], triplet_score_cut=value, ).cpu() for batch in batches ] # Define dataframe of tracks return pd.concat( tuple(get_tracks_from_batch(batch=batch) for batch in batches) ).drop_duplicates()