from __future__ import annotations
import typing
import os
import os.path as op
import logging
import numpy as np
import numpy.typing as npt
from matplotlib.figure import Figure
from matplotlib.ticker import MaxNLocator
import montetracko as mt
from montetracko.evaluation.trackevaluator import TrackEvaluator
from montetracko.requirement.category import Category
from utils.plotutils import plotconfig, plotools
from utils.commonutils.config import cdirs
T = typing.TypeVar("T", str, int, float, TrackEvaluator, None)
def _to_list(element: T | typing.List[T], size: int | None = None) -> typing.List[T]:
"""A helper function to transform an element into a list of elements,
and check that the size of this list is right.
"""
if isinstance(element, list):
if size is not None and len(element) != size:
raise ValueError(
f"Expected size of {element} is {size}, instead of {len(element)}"
)
return element
elif element is None:
if size is None:
raise ValueError("No element nor size was provided.")
else:
return [element] * size
else:
if size is not None and size != 1:
raise ValueError(
f"Only a single element was provided, but the expected size is {size}"
)
return [element]
@typing.overload
def plot_histograms_trackevaluator(
trackEvaluator: TrackEvaluator | typing.List[TrackEvaluator],
columns: typing.List[str],
metric_names: typing.List[str],
color: str | typing.List[str] | None = None,
label: str | typing.List[str] | None = None,
column_labels: typing.Dict[str, str] | None = None,
bins: int | typing.Sequence[float] | str | typing.Dict[str, typing.Any] | None = 50,
column_ranges: typing.Dict[str, typing.Tuple[float, float]] | None = None,
category: Category | None = None,
same_fig: bool = True,
**kwargs,
) -> typing.Tuple[Figure | npt.NDArray, npt.NDArray, npt.NDArray]:
...
@typing.overload
def plot_histograms_trackevaluator(
trackEvaluator: TrackEvaluator | typing.List[TrackEvaluator],
columns: typing.List[str],
metric_names: typing.List[str],
color: str | typing.List[str] | None = None,
label: str | typing.List[str] | None = None,
column_labels: typing.Dict[str, str] | None = None,
bins: int | typing.Sequence[float] | str | typing.Dict[str, typing.Any] | None = 50,
column_ranges: typing.Dict[str, typing.Tuple[float, float]] | None = None,
category: Category | None = None,
same_fig: typing.Literal[False] = False,
**kwargs,
) -> typing.Tuple[npt.NDArray, npt.NDArray, npt.NDArray]:
...
@typing.overload
def plot_histograms_trackevaluator(
trackEvaluator: TrackEvaluator | typing.List[TrackEvaluator],
columns: typing.List[str],
metric_names: typing.List[str],
color: str | typing.List[str] | None = None,
label: str | typing.List[str] | None = None,
column_labels: typing.Dict[str, str] | None = None,
bins: int | typing.Sequence[float] | str | typing.Dict[str, typing.Any] | None = 50,
column_ranges: typing.Dict[str, typing.Tuple[float, float]] | None = None,
category: Category | None = None,
same_fig: typing.Literal[True] = True,
**kwargs,
) -> typing.Tuple[Figure, npt.NDArray, npt.NDArray]:
...
[docs]def plot_histograms_trackevaluator(
trackEvaluator: TrackEvaluator | typing.List[TrackEvaluator],
columns: typing.List[str],
metric_names: typing.List[str],
color: str | typing.List[str] | None = None,
label: str | typing.List[str] | None = None,
column_labels: typing.Dict[str, str] | None = None,
bins: int
| typing.Sequence[float]
| str
| typing.Dict[str, typing.Any]
| None = None,
column_ranges: typing.Dict[str, typing.Tuple[float, float]] | None = None,
category: Category | None = None,
same_fig: bool = True,
lhcb: bool = False,
with_err: bool = True,
**kwargs,
) -> typing.Tuple[Figure | npt.NDArray, npt.NDArray, npt.NDArray]:
"""Plot multiple histograms of metrics.
Args:
trackEvaluator: one or more montetracko track evaluators to plot. They
should share the same data distributions.
columns: list of columns to histogrammise the metrics on
metric_names: list of metric names to plot
color: colors for each track evaluator
labels: labels for each track evaluator
column_labels: Associates a column name with its label
bins: Number of bins, or a dictionary that associates a metric name with
the bin edges
column_ranges: Associates a column name with the tuples of lower and upper
ranges of the bin
category: Particle category to plot
same_fig: whether to put all the axes in the same figure
lhcb: whether to add "LHCb Simulation at the top of every matplotlib ax
Returns:
A tuple of 3 elements: the figure(s), the axes and the histogram axes.
"""
trackEvaluators = _to_list(trackEvaluator)
if color is None and len(trackEvaluators) == 1:
colors = ["k"]
else:
colors = _to_list(color, len(trackEvaluators)) # type: ignore
labels = _to_list(label, len(trackEvaluators)) # type: ignore
if column_ranges is None:
column_ranges = plotconfig.column_ranges
if column_labels is None:
column_labels = plotconfig.column_labels
if bins is None:
bins = plotconfig.column_bins
figs, axes = plotools.get_figs_axes_on_grid(
nrows=len(metric_names), ncols=len(columns), same_fig=same_fig
)
if not same_fig:
figs = np.atleast_2d(figs) # type: ignore
axes = np.atleast_2d(axes)
axes_histogram = np.empty_like(axes)
for idx_col, column in enumerate(columns):
edges = None
for idx_metric, metric_name in enumerate(metric_names):
if edges is not None:
bins_metric = edges
elif isinstance(bins, dict) and column in bins:
bins_metric = bins[column]
else:
bins_metric = 20
histogram_common_kwargs = dict(
column=column,
metric_name=metric_name,
category=category,
column_label=column_labels.get(column, column.replace("_", r"\_")),
ax=axes[idx_metric][idx_col],
)
alpha = 0.65 if len(trackEvaluators) > 1 else 1.0
(
axes_histogram[idx_metric][idx_col],
_,
_,
edges,
_,
) = trackEvaluators[0].plot_histogram(
range=column_ranges.get(column),
bins=bins_metric,
color=colors[0],
label=labels[0],
err="auto" if with_err else None,
alpha=alpha,
**histogram_common_kwargs,
**kwargs,
)
for trackEvaluator, color, label in zip(
trackEvaluators[1:], colors[1:], labels[1:]
):
trackEvaluator.plot_histogram(
bins=edges,
color=color,
label=label,
err="auto" if with_err else None,
show_histogram=False,
alpha=alpha,
**histogram_common_kwargs,
**kwargs,
)
if column in plotconfig.integer_columns:
axes[idx_metric][idx_col].xaxis.set_major_locator(
MaxNLocator(integer=True)
)
for idx_metric, metric_name in enumerate(metric_names):
plotools.set_same_y_lim_for_all_axes(axes[idx_metric], ymin=0.0)
if lhcb:
for ax, ax_hist in zip(axes[idx_metric], axes_histogram[idx_metric]):
plotools.pad_on_top(ax)
plotools.pad_on_top(ax_hist)
plotools.add_text(ax, ha="right", y=0.98)
if same_fig:
plotools.hide_repetitive_labels_in_grid(axes=axes)
for line_axes_histogram in axes_histogram:
for ax_histogram in line_axes_histogram[:-1]:
ax_histogram.yaxis.label.set_visible(False)
assert isinstance(figs, Figure)
figs.tight_layout()
# Define legend
if not all([label is None for label in labels]):
if same_fig:
axes[0][0].legend()
else:
for line_axes in axes:
for ax in line_axes:
ax.legend()
return figs, axes, axes_histogram
[docs]def plot_evaluation(
trackEvaluator: mt.TrackEvaluator,
category: mt.requirement.Category,
plotted_groups: typing.List[str] = ["basic"],
detector: str | None = None,
output_dir: str | None = None,
suffix: str | None = None,
):
"""Generate and display histograms of track evaluation metrics in specified
particle-related columns.
Args:
trackEvaluator: A ``TrackEvaluator`` instance containing the results
of the track matching
category: Truth category for the plot
plotted_groups: Pre-configured metrics and columns to plot.
Each group corresponds to one plot that shows the the distributions of
various metrics as a function of various truth variables,
as hard-coded in this function.
There are 3 groups: ``basic``, ``geometry`` and ``challenging``.
detector: name of the detector (``velo`` or ``scifi``)
suffix: Suffix to add at the end of the figure names
"""
plotconfig.configure_matplotlib()
if detector is None:
detector = cdirs.detectors[0]
if detector == "velo":
group_configurations = {
"basic": dict(
columns=["pt", "p", "eta", "vz"],
metric_names=[
"efficiency",
"clone_rate",
# "hit_purity_per_candidate",
"hit_efficiency_per_candidate",
],
),
"challenging": dict(
columns=["vz", "nhits_velo"],
metric_names=["efficiency"],
),
"geometry": dict(
columns=[
"distance_to_line",
"distance_to_z_axis",
"xz_angle",
"yz_angle",
],
metric_names=["efficiency"],
),
}
elif detector == "scifi" or detector == "scifi_xz":
group_configurations = {
"basic": dict(
columns=[
"pt",
"pl",
"n_unique_planes",
"n_skipped_planes",
"n_shared_hits",
"vz",
# "eta",
"phi",
"quadratic_coeff",
# "distance_to_line",
# "distance_to_z_axis",
],
metric_names=["efficiency"],
),
}
else:
raise ValueError(f"Detector {detector} is not recognised")
for group_name in plotted_groups:
if group_name not in group_configurations:
raise ValueError(
f"Group `{group_name}` is unknown. "
"Valid groups are: " + ", ".join(group_configurations.keys())
)
group_config = group_configurations[group_name]
fig, _, _ = plot_histograms_trackevaluator(
trackEvaluator=trackEvaluator,
**group_config, # type: ignore
category=category,
)
if output_dir is not None:
os.makedirs(output_dir, exist_ok=True)
if suffix is None:
suffix = ""
plot_path = op.join(
output_dir, f"hist1d_{group_name}_{category.name}{suffix}.pdf"
)
plotools.save_fig(fig, plot_path, dpi=200)
logging.info(
f"Plot {group_name} for category {category.name} saved in {plot_path}"
)
[docs]def plot_evaluation_categories(
trackEvaluator: mt.TrackEvaluator,
detector: str | None = None,
categories: typing.Iterable[mt.requirement.Category] | None = None,
plotted_groups: typing.List[str] | None = ["basic"],
output_dir: str | None = None,
suffix: str | None = None,
):
"""Generate and display histograms of track evaluation metrics in specified
particle-related columns, for various categories.
Args:
trackEvaluator: A ``TrackEvaluator`` instance containing the results
of the track matching
category: Truth category for the plot
plotted_groups: Pre-configured metrics and columns to plot.
Each group corresponds to one plot that shows the the distributions of
various metrics as a function of various truth variables,
as hard-coded in this function.
There are 3 groups: ``basic``, ``geometry`` and ``challenging``.
categories: list of categories
suffix: Suffix to add at the end of the figure names
"""
if categories is not None and plotted_groups is not None and plotted_groups:
for category in categories:
plot_evaluation(
trackEvaluator=trackEvaluator,
detector=detector,
category=category,
output_dir=output_dir,
plotted_groups=plotted_groups,
suffix=suffix,
)