from __future__ import annotations
import typing
import os
import os.path as op
import logging
import numpy as np
import numpy.typing as npt
from matplotlib.figure import Figure
from matplotlib.ticker import MaxNLocator
import montetracko as mt
from montetracko.evaluation.trackevaluator import TrackEvaluator
from montetracko.requirement.category import Category
from utils.plotutils import plotconfig, plotools
from utils.commonutils.config import cdirs
T = typing.TypeVar("T", str, int, float, TrackEvaluator, None)
def _to_list(element: T | typing.List[T], size: int | None = None) -> typing.List[T]:
    """A helper function to transform an element into a list of elements,
    and check that the size of this list is right.
    """
    if isinstance(element, list):
        if size is not None and len(element) != size:
            raise ValueError(
                f"Expected size of {element} is {size}, instead of {len(element)}"
            )
        return element
    elif element is None:
        if size is None:
            raise ValueError("No element nor size was provided.")
        else:
            return [element] * size
    else:
        if size is not None and size != 1:
            raise ValueError(
                f"Only a single element was provided, but the expected size is {size}"
            )
        return [element]
@typing.overload
def plot_histograms_trackevaluator(
    trackEvaluator: TrackEvaluator | typing.List[TrackEvaluator],
    columns: typing.List[str],
    metric_names: typing.List[str],
    color: str | typing.List[str] | None = None,
    label: str | typing.List[str] | None = None,
    column_labels: typing.Dict[str, str] | None = None,
    bins: int | typing.Sequence[float] | str | typing.Dict[str, typing.Any] | None = 50,
    column_ranges: typing.Dict[str, typing.Tuple[float, float]] | None = None,
    category: Category | None = None,
    same_fig: bool = True,
    **kwargs,
) -> typing.Tuple[Figure | npt.NDArray, npt.NDArray, npt.NDArray]:
    ...
@typing.overload
def plot_histograms_trackevaluator(
    trackEvaluator: TrackEvaluator | typing.List[TrackEvaluator],
    columns: typing.List[str],
    metric_names: typing.List[str],
    color: str | typing.List[str] | None = None,
    label: str | typing.List[str] | None = None,
    column_labels: typing.Dict[str, str] | None = None,
    bins: int | typing.Sequence[float] | str | typing.Dict[str, typing.Any] | None = 50,
    column_ranges: typing.Dict[str, typing.Tuple[float, float]] | None = None,
    category: Category | None = None,
    same_fig: typing.Literal[False] = False,
    **kwargs,
) -> typing.Tuple[npt.NDArray, npt.NDArray, npt.NDArray]:
    ...
@typing.overload
def plot_histograms_trackevaluator(
    trackEvaluator: TrackEvaluator | typing.List[TrackEvaluator],
    columns: typing.List[str],
    metric_names: typing.List[str],
    color: str | typing.List[str] | None = None,
    label: str | typing.List[str] | None = None,
    column_labels: typing.Dict[str, str] | None = None,
    bins: int | typing.Sequence[float] | str | typing.Dict[str, typing.Any] | None = 50,
    column_ranges: typing.Dict[str, typing.Tuple[float, float]] | None = None,
    category: Category | None = None,
    same_fig: typing.Literal[True] = True,
    **kwargs,
) -> typing.Tuple[Figure, npt.NDArray, npt.NDArray]:
    ...
