#!/usr/bin/env python3
"""A script that compares the performance of ETX4VELO and Allen in a test sample.
"""
from __future__ import annotations
import typing
import os.path as op
from argparse import ArgumentParser
import numpy as np
import montetracko as mt
import montetracko.lhcb as mtb
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from Evaluation.plotting import plot_histograms_trackevaluator
from utils.plotutils import plotools, plotconfig
from utils.commonutils.config import (
get_performance_directory_experiment,
cdirs,
load_config,
get_detector_from_pipeline_config,
)
from utils.scriptutils.parser import add_predefined_arguments
from scripts.evaluation.evaluate_etx4velo import evaluate_partition
from scripts.evaluation.evaluate_allen_on_test_sample import evaluate_allen
plotconfig.configure_matplotlib()
[docs]def compare_allen_vs_etx4velo_from_trackevaluators(
trackEvaluators: mt.TrackEvaluator | typing.List[mt.TrackEvaluator],
trackEvaluator_Allen: mt.TrackEvaluator,
names: typing.List[str] | str | None = None,
colors: typing.List[str] | str | None = None,
categories: typing.List[mt.requirement.Category] | None = None,
metric_names: typing.List[str] | None = None,
columns: typing.List[str] | None = None,
test_dataset_name: str | None = None,
detector: str | None = None,
output_dir: str | None = None,
suffix: str | None = None,
path_or_config: str | dict | None = None,
same_fig: bool = True,
with_err: bool = True,
**kwargs,
):
""""""
if isinstance(trackEvaluators, mt.TrackEvaluator):
trackEvaluators = [trackEvaluators]
if names is None:
names = ["ETX4VELO"]
elif isinstance(names, str):
names = [names]
if colors is None:
colors = ["blue"]
elif not isinstance(colors, list):
colors = [colors]
if categories is None:
categories = [
mtb.category.category_velo_no_electrons,
mtb.category.category_long_only_electrons,
mtb.category.category_long_strange,
]
if metric_names is None:
metric_names = ["efficiency", "hit_efficiency_per_candidate"]
if detector is None:
detector = (
get_detector_from_pipeline_config(path_or_config)
if path_or_config is not None
else cdirs.detectors[0]
)
if columns is None:
if detector == "velo":
columns = [
"pt",
"n_unique_planes",
"n_skipped_planes",
"n_shared_hits",
"vz",
"eta",
"phi",
# "distance_to_line",
# "distance_to_z_axis",
]
elif detector == "scifi" or detector == "scifi_xz":
columns = [
"pt",
"px",
"py",
"pz",
"n_unique_planes",
"n_skipped_planes",
"n_shared_hits",
"vz",
"eta",
"phi",
"quadratic_coeff",
]
else:
raise ValueError(f"Detector {detector} is not recognised.")
if suffix is None:
suffix = ""
if categories is None:
categories = [
mtb.category.category_velo_no_electrons,
mtb.category.category_long_only_electrons,
]
for category in categories:
figs, _, _ = plot_histograms_trackevaluator(
trackEvaluator=trackEvaluators + [trackEvaluator_Allen],
label=names + ["Allen"], # type: ignore
color=colors + ["green"], # type: ignore
columns=columns,
metric_names=metric_names,
category=category,
same_fig=same_fig,
with_err=with_err,
**kwargs,
)
if output_dir is None:
if test_dataset_name is not None:
if path_or_config is None:
raise ValueError(
"`test_dataset_name` was provided, which is used to save "
"the figure. However, `path_or_config` was not provided "
"and is needed to determine the output directory. "
)
else:
experiment_performance_dir = get_performance_directory_experiment(
path_or_config=path_or_config
)
output_dir = op.join(
experiment_performance_dir,
"etx4velo_vs_allen",
f"{test_dataset_name}",
)
if output_dir is not None:
if same_fig:
assert isinstance(figs, Figure)
plotools.save_fig(figs, op.join(output_dir, category.name + suffix))
plt.close(figs)
else:
assert isinstance(figs, np.ndarray)
for idx_metric, metric_name in enumerate(metric_names):
for idx_col, column in enumerate(columns):
fig: Figure = figs[idx_metric][idx_col]
plotools.save_fig(
fig,
op.join(
output_dir,
f"{category.name}{suffix}_{metric_name}_{column}",
),
)
plt.close(fig)
[docs]def compare_etx4velo_vs_allen(
path_or_config: str | dict,
test_dataset_name: str,
categories: typing.List[mt.requirement.Category] | None = None,
metric_names: typing.List[str] | None = None,
columns: typing.List[str] | None = None,
same_fig: bool = True,
output_dir: str | None = None,
lhcb: bool = False,
allen_report: bool = False,
table_report: bool = False,
suffix: str | None = None,
with_err: bool = True,
compare_trackevaluators: bool = True,
**kwargs,
):
print(f"Evaluation of {test_dataset_name} for ETX4VELO")
trackEvaluator_etx4velo = evaluate_partition(
path_or_config=path_or_config,
partition=test_dataset_name,
plotted_groups=["basic"],
plot_categories=[],
allen_report=allen_report,
table_report=table_report,
suffix=suffix,
**kwargs,
)
print(f"Evaluation of {test_dataset_name} for Allen")
trackEvaluator_allen = evaluate_allen(
test_dataset_name=test_dataset_name,
plotted_groups=["basic"],
plot_categories=[],
allen_report=allen_report,
table_report=table_report,
detector=load_config(path_or_config)["common"].get("detector"),
suffix=suffix,
)
if compare_trackevaluators:
print(f"Compare ETX4VELO vs Allen in {test_dataset_name}")
compare_allen_vs_etx4velo_from_trackevaluators(
path_or_config=path_or_config,
trackEvaluators=trackEvaluator_etx4velo,
trackEvaluator_Allen=trackEvaluator_allen,
test_dataset_name=test_dataset_name,
categories=categories,
metric_names=metric_names,
columns=columns,
output_dir=output_dir,
suffix=suffix,
lhcb=lhcb,
same_fig=same_fig,
with_err=with_err,
)
[docs]def compare_etx4velo_vs_allen_global(
paths_or_configs: typing.List[str | dict],
names: typing.List[str],
colors: typing.List[str],
test_dataset_name: str,
categories: typing.List[mt.requirement.Category] | None = None,
metric_names: typing.List[str] | None = None,
columns: typing.List[str] | None = None,
same_fig: bool = True,
lhcb: bool = False,
allen_report: bool = False,
table_report: bool = False,
suffix: str | None = None,
compare_trackevaluators: bool = True,
**kwargs,
):
assert len(paths_or_configs) == len(names)
trackEvaluators = []
for path_or_config in paths_or_configs:
print(f"Evaluation of {test_dataset_name} for ETX4VELO")
trackEvaluators.append(
evaluate_partition(
path_or_config=path_or_config,
partition=test_dataset_name,
plotted_groups=["basic"],
plot_categories=[],
allen_report=allen_report,
table_report=table_report,
suffix=suffix,
**kwargs,
)
)
print(f"Evaluation of {test_dataset_name} for Allen")
detector = get_detector_from_pipeline_config(paths_or_configs[0])
trackEvaluator_allen = evaluate_allen(
test_dataset_name=test_dataset_name,
plotted_groups=["basic"],
plot_categories=[],
allen_report=allen_report,
table_report=table_report,
detector=detector,
suffix=suffix,
)
if compare_trackevaluators:
print(f"Compare ETX4VELO vs Allen in {test_dataset_name}")
compare_allen_vs_etx4velo_from_trackevaluators(
trackEvaluators=trackEvaluators,
trackEvaluator_Allen=trackEvaluator_allen,
names=names,
colors=colors,
test_dataset_name=test_dataset_name,
categories=categories,
metric_names=metric_names,
columns=columns,
output_dir=op.join(
cdirs.performance_directory, "etx4velo_vs_allen", "_".join(names)
),
detector=detector,
suffix=suffix,
lhcb=lhcb,
same_fig=same_fig,
)
if __name__ == "__main__":
parser = ArgumentParser(
description="Plot histograms of metrics for etx4velo and for Allen."
)
add_predefined_arguments(
parser, ["pipeline_config", "test_dataset_name", "output_dir"]
)
parsed_args = parser.parse_args()
pipeline_config_path: str = parsed_args.pipeline_config
test_dataset_name: str = parsed_args.test_dataset_name
output_dir: str | None = parsed_args.output_dir
compare_etx4velo_vs_allen(
path_or_config=pipeline_config_path,
test_dataset_name=test_dataset_name,
output_dir=output_dir,
lhcb=True,
same_fig=False,
)