from __future__ import annotations
import typing
from tqdm.auto import tqdm
import pandas as pd
from matplotlib.axes import Axes
import torch
from torch_geometric.data import Data
from .triplet_gnn_base import TripletGNNBase, get_df_edges_from_batch_only
from utils.loaderutils.dataiterator import LazyDatasetBase
from TrackBuilding.builder import (
batch_from_triplets_to_tracks,
batch_from_edges_to_tracks,
)
from utils.loaderutils.tracks import get_tracks_from_batch
from utils.modelutils.exploration import ParamExplorer
from utils.plotutils import plotools
from utils.graphutils.tripletbuilding import from_edge_index_to_triplet_indices
[docs]class GNNScoreCutExplorer(ParamExplorer):
"""A class that allows to vary the score cut after the GNN, and compare the metric
performances of track finding.
"""
def __init__(self, model: TripletGNNBase, builder: str = "default") -> None:
super().__init__(model, varname="score_cut", varlabel="Score cut")
self.builder = str(builder)
assert isinstance(model, TripletGNNBase)
self.model = model
@property
def default_step(self) -> str:
return "gnn"
def _add_score(self, batch: Data):
outputs = self.model.shared_evaluation(batch, log=False, with_triplets=False)
batch["edge_score"] = outputs["edge_score"]
return batch
[docs] def run_inference(self, batches: typing.List[Data] | LazyDatasetBase):
with torch.no_grad():
batches = [self._add_score(batch=batch).cpu() for batch in batches]
return batches
def _filter_edges(self, batch: Data, edge_score_min: float) -> Data:
if self.builder == "default":
batch = batch_from_edges_to_tracks(
batch=batch,
edge_index=batch["edge_index"][:, batch["edge_score"] > edge_score_min],
)
return batch
elif self.builder == "triplet":
# Build triplets
filtered_edge_index, edge_mask = self.model.filter_edges(
edge_index=batch["edge_index"],
edge_score=batch["edge_score"],
edge_score_cut=edge_score_min,
)
dict_triplet_indices = from_edge_index_to_triplet_indices(
edge_index=filtered_edge_index
)
batch["filtered_edge_index"] = filtered_edge_index
batch["triplet_indices"] = dict_triplet_indices
batch["triplet_scores"] = {
triplet_name: torch.ones(
triplet_index.shape[1], device=triplet_index.device
)
for triplet_name, triplet_index in dict_triplet_indices.items()
}
return batch_from_triplets_to_tracks(
batch=batch,
triplet_indices=batch["triplet_indices"],
triplet_scores=batch["triplet_scores"],
edge_index=batch["filtered_edge_index"],
triplet_score_cut=0.0,
).cpu()
else:
raise ValueError(f"builder `{self.builder}` is not recognised")
[docs] def get_tracks(self, value: float, batches: typing.List[Data]):
# Run track reconstruction
batches = [
self._filter_edges(batch=batch, edge_score_min=value) for batch in batches
]
# Define dataframe of tracks
return pd.concat(
tuple(get_tracks_from_batch(batch=batch) for batch in batches)
).drop_duplicates()
[docs]class TripletGNNScoreCutExplorer(ParamExplorer):
"""A class that allows to vary the score cut after the GNN, and compare the metric
performances of track finding.
"""
def __init__(self, model: TripletGNNBase) -> None:
super().__init__(model, varname="score_cut", varlabel="Score cut")
[docs] def run_inference(self, batches: typing.List[Data] | LazyDatasetBase):
assert isinstance(self.model, TripletGNNBase)
with torch.no_grad():
for batch in tqdm(batches, desc="GNN inference"): # type: ignore
outputs = self.model.shared_evaluation(
batch, log=False, with_triplets=True
)
batch.triplet_indices = outputs["triplet_indices"]
batch.triplet_scores = outputs["triplet_scores"]
batch.filtered_edge_index = outputs["filtered_edge_index"]
return batches
[docs] def add_lhcb_text(self, ax: Axes, metric_name: str):
if metric_name in ["efficiency", "hit_efficiency_per_candidate"]:
plotools.add_text(ax, ha="left", y=0.3)
elif metric_name == "ghost_rate":
plotools.add_text(ax, ha="right", va="top")
elif metric_name == "clone_rate":
plotools.add_text(ax, ha="center", y=0.65)
else:
plotools.add_text(ax, ha="left", y=0.3)
[docs] def get_tracks(
self,
value: float,
batches: typing.List[Data],
):
# Run track reconstruction
batches = [
batch_from_triplets_to_tracks(
batch=batch,
triplet_indices=batch["triplet_indices"],
triplet_scores=batch["triplet_scores"],
edge_index=batch["filtered_edge_index"],
triplet_score_cut=value,
).cpu()
for batch in batches
]
# Define dataframe of tracks
return pd.concat(
tuple(get_tracks_from_batch(batch=batch) for batch in batches)
).drop_duplicates()