"""A module that defines a function to run the processing from a configuration file.
"""
from __future__ import annotations
import typing
import logging
import os
import json
from functools import partial
from pathlib import Path
from joblib import Parallel, delayed
from tqdm.auto import tqdm
from utils.tools.tfiles import delete_directory
from utils.commonutils.config import load_config, cdirs
from utils.commonutils.ctests import get_preprocessed_test_dataset_dir
from .processing import prepare_event
from .splitting import randomly_split_list
from Preprocessing.preprocessing_paths import get_truncated_paths
[docs]def run_processing_in_parallel(
truncated_paths: typing.List[str],
output_dir: str,
max_workers: int,
reproduce: bool = True,
**processing_config,
):
"""Run the processing step in parallel.
Args:
truncated_paths: List of the truncated paths of the input files (which
correspond to the hits-particles and particles dataframe)
output_dir: directory where to write the output files
max_workers: maximal number of processes to run in parallel
reproduce: whether to delete the output directory before writting on it
**processing_config: Other keyword arguments passed to :py:func:`prepare_event`
"""
if reproduce:
delete_directory(output_dir)
os.makedirs(os.path.join(output_dir), exist_ok=True)
if os.path.exists(os.path.join(output_dir, "done")):
logging.info(
f"Output folder is not empty so processing was not run: {output_dir}"
)
else:
logging.info("Writing outputs to " + output_dir)
# Process input files with a worker pool and progress bar
process_func = partial(
prepare_event, output_dir=output_dir, **processing_config
)
if max_workers == 1:
for truncated_path in tqdm(truncated_paths):
# logging.info(f"Truncated path: {truncated_path}")
process_func(truncated_path=truncated_path)
else:
Parallel(n_jobs=max_workers)(
delayed(process_func)(truncated_path)
for truncated_path in tqdm(truncated_paths)
)
Path(os.path.join(output_dir, "done")).touch()
[docs]def run_processing_test_dataset(
truncated_paths: typing.List[str],
output_dir: str,
n_workers: int,
reproduce: bool = True,
**processing_config,
):
"""Run the processing for the test dataset. There is no splitting train-val
for a test sample.
Args:
truncated_paths: List of the truncated paths of the input files (which
correspond to the hits-particles and particles dataframe)
output_dir: directory where to write the output files
n_workers: maximal number of processes to run in parallel
reproduce: whether to delete the output directory before writting on it
**processing_config: other keyword arguments passed to
:py:func:`.processing.prepare_event`
"""
run_processing_in_parallel(
truncated_paths=truncated_paths,
output_dir=output_dir,
max_workers=n_workers,
reproduce=reproduce,
**processing_config,
)
[docs]def run_processing_from_config(
path_or_config: str | dict,
reproduce: bool = True,
test_dataset_name: str | None = None,
):
"""Loop over the events saved during the pre-processing step, and transform
them into the relevant format for training.
Args:
path_or_config: the overall configuration
reproduce: whether to reproduce an existing processing
test_dataset_name: Name of the test dataset to produce. If ``None`` (default),
the train and val datasets are produced instead.
"""
config = load_config(path_or_config)
processing_config = config["processing"]
input_dir = processing_config.pop("input_dir")
output_dir = processing_config.pop("output_dir")
n_workers = processing_config.pop("n_workers")
if test_dataset_name is not None:
detector = config["common"].get("detector")
if detector is None:
detector = cdirs.detectors[0]
input_dir = get_preprocessed_test_dataset_dir(
test_dataset_name=test_dataset_name, detector=detector
)
output_dir = os.path.join(output_dir, "test", test_dataset_name)
logging.info(f"Input directory: {input_dir}")
truncated_paths = get_truncated_paths(input_dir=input_dir)
n_train_events = processing_config.pop("n_train_events")
n_test_events = processing_config.pop("n_val_events")
split_seed = processing_config.pop("split_seed")
if test_dataset_name is None:
# Split into train + val
list_truncated_paths = randomly_split_list(
truncated_paths, sizes=[n_train_events, n_test_events], seed=split_seed
)
# Run preprocessing for train, then for val
for split_name, truncated_paths in zip(["train", "val"], list_truncated_paths):
run_processing_in_parallel(
truncated_paths=truncated_paths,
output_dir=os.path.join(output_dir, split_name),
max_workers=n_workers,
reproduce=reproduce,
**processing_config,
)
# Save the splitting into a JSON file
splitting_json_path = os.path.join(output_dir, "splitting.json")
with open(splitting_json_path, "w") as json_file:
json.dump(
{
"train": list_truncated_paths[0],
"val": list_truncated_paths[1],
},
json_file,
indent=4,
)
logging.info(f"Splitting was saved in {splitting_json_path}.")
else:
# Run preprocessing for test
run_processing_test_dataset(
truncated_paths=truncated_paths,
output_dir=output_dir,
n_workers=n_workers,
reproduce=reproduce,
**processing_config,
)