Source code for scripts.evaluation.evaluate_allen_on_test_sample

#!/usr/bin/env python3
"""A script that evaluates the tracking performance of Allen for these test samples.
"""
import typing
import os
import os.path as op
from argparse import ArgumentParser

import numpy as np
import montetracko as mt

from utils.commonutils.config import cdirs
from utils.commonutils.ctests import get_available_test_dataset_names
from utils.scriptutils.parser import add_predefined_arguments
from utils.commonutils.cdetector import get_coordinate_names
from utils.commonutils.ctests import get_preprocessed_test_dataset_dir
from utils.loaderutils.preprocessing import load_preprocessed_dataframes
from scripts.evaluation.evaluate_etx4velo import evaluate


[docs]def get_event_ids(preprocessing_dir: str): return np.unique( [ file_.name.split("-")[0][len("event") :] for file_ in os.scandir(preprocessing_dir) if file_.name != "done" ], return_counts=False, return_index=False, return_inverse=False, ).tolist()
[docs]def evaluate_allen( test_dataset_name: str, detector: str | None = None, event_ids: typing.List[str] | typing.List[int] | None = None, output_dir: str | None = None, **kwargs, ) -> mt.TrackEvaluator: """Evaluate the track finding performance of Allen in a test sample. Args: test_dataset_name: Name of a test dataset event_ids: Event IDs to consider. If not provided, all available are considered output_dir: Output directory where the save the reports and figures **kwargs: Other keyword arguments passed to :py:func:`scripts.evaluate.evaluate` """ if detector is None: detector = cdirs.detectors[0] preprocessed_input_dir = get_preprocessed_test_dataset_dir( test_dataset_name=test_dataset_name, detector=detector ) if event_ids is None: event_ids_str: typing.List[str] = np.unique( [ file_.name.split("-")[0][len("event") :] for file_ in os.scandir(preprocessed_input_dir) if file_.name != "done" ], return_index=False, return_counts=False, return_inverse=False, ).tolist() else: event_ids_str: typing.List[str] = [ str(event_id).zfill(18) for event_id in event_ids ] trackhandler = mt.TrackHandler.from_padded_csv( paths=[ op.join( op.join( cdirs.reference_directory, test_dataset_name, cdirs.get_filenames_from_detector(detector=detector)[ "track_subdirectory" ], f"{event_id_str}.csv", ) ) for event_id_str in event_ids_str ], padding_value=0, skip_header=True, ) truncated_paths = [ op.join(preprocessed_input_dir, "event" + str(event_id).zfill(18)) for event_id in event_ids_str ] df_hits_particles = load_preprocessed_dataframes( truncated_paths=truncated_paths, ending="-hits_particles", columns=[ "particle_id", "hit_id", "plane", ] + get_coordinate_names(detector=detector) + (["dxdy"] if detector == "scifi_xz" else []), ) df_particles = load_preprocessed_dataframes( truncated_paths=truncated_paths, ending="-particles" ) if output_dir is None: output_dir = op.join( cdirs.performance_directory, "allen", detector, test_dataset_name ) return evaluate( df_hits_particles=df_hits_particles, df_particles=df_particles, df_tracks=trackhandler.dataframe, output_dir=output_dir, detector=detector, **kwargs, )
if __name__ == "__main__": parser = ArgumentParser("Run the preprocessing of all the test sample.") add_predefined_arguments( parser=parser, arguments=["test_dataset_name", "test_config", "detector"] ) parser.add_argument( "-i", "--indir", help=( "Directory where the test samples are located. " "They are generated by the XDIGI2CSV repository." ), default=cdirs.reference_directory, required=False, ) parser.add_argument( "-g", "--plotted_groups", help="Group of variables to plot the performance metrics on.", nargs="+", default=[], required=False, ) parser.add_argument( "-o", "--output_dir", help="Output directory where to save the report and plots.", required=False, ) parsed_args = parser.parse_args() test_dataset_name: str = parsed_args.test_dataset_name indir: str = parsed_args.indir test_config_path: str = parsed_args.test_config plotted_groups: typing.List[str] = parsed_args.plotted_groups detector: str = parsed_args.detector output_dir: str = ( parsed_output_dir if (parsed_output_dir := parsed_args.output_dir) else op.join(cdirs.performance_directory, "allen", detector, test_dataset_name) ) test_dataset_names = get_available_test_dataset_names( path_or_config_test=test_config_path ) evaluate_allen( test_dataset_name=test_dataset_name, plotted_groups=plotted_groups, output_dir=output_dir, timestamp=False, detector=detector, )