"""A module that define utilities to handle test datasets.
"""
from __future__ import annotations
import typing
import os
import os.path as op
import yaml
from .config import load_dict, load_config, cdirs
[docs]def get_available_test_dataset_names(
path_or_config_test: str | typing.Dict[str, typing.Any] | None = None,
) -> typing.List[str]:
"""Get the list of available test dataset names from the test dataset
configuration file.
Args:
path_or_config_test: YAML test dataset configuration dictionary or path to it
Returns:
List of test dataset names that can be produced or/and used.
"""
if path_or_config_test is None:
path_or_config_test = cdirs.test_config_path
config = load_dict(path_or_config=path_or_config_test)
return list(config.keys())
[docs]def get_required_test_dataset_names(path_or_config: str | dict) -> typing.List[str]:
"""Get the list of the dataset names required by the configuration."""
config = load_config(path_or_config=path_or_config)
return config["common"]["test_dataset_names"]
[docs]def get_preprocessed_test_dataset_dir(test_dataset_name: str, detector: str) -> str:
"""Get the path to the directory that contains the preprocessed files
of a given test dataset.
Args:
test_dataset_name: name of the test dataset to pre-process
"""
return op.join(cdirs.data_directory, "__test__", detector, test_dataset_name)
[docs]def get_test_config_for_preprocessing(
test_dataset_name: str, path_or_config_test: str | dict, detector: str
) -> dict:
"""Get the configuration used for the pre-processing of a given test dataset.
Args:
test_dataset_name: name of the test dataset to pre-process
path_or_config_test: YAML test dataset configuration dictionary or path to it
"""
test_config = load_dict(path_or_config=path_or_config_test)
if test_dataset_name not in test_config:
raise ValueError(
f"`{test_dataset_name}` is not recognised as a valid test sample name. "
"Valid sample names are: " + ", ".join(list(test_config.keys()))
)
return {
"common": {"detector": detector},
"preprocessing": {
"output_dir": get_preprocessed_test_dataset_dir(
test_dataset_name=test_dataset_name, detector=detector
),
**test_config[test_dataset_name],
},
}
[docs]def load_preprocessing_test_config(
test_dataset_name: str,
reference_directory: str | None = None,
) -> typing.Dict[str, typing.Any]:
if reference_directory is None:
reference_directory = cdirs.reference_directory
test_config_path = op.join(
reference_directory, test_dataset_name, "preprocessing.yaml"
)
if op.exists(test_config_path):
with open(test_config_path, "r") as test_config_file:
return yaml.load(test_config_file, Loader=yaml.SafeLoader)
else:
return {}
[docs]def collect_test_samples(
reference_directory: str | None = None,
output_path: str | None = None,
n_events: int = 1000,
supplementary_test_config_path: str | None = None,
):
if reference_directory is None:
reference_directory = cdirs.reference_directory
if output_path is None:
output_path = cdirs.test_config_path
# Gather production (sample) names
test_dataset_names = [
file_or_dir.name
for file_or_dir in os.scandir(reference_directory)
if file_or_dir.is_dir()
]
default_config = {
"n_events": n_events,
"num_true_hits_threshold": None,
}
# Deduce configuration of test samples (same for every production)
test_config = {
test_dataset_name: {
"input_dir": op.join(reference_directory, test_dataset_name, "xdigi2csv"),
**default_config,
**load_preprocessing_test_config(
test_dataset_name=test_dataset_name,
reference_directory=reference_directory,
),
}
for test_dataset_name in test_dataset_names
}
# Add supplementary configuration path
if supplementary_test_config_path is not None:
with open(
supplementary_test_config_path, "r"
) as supplementary_test_config_file:
supplementary_test_config = yaml.load(
supplementary_test_config_file, Loader=yaml.SafeLoader
)
test_config = {**test_config, **supplementary_test_config}
# Dump configuration
with open(output_path, "w") as test_config_file:
yaml.dump(test_config, test_config_file, Dumper=yaml.SafeDumper)
print("Configuration of test samples saved in", output_path)
[docs]def get_test_batch_dir(experiment_name: str, stage: str, test_dataset_name: str):
"""Get the directory where the batches of a particular experiment, and
of a given test sample are saved.
Args:
experiment_name: name of the experiment
stage: name of the pipeline stage
test_dataset_name: name of the test dataset
Returns:
Path to the directory where the torch batch files are saved.
"""
return op.join(
cdirs.data_directory, experiment_name, stage, "test", test_dataset_name + "/"
)
[docs]def get_test_batch_paths(
experiment_name: str, stage: str, test_dataset_name: str
) -> typing.List[str]:
"""Get the list of paths of test batches of a given stage and experiment.
Args:
experiment_name: name of the experiment
stage: name of the pipeline stage
test_dataset_name: name of the test dataset
Returns:
List of paths of the test batches.
Notes:
If ``stage`` contains ``embedding`` and the test batch directory does not
exists, the function tries to replace ``embedding`` by ``metric_learning``
for backward compatiblity.
"""
test_batch_dir = get_test_batch_dir(
experiment_name=experiment_name,
stage=stage,
test_dataset_name=test_dataset_name,
)
if "embedding" in stage and not os.path.exists(test_batch_dir):
test_batch_dir = get_test_batch_dir(
experiment_name=experiment_name,
stage=stage.replace("embedding", "metric_learning"),
test_dataset_name=test_dataset_name,
)
return [
file_.path
for file_ in os.scandir(test_batch_dir)
if file_.is_file()
if file_.name != "done"
]