Source code for pipeline.utils.commonutils.ctests

"""A module that define utilities to handle test datasets.
"""

from __future__ import annotations
import typing
import os
import os.path as op
import yaml
from .config import load_dict, load_config, cdirs


[docs]def get_available_test_dataset_names( path_or_config_test: str | typing.Dict[str, typing.Any] | None = None, ) -> typing.List[str]: """Get the list of available test dataset names from the test dataset configuration file. Args: path_or_config_test: YAML test dataset configuration dictionary or path to it Returns: List of test dataset names that can be produced or/and used. """ if path_or_config_test is None: path_or_config_test = cdirs.test_config_path config = load_dict(path_or_config=path_or_config_test) return list(config.keys())
[docs]def get_required_test_dataset_names(path_or_config: str | dict) -> typing.List[str]: """Get the list of the dataset names required by the configuration.""" config = load_config(path_or_config=path_or_config) return config["common"]["test_dataset_names"]
[docs]def get_preprocessed_test_dataset_dir(test_dataset_name: str, detector: str) -> str: """Get the path to the directory that contains the preprocessed files of a given test dataset. Args: test_dataset_name: name of the test dataset to pre-process """ return op.join(cdirs.data_directory, "__test__", detector, test_dataset_name)
[docs]def get_test_config_for_preprocessing( test_dataset_name: str, path_or_config_test: str | dict, detector: str ) -> dict: """Get the configuration used for the pre-processing of a given test dataset. Args: test_dataset_name: name of the test dataset to pre-process path_or_config_test: YAML test dataset configuration dictionary or path to it """ test_config = load_dict(path_or_config=path_or_config_test) if test_dataset_name not in test_config: raise ValueError( f"`{test_dataset_name}` is not recognised as a valid test sample name. " "Valid sample names are: " + ", ".join(list(test_config.keys())) ) return { "common": {"detector": detector}, "preprocessing": { "output_dir": get_preprocessed_test_dataset_dir( test_dataset_name=test_dataset_name, detector=detector ), **test_config[test_dataset_name], }, }
[docs]def load_preprocessing_test_config( test_dataset_name: str, reference_directory: str | None = None, ) -> typing.Dict[str, typing.Any]: if reference_directory is None: reference_directory = cdirs.reference_directory test_config_path = op.join( reference_directory, test_dataset_name, "preprocessing.yaml" ) if op.exists(test_config_path): with open(test_config_path, "r") as test_config_file: return yaml.load(test_config_file, Loader=yaml.SafeLoader) else: return {}
[docs]def collect_test_samples( reference_directory: str | None = None, output_path: str | None = None, n_events: int = 1000, supplementary_test_config_path: str | None = None, ): if reference_directory is None: reference_directory = cdirs.reference_directory if output_path is None: output_path = cdirs.test_config_path # Gather production (sample) names test_dataset_names = [ file_or_dir.name for file_or_dir in os.scandir(reference_directory) if file_or_dir.is_dir() ] default_config = { "n_events": n_events, "num_true_hits_threshold": None, } # Deduce configuration of test samples (same for every production) test_config = { test_dataset_name: { "input_dir": op.join(reference_directory, test_dataset_name, "xdigi2csv"), **default_config, **load_preprocessing_test_config( test_dataset_name=test_dataset_name, reference_directory=reference_directory, ), } for test_dataset_name in test_dataset_names } # Add supplementary configuration path if supplementary_test_config_path is not None: with open( supplementary_test_config_path, "r" ) as supplementary_test_config_file: supplementary_test_config = yaml.load( supplementary_test_config_file, Loader=yaml.SafeLoader ) test_config = {**test_config, **supplementary_test_config} # Dump configuration with open(output_path, "w") as test_config_file: yaml.dump(test_config, test_config_file, Dumper=yaml.SafeDumper) print("Configuration of test samples saved in", output_path)
[docs]def get_test_batch_dir(experiment_name: str, stage: str, test_dataset_name: str): """Get the directory where the batches of a particular experiment, and of a given test sample are saved. Args: experiment_name: name of the experiment stage: name of the pipeline stage test_dataset_name: name of the test dataset Returns: Path to the directory where the torch batch files are saved. """ return op.join( cdirs.data_directory, experiment_name, stage, "test", test_dataset_name + "/" )
[docs]def get_test_batch_paths( experiment_name: str, stage: str, test_dataset_name: str ) -> typing.List[str]: """Get the list of paths of test batches of a given stage and experiment. Args: experiment_name: name of the experiment stage: name of the pipeline stage test_dataset_name: name of the test dataset Returns: List of paths of the test batches. Notes: If ``stage`` contains ``embedding`` and the test batch directory does not exists, the function tries to replace ``embedding`` by ``metric_learning`` for backward compatiblity. """ test_batch_dir = get_test_batch_dir( experiment_name=experiment_name, stage=stage, test_dataset_name=test_dataset_name, ) if "embedding" in stage and not os.path.exists(test_batch_dir): test_batch_dir = get_test_batch_dir( experiment_name=experiment_name, stage=stage.replace("embedding", "metric_learning"), test_dataset_name=test_dataset_name, ) return [ file_.path for file_ in os.scandir(test_batch_dir) if file_.is_file() if file_.name != "done" ]