Source code for pipeline.Processing.run_processing

"""A module that defines a function to run the processing from a configuration file.
"""
from __future__ import annotations
import typing
import logging
import os
import json
from functools import partial
from pathlib import Path

from joblib import Parallel, delayed
from tqdm.auto import tqdm

from utils.tools.tfiles import delete_directory
from utils.commonutils.config import load_config, cdirs
from utils.commonutils.ctests import get_preprocessed_test_dataset_dir
from .processing import prepare_event
from .splitting import randomly_split_list
from Preprocessing.preprocessing_paths import get_truncated_paths


[docs]def run_processing_in_parallel( truncated_paths: typing.List[str], output_dir: str, max_workers: int, reproduce: bool = True, **processing_config, ): """Run the processing step in parallel. Args: truncated_paths: List of the truncated paths of the input files (which correspond to the hits-particles and particles dataframe) output_dir: directory where to write the output files max_workers: maximal number of processes to run in parallel reproduce: whether to delete the output directory before writting on it **processing_config: Other keyword arguments passed to :py:func:`prepare_event` """ if reproduce: delete_directory(output_dir) os.makedirs(os.path.join(output_dir), exist_ok=True) if os.path.exists(os.path.join(output_dir, "done")): logging.info( f"Output folder is not empty so processing was not run: {output_dir}" ) else: logging.info("Writing outputs to " + output_dir) # Process input files with a worker pool and progress bar process_func = partial( prepare_event, output_dir=output_dir, **processing_config ) if max_workers == 1: for truncated_path in tqdm(truncated_paths): # logging.info(f"Truncated path: {truncated_path}") process_func(truncated_path=truncated_path) else: Parallel(n_jobs=max_workers)( delayed(process_func)(truncated_path) for truncated_path in tqdm(truncated_paths) ) Path(os.path.join(output_dir, "done")).touch()
[docs]def run_processing_test_dataset( truncated_paths: typing.List[str], output_dir: str, n_workers: int, reproduce: bool = True, **processing_config, ): """Run the processing for the test dataset. There is no splitting train-val for a test sample. Args: truncated_paths: List of the truncated paths of the input files (which correspond to the hits-particles and particles dataframe) output_dir: directory where to write the output files n_workers: maximal number of processes to run in parallel reproduce: whether to delete the output directory before writting on it **processing_config: other keyword arguments passed to :py:func:`.processing.prepare_event` """ run_processing_in_parallel( truncated_paths=truncated_paths, output_dir=output_dir, max_workers=n_workers, reproduce=reproduce, **processing_config, )
[docs]def run_processing_from_config( path_or_config: str | dict, reproduce: bool = True, test_dataset_name: str | None = None, ): """Loop over the events saved during the pre-processing step, and transform them into the relevant format for training. Args: path_or_config: the overall configuration reproduce: whether to reproduce an existing processing test_dataset_name: Name of the test dataset to produce. If ``None`` (default), the train and val datasets are produced instead. """ config = load_config(path_or_config) processing_config = config["processing"] input_dir = processing_config.pop("input_dir") output_dir = processing_config.pop("output_dir") n_workers = processing_config.pop("n_workers") if test_dataset_name is not None: detector = config["common"].get("detector") if detector is None: detector = cdirs.detectors[0] input_dir = get_preprocessed_test_dataset_dir( test_dataset_name=test_dataset_name, detector=detector ) output_dir = os.path.join(output_dir, "test", test_dataset_name) logging.info(f"Input directory: {input_dir}") truncated_paths = get_truncated_paths(input_dir=input_dir) n_train_events = processing_config.pop("n_train_events") n_test_events = processing_config.pop("n_val_events") split_seed = processing_config.pop("split_seed") if test_dataset_name is None: # Split into train + val list_truncated_paths = randomly_split_list( truncated_paths, sizes=[n_train_events, n_test_events], seed=split_seed ) # Run preprocessing for train, then for val for split_name, truncated_paths in zip(["train", "val"], list_truncated_paths): run_processing_in_parallel( truncated_paths=truncated_paths, output_dir=os.path.join(output_dir, split_name), max_workers=n_workers, reproduce=reproduce, **processing_config, ) # Save the splitting into a JSON file splitting_json_path = os.path.join(output_dir, "splitting.json") with open(splitting_json_path, "w") as json_file: json.dump( { "train": list_truncated_paths[0], "val": list_truncated_paths[1], }, json_file, indent=4, ) logging.info(f"Splitting was saved in {splitting_json_path}.") else: # Run preprocessing for test run_processing_test_dataset( truncated_paths=truncated_paths, output_dir=output_dir, n_workers=n_workers, reproduce=reproduce, **processing_config, )