Source code for pipeline.GNN.gnn_plots
from __future__ import annotations
import typing
import numpy.typing as npt
from matplotlib.figure import Figure
from matplotlib.axes import Axes
from .triplet_gnn_base import TripletGNNBase
from .gnn_validation import GNNScoreCutExplorer, TripletGNNScoreCutExplorer
[docs]def plot_best_performances_score_cut(
model: TripletGNNBase,
partition: str,
edge_score_cuts: typing.Sequence[float],
builder: str = "default",
n_events: int | None = None,
seed: int | None = None,
identifier: str | None = None,
path_or_config: str | dict | None = None,
step: str = "gnn",
**kwargs,
) -> typing.Tuple[
Figure | npt.NDArray,
typing.List[Axes],
typing.Dict[float, typing.Dict[typing.Tuple[str | None, str], float]],
]:
if identifier is None:
identifier = "_from_edges"
if builder != "default":
identifier += f"_{builder}"
gnnScoreCutExplorer = GNNScoreCutExplorer(model=model, builder=builder)
return gnnScoreCutExplorer.plot(
path_or_config=path_or_config,
partition=partition,
values=edge_score_cuts,
n_events=n_events,
seed=seed,
identifier=identifier,
step=step,
**kwargs,
)
[docs]def plot_best_performances_score_cut_triplets(
model: TripletGNNBase,
partition: str,
edge_score_cut: float,
triplet_score_cuts: typing.Sequence[float],
n_events: int | None = None,
seed: int | None = None,
identifier: str | None = None,
path_or_config: str | dict | None = None,
step: str = "gnn",
**kwargs,
) -> typing.Tuple[
Figure | npt.NDArray,
typing.List[Axes],
typing.Dict[float, typing.Dict[typing.Tuple[str | None, str], float]],
]:
if identifier is None:
identifier = "_from_triplets"
model.hparams["edge_score_cut"] = edge_score_cut
gnnScoreCutExplorer = TripletGNNScoreCutExplorer(model=model)
return gnnScoreCutExplorer.plot(
path_or_config=path_or_config,
partition=partition,
values=triplet_score_cuts,
n_events=n_events,
seed=seed,
identifier=identifier,
step=step,
**kwargs,
)