[docs]def plot_histograms_trackevaluator(
    trackEvaluator: TrackEvaluator | typing.List[TrackEvaluator],
    columns: typing.List[str],
    metric_names: typing.List[str],
    color: str | typing.List[str] | None = None,
    label: str | typing.List[str] | None = None,
    column_labels: typing.Dict[str, str] | None = None,
    bins: int
    | typing.Sequence[float]
    | str
    | typing.Dict[str, typing.Any]
    | None = None,
    column_ranges: typing.Dict[str, typing.Tuple[float, float]] | None = None,
    category: Category | None = None,
    same_fig: bool = True,
    lhcb: bool = False,
    with_err: bool = True,
    **kwargs,
) -> typing.Tuple[Figure | npt.NDArray, npt.NDArray, npt.NDArray]:
    """Plot multiple histograms of metrics.
    Args:
        trackEvaluator: one or more montetracko track evaluators to plot. They
            should share the same data distributions.
        columns: list of columns to histogrammise the metrics on
        metric_names: list of metric names to plot
        color: colors for each track evaluator
        labels: labels for each track evaluator
        column_labels: Associates a column name with its label
        bins: Number of bins, or a dictionary that associates a metric name with
            the bin edges
        column_ranges: Associates a column name with the tuples of lower and upper
            ranges of the bin
        category: Particle category to plot
        same_fig: whether to put all the axes in the same figure
        lhcb: whether to add "LHCb Simulation at the top of every matplotlib ax
    Returns:
        A tuple of 3 elements: the figure(s), the axes and the histogram axes.
    """
    trackEvaluators = _to_list(trackEvaluator)
    if color is None and len(trackEvaluators) == 1:
        colors = ["k"]
    else:
        colors = _to_list(color, len(trackEvaluators))  # type: ignore
    labels = _to_list(label, len(trackEvaluators))  # type: ignore
    if column_ranges is None:
        column_ranges = plotconfig.column_ranges
    if column_labels is None:
        column_labels = plotconfig.column_labels
    if bins is None:
        bins = plotconfig.column_bins
    figs, axes = plotools.get_figs_axes_on_grid(
        nrows=len(metric_names), ncols=len(columns), same_fig=same_fig
    )
    if not same_fig:
        figs = np.atleast_2d(figs)  # type: ignore
    axes = np.atleast_2d(axes)
    axes_histogram = np.empty_like(axes)
    for idx_col, column in enumerate(columns):
        edges = None
        for idx_metric, metric_name in enumerate(metric_names):
            if edges is not None:
                bins_metric = edges
            elif isinstance(bins, dict) and column in bins:
                bins_metric = bins[column]
            else:
                bins_metric = 20
            histogram_common_kwargs = dict(
                column=column,
                metric_name=metric_name,
                category=category,
                column_label=column_labels.get(column, column.replace("_", r"\_")),
                ax=axes[idx_metric][idx_col],
            )
            alpha = 0.65 if len(trackEvaluators) > 1 else 1.0
            (
                axes_histogram[idx_metric][idx_col],
                _,
                _,
                edges,
                _,
            ) = trackEvaluators[0].plot_histogram(
                range=column_ranges.get(column),
                bins=bins_metric,
                color=colors[0],
                label=labels[0],
                err="auto" if with_err else None,
                alpha=alpha,
                **histogram_common_kwargs,
                **kwargs,
            )
            for trackEvaluator, color, label in zip(
                trackEvaluators[1:], colors[1:], labels[1:]
            ):
                trackEvaluator.plot_histogram(
                    bins=edges,
                    color=color,
                    label=label,
                    err="auto" if with_err else None,
                    show_histogram=False,
                    alpha=alpha,
                    **histogram_common_kwargs,
                    **kwargs,
                )
            if column in plotconfig.integer_columns:
                axes[idx_metric][idx_col].xaxis.set_major_locator(
                    MaxNLocator(integer=True)
                )
    for idx_metric, metric_name in enumerate(metric_names):
        plotools.set_same_y_lim_for_all_axes(axes[idx_metric], ymin=0.0)
        if lhcb:
            for ax, ax_hist in zip(axes[idx_metric], axes_histogram[idx_metric]):
                plotools.pad_on_top(ax)
                plotools.pad_on_top(ax_hist)
                plotools.add_text(ax, ha="right", y=0.98)
    if same_fig:
        plotools.hide_repetitive_labels_in_grid(axes=axes)
        for line_axes_histogram in axes_histogram:
            for ax_histogram in line_axes_histogram[:-1]:
                ax_histogram.yaxis.label.set_visible(False)
        assert isinstance(figs, Figure)
        figs.tight_layout()
    # Define legend
    if not all([label is None for label in labels]):
        if same_fig:
            axes[0][0].legend()
        else:
            for line_axes in axes:
                for ax in line_axes:
                    ax.legend()
    return figs, axes, axes_histogram 
[docs]def plot_evaluation(
    trackEvaluator: mt.TrackEvaluator,
    category: mt.requirement.Category,
    plotted_groups: typing.List[str] = ["basic"],
    detector: str | None = None,
    output_dir: str | None = None,
    suffix: str | None = None,
):
    """Generate and display histograms of track evaluation metrics in specified
    particle-related columns.
    Args:
        trackEvaluator: A ``TrackEvaluator`` instance containing the results
            of the track matching
        category: Truth category for the plot
        plotted_groups: Pre-configured metrics and columns to plot.
            Each group corresponds to one plot that shows the the distributions of
            various metrics as a function of various truth variables,
            as hard-coded in this function.
            There are 3 groups: ``basic``, ``geometry`` and ``challenging``.
        detector: name of the detector (``velo`` or ``scifi``)
        suffix: Suffix to add at the end of the figure names
    """
    plotconfig.configure_matplotlib()
    if detector is None:
        detector = cdirs.detectors[0]
    if detector == "velo":
        group_configurations = {
            "basic": dict(
                columns=["pt", "p", "eta", "vz"],
                metric_names=[
                    "efficiency",
                    "clone_rate",
                    # "hit_purity_per_candidate",
                    "hit_efficiency_per_candidate",
                ],
            ),
            "challenging": dict(
                columns=["vz", "nhits_velo"],
                metric_names=["efficiency"],
            ),
            "geometry": dict(
                columns=[
                    "distance_to_line",
                    "distance_to_z_axis",
                    "xz_angle",
                    "yz_angle",
                ],
                metric_names=["efficiency"],
            ),
        }
    elif detector == "scifi" or detector == "scifi_xz":
        group_configurations = {
            "basic": dict(
                columns=[
                    "pt",
                    "pl",
                    "n_unique_planes",
                    "n_skipped_planes",
                    "n_shared_hits",
                    "vz",
                    # "eta",
                    "phi",
                    "quadratic_coeff",
                    # "distance_to_line",
                    # "distance_to_z_axis",
                ],
                metric_names=["efficiency"],
            ),
        }
    else:
        raise ValueError(f"Detector {detector} is not recognised")
    for group_name in plotted_groups:
        if group_name not in group_configurations:
            raise ValueError(
                f"Group `{group_name}` is unknown. "
                "Valid groups are: " + ", ".join(group_configurations.keys())
            )
        group_config = group_configurations[group_name]
        fig, _, _ = plot_histograms_trackevaluator(
            trackEvaluator=trackEvaluator,
            **group_config,  # type: ignore
            category=category,
        )
        if output_dir is not None:
            os.makedirs(output_dir, exist_ok=True)
            if suffix is None:
                suffix = ""
            plot_path = op.join(
                output_dir, f"hist1d_{group_name}_{category.name}{suffix}.pdf"
            )
            plotools.save_fig(fig, plot_path, dpi=200)
            logging.info(
                f"Plot {group_name} for category {category.name} saved in {plot_path}"
            ) 
[docs]def plot_evaluation_categories(
    trackEvaluator: mt.TrackEvaluator,
    detector: str | None = None,
    categories: typing.Iterable[mt.requirement.Category] | None = None,
    plotted_groups: typing.List[str] | None = ["basic"],
    output_dir: str | None = None,
    suffix: str | None = None,
):
    """Generate and display histograms of track evaluation metrics in specified
    particle-related columns, for various categories.
    Args:
        trackEvaluator: A ``TrackEvaluator`` instance containing the results
            of the track matching
        category: Truth category for the plot
        plotted_groups: Pre-configured metrics and columns to plot.
            Each group corresponds to one plot that shows the the distributions of
            various metrics as a function of various truth variables,
            as hard-coded in this function.
            There are 3 groups: ``basic``, ``geometry`` and ``challenging``.
        categories: list of categories
        suffix: Suffix to add at the end of the figure names
    """
    if categories is not None and plotted_groups is not None and plotted_groups:
        for category in categories:
            plot_evaluation(
                trackEvaluator=trackEvaluator,
                detector=detector,
                category=category,
                output_dir=output_dir,
                plotted_groups=plotted_groups,
                suffix=suffix,
            )