from __future__ import annotations
import typing
from tqdm.auto import tqdm
import pandas as pd
from matplotlib.axes import Axes
import torch
from torch_geometric.data import Data
from .triplet_gnn_base import TripletGNNBase, get_df_edges_from_batch_only
from utils.loaderutils.dataiterator import LazyDatasetBase
from TrackBuilding.builder import (
    batch_from_triplets_to_tracks,
    batch_from_edges_to_tracks,
)
from utils.loaderutils.tracks import get_tracks_from_batch
from utils.modelutils.exploration import ParamExplorer
from utils.plotutils import plotools
from utils.graphutils.tripletbuilding import from_edge_index_to_triplet_indices
[docs]class GNNScoreCutExplorer(ParamExplorer):
    """A class that allows to vary the score cut after the GNN, and compare the metric
    performances of track finding.
    """
    def __init__(self, model: TripletGNNBase, builder: str = "default") -> None:
        super().__init__(model, varname="score_cut", varlabel="Score cut")
        self.builder = str(builder)
        assert isinstance(model, TripletGNNBase)
        self.model = model
    @property
    def default_step(self) -> str:
        return "gnn"
    def _add_score(self, batch: Data):
        outputs = self.model.shared_evaluation(batch, log=False, with_triplets=False)
        batch["edge_score"] = outputs["edge_score"]
        return batch
[docs]    def run_inference(self, batches: typing.List[Data] | LazyDatasetBase):
        with torch.no_grad():
            batches = [self._add_score(batch=batch).cpu() for batch in batches]
        return batches 
    def _filter_edges(self, batch: Data, edge_score_min: float) -> Data:
        if self.builder == "default":
            batch = batch_from_edges_to_tracks(
                batch=batch,
                edge_index=batch["edge_index"][:, batch["edge_score"] > edge_score_min],
            )
            return batch
        elif self.builder == "triplet":
            # Build triplets
            filtered_edge_index, edge_mask = self.model.filter_edges(
                edge_index=batch["edge_index"],
                edge_score=batch["edge_score"],
                edge_score_cut=edge_score_min,
            )
            dict_triplet_indices = from_edge_index_to_triplet_indices(
                edge_index=filtered_edge_index
            )
            batch["filtered_edge_index"] = filtered_edge_index
            batch["triplet_indices"] = dict_triplet_indices
            batch["triplet_scores"] = {
                triplet_name: torch.ones(
                    triplet_index.shape[1], device=triplet_index.device
                )
                for triplet_name, triplet_index in dict_triplet_indices.items()
            }
            return batch_from_triplets_to_tracks(
                batch=batch,
                triplet_indices=batch["triplet_indices"],
                triplet_scores=batch["triplet_scores"],
                edge_index=batch["filtered_edge_index"],
                triplet_score_cut=0.0,
            ).cpu()
        else:
            raise ValueError(f"builder `{self.builder}` is not recognised")
[docs]    def get_tracks(self, value: float, batches: typing.List[Data]):
        # Run track reconstruction
        batches = [
            self._filter_edges(batch=batch, edge_score_min=value) for batch in batches
        ]
        # Define dataframe of tracks
        return pd.concat(
            tuple(get_tracks_from_batch(batch=batch) for batch in batches)
        ).drop_duplicates()  
[docs]class TripletGNNScoreCutExplorer(ParamExplorer):
    """A class that allows to vary the score cut after the GNN, and compare the metric
    performances of track finding.
    """
    def __init__(self, model: TripletGNNBase) -> None:
        super().__init__(model, varname="score_cut", varlabel="Score cut")
[docs]    def run_inference(self, batches: typing.List[Data] | LazyDatasetBase):
        assert isinstance(self.model, TripletGNNBase)
        with torch.no_grad():
            for batch in tqdm(batches, desc="GNN inference"):  # type: ignore
                outputs = self.model.shared_evaluation(
                    batch, log=False, with_triplets=True
                )
                batch.triplet_indices = outputs["triplet_indices"]
                batch.triplet_scores = outputs["triplet_scores"]
                batch.filtered_edge_index = outputs["filtered_edge_index"]
        return batches 
[docs]    def add_lhcb_text(self, ax: Axes, metric_name: str):
        if metric_name in ["efficiency", "hit_efficiency_per_candidate"]:
            plotools.add_text(ax, ha="left", y=0.3)
        elif metric_name == "ghost_rate":
            plotools.add_text(ax, ha="right", va="top")
        elif metric_name == "clone_rate":
            plotools.add_text(ax, ha="center", y=0.65)
        else:
            plotools.add_text(ax, ha="left", y=0.3) 
[docs]    def get_tracks(
        self,
        value: float,
        batches: typing.List[Data],
    ):
        # Run track reconstruction
        batches = [
            batch_from_triplets_to_tracks(
                batch=batch,
                triplet_indices=batch["triplet_indices"],
                triplet_scores=batch["triplet_scores"],
                edge_index=batch["filtered_edge_index"],
                triplet_score_cut=value,
            ).cpu()
            for batch in batches
        ]
        # Define dataframe of tracks
        return pd.concat(
            tuple(get_tracks_from_batch(batch=batch) for batch in batches)
        ).drop_duplicates()