Source code for pipeline.Preprocessing.preprocessing_paths

"""Module to handle the output path of the preprocessing.
"""
from __future__ import annotations
import typing
import json
import logging
import os
import os.path as op
import numpy as np


from utils.commonutils.config import load_config, get_detector_from_pipeline_config
from utils.commonutils.ctests import (
    get_required_test_dataset_names,
    get_preprocessed_test_dataset_dir,
)


[docs]def get_truncated_paths(input_dir: str) -> typing.List[str]: """Get the list of the truncated paths in a given preprocessing folder. Args: input_dir: directory that contains preprocessed files Returns: List of the paths, truncated so that they do not contains neither the extension, nor ``hits_particles`` not ``particles`` """ return sorted( np.unique( [ op.join(input_dir, file_.name.split("-")[0]) for file_ in os.scandir(input_dir) if file_.name != "done" ] ) )
[docs]def get_truncated_paths_for_partition( path_or_config: str | dict, partition: str ) -> typing.List[str]: """Get the list of truncated paths for a given partition. Args: path_or_config: configuration dictionary, or path to the YAML file that contains the configuration partition: Dataset partition: ``train``, ``val`` or name of a test dataset Returns: List of the truncated paths of the pre-processed parquet files for this partition. """ config = load_config(path_or_config=path_or_config) test_dataset_names = get_required_test_dataset_names(config) if partition in ["train", "val"]: splitting_json_path = op.join( config["processing"]["output_dir"], "splitting.json" ) logging.info(f"Load truncated paths for {partition} in {splitting_json_path}") with open(splitting_json_path, "r") as json_file: truncated_paths = json.load(json_file)[partition] elif partition in test_dataset_names: detector = get_detector_from_pipeline_config(path_or_config) test_preprocessed_input_dir = get_preprocessed_test_dataset_dir( test_dataset_name=partition, detector=detector ) logging.info( f"Load pre-processed test datasets in {test_preprocessed_input_dir}." ) truncated_paths = get_truncated_paths(input_dir=test_preprocessed_input_dir) else: raise ValueError( "`partition` is not recognised. It can either be `train`, `val` " "or the name of a test dataset: " + str(test_dataset_names) ) return truncated_paths