Source code for pipeline.GNN.gnn_plots

from __future__ import annotations
import typing

import numpy.typing as npt
from matplotlib.figure import Figure
from matplotlib.axes import Axes

from .triplet_gnn_base import TripletGNNBase
from .gnn_validation import GNNScoreCutExplorer, TripletGNNScoreCutExplorer


[docs]def plot_best_performances_score_cut( model: TripletGNNBase, partition: str, edge_score_cuts: typing.Sequence[float], builder: str = "default", n_events: int | None = None, seed: int | None = None, identifier: str | None = None, path_or_config: str | dict | None = None, step: str = "gnn", **kwargs, ) -> typing.Tuple[ Figure | npt.NDArray, typing.List[Axes], typing.Dict[float, typing.Dict[typing.Tuple[str | None, str], float]], ]: if identifier is None: identifier = "_from_edges" if builder != "default": identifier += f"_{builder}" gnnScoreCutExplorer = GNNScoreCutExplorer(model=model, builder=builder) return gnnScoreCutExplorer.plot( path_or_config=path_or_config, partition=partition, values=edge_score_cuts, n_events=n_events, seed=seed, identifier=identifier, step=step, **kwargs, )
[docs]def plot_best_performances_score_cut_triplets( model: TripletGNNBase, partition: str, edge_score_cut: float, triplet_score_cuts: typing.Sequence[float], n_events: int | None = None, seed: int | None = None, identifier: str | None = None, path_or_config: str | dict | None = None, step: str = "gnn", **kwargs, ) -> typing.Tuple[ Figure | npt.NDArray, typing.List[Axes], typing.Dict[float, typing.Dict[typing.Tuple[str | None, str], float]], ]: if identifier is None: identifier = "_from_triplets" model.hparams["edge_score_cut"] = edge_score_cut gnnScoreCutExplorer = TripletGNNScoreCutExplorer(model=model) return gnnScoreCutExplorer.plot( path_or_config=path_or_config, partition=partition, values=triplet_score_cuts, n_events=n_events, seed=seed, identifier=identifier, step=step, **kwargs, )