#!/usr/bin/env python3
"""A script that evaluates the tracking performance of Allen for these test samples.
"""
import typing
import os
import os.path as op
from argparse import ArgumentParser
import numpy as np
import montetracko as mt
from utils.commonutils.config import cdirs
from utils.commonutils.ctests import get_available_test_dataset_names
from utils.scriptutils.parser import add_predefined_arguments
from utils.commonutils.cdetector import get_coordinate_names
from utils.commonutils.ctests import get_preprocessed_test_dataset_dir
from utils.loaderutils.preprocessing import load_preprocessed_dataframes
from scripts.evaluation.evaluate_etx4velo import evaluate
[docs]def get_event_ids(preprocessing_dir: str):
return np.unique(
[
file_.name.split("-")[0][len("event") :]
for file_ in os.scandir(preprocessing_dir)
if file_.name != "done"
],
return_counts=False,
return_index=False,
return_inverse=False,
).tolist()
[docs]def evaluate_allen(
test_dataset_name: str,
detector: str | None = None,
event_ids: typing.List[str] | typing.List[int] | None = None,
output_dir: str | None = None,
**kwargs,
) -> mt.TrackEvaluator:
"""Evaluate the track finding performance of Allen in a test sample.
Args:
test_dataset_name: Name of a test dataset
event_ids: Event IDs to consider. If not provided, all available are
considered
output_dir: Output directory where the save the reports and figures
**kwargs: Other keyword arguments passed to :py:func:`scripts.evaluate.evaluate`
"""
if detector is None:
detector = cdirs.detectors[0]
preprocessed_input_dir = get_preprocessed_test_dataset_dir(
test_dataset_name=test_dataset_name, detector=detector
)
if event_ids is None:
event_ids_str: typing.List[str] = np.unique(
[
file_.name.split("-")[0][len("event") :]
for file_ in os.scandir(preprocessed_input_dir)
if file_.name != "done"
],
return_index=False,
return_counts=False,
return_inverse=False,
).tolist()
else:
event_ids_str: typing.List[str] = [
str(event_id).zfill(18) for event_id in event_ids
]
trackhandler = mt.TrackHandler.from_padded_csv(
paths=[
op.join(
op.join(
cdirs.reference_directory,
test_dataset_name,
cdirs.get_filenames_from_detector(detector=detector)[
"track_subdirectory"
],
f"{event_id_str}.csv",
)
)
for event_id_str in event_ids_str
],
padding_value=0,
skip_header=True,
)
truncated_paths = [
op.join(preprocessed_input_dir, "event" + str(event_id).zfill(18))
for event_id in event_ids_str
]
df_hits_particles = load_preprocessed_dataframes(
truncated_paths=truncated_paths,
ending="-hits_particles",
columns=[
"particle_id",
"hit_id",
"plane",
]
+ get_coordinate_names(detector=detector)
+ (["dxdy"] if detector == "scifi_xz" else []),
)
df_particles = load_preprocessed_dataframes(
truncated_paths=truncated_paths, ending="-particles"
)
if output_dir is None:
output_dir = op.join(
cdirs.performance_directory, "allen", detector, test_dataset_name
)
return evaluate(
df_hits_particles=df_hits_particles,
df_particles=df_particles,
df_tracks=trackhandler.dataframe,
output_dir=output_dir,
detector=detector,
**kwargs,
)
if __name__ == "__main__":
parser = ArgumentParser("Run the preprocessing of all the test sample.")
add_predefined_arguments(
parser=parser, arguments=["test_dataset_name", "test_config", "detector"]
)
parser.add_argument(
"-i",
"--indir",
help=(
"Directory where the test samples are located. "
"They are generated by the XDIGI2CSV repository."
),
default=cdirs.reference_directory,
required=False,
)
parser.add_argument(
"-g",
"--plotted_groups",
help="Group of variables to plot the performance metrics on.",
nargs="+",
default=[],
required=False,
)
parser.add_argument(
"-o",
"--output_dir",
help="Output directory where to save the report and plots.",
required=False,
)
parsed_args = parser.parse_args()
test_dataset_name: str = parsed_args.test_dataset_name
indir: str = parsed_args.indir
test_config_path: str = parsed_args.test_config
plotted_groups: typing.List[str] = parsed_args.plotted_groups
detector: str = parsed_args.detector
output_dir: str = (
parsed_output_dir
if (parsed_output_dir := parsed_args.output_dir)
else op.join(cdirs.performance_directory, "allen", detector, test_dataset_name)
)
test_dataset_names = get_available_test_dataset_names(
path_or_config_test=test_config_path
)
evaluate_allen(
test_dataset_name=test_dataset_name,
plotted_groups=plotted_groups,
output_dir=output_dir,
timestamp=False,
detector=detector,
)