Source code for pipeline.GNN.gnn_plots
from __future__ import annotations
import typing
import numpy.typing as npt
from matplotlib.figure import Figure
from matplotlib.axes import Axes
from .triplet_gnn_base import TripletGNNBase
from .gnn_validation import GNNScoreCutExplorer, TripletGNNScoreCutExplorer
[docs]def plot_best_performances_score_cut(
    model: TripletGNNBase,
    partition: str,
    edge_score_cuts: typing.Sequence[float],
    builder: str = "default",
    n_events: int | None = None,
    seed: int | None = None,
    identifier: str | None = None,
    path_or_config: str | dict | None = None,
    step: str = "gnn",
    **kwargs,
) -> typing.Tuple[
    Figure | npt.NDArray,
    typing.List[Axes],
    typing.Dict[float, typing.Dict[typing.Tuple[str | None, str], float]],
]:
    if identifier is None:
        identifier = "_from_edges"
        if builder != "default":
            identifier += f"_{builder}"
    gnnScoreCutExplorer = GNNScoreCutExplorer(model=model, builder=builder)
    return gnnScoreCutExplorer.plot(
        path_or_config=path_or_config,
        partition=partition,
        values=edge_score_cuts,
        n_events=n_events,
        seed=seed,
        identifier=identifier,
        step=step,
        **kwargs,
    )
[docs]def plot_best_performances_score_cut_triplets(
    model: TripletGNNBase,
    partition: str,
    edge_score_cut: float,
    triplet_score_cuts: typing.Sequence[float],
    n_events: int | None = None,
    seed: int | None = None,
    identifier: str | None = None,
    path_or_config: str | dict | None = None,
    step: str = "gnn",
    **kwargs,
) -> typing.Tuple[
    Figure | npt.NDArray,
    typing.List[Axes],
    typing.Dict[float, typing.Dict[typing.Tuple[str | None, str], float]],
]:
    if identifier is None:
        identifier = "_from_triplets"
    model.hparams["edge_score_cut"] = edge_score_cut
    gnnScoreCutExplorer = TripletGNNScoreCutExplorer(model=model)
    return gnnScoreCutExplorer.plot(
        path_or_config=path_or_config,
        partition=partition,
        values=triplet_score_cuts,
        n_events=n_events,
        seed=seed,
        identifier=identifier,
        step=step,
        **kwargs,
    )
