Source code for pipeline.Preprocessing.preprocessing_paths
"""Module to handle the output path of the preprocessing.
"""
from __future__ import annotations
import typing
import json
import logging
import os
import os.path as op
import numpy as np
from utils.commonutils.config import load_config, get_detector_from_pipeline_config
from utils.commonutils.ctests import (
get_required_test_dataset_names,
get_preprocessed_test_dataset_dir,
)
[docs]def get_truncated_paths(input_dir: str) -> typing.List[str]:
"""Get the list of the truncated paths in a given preprocessing folder.
Args:
input_dir: directory that contains preprocessed files
Returns:
List of the paths, truncated so that they do not contains neither the extension,
nor ``hits_particles`` not ``particles``
"""
return sorted(
np.unique(
[
op.join(input_dir, file_.name.split("-")[0])
for file_ in os.scandir(input_dir)
if file_.name != "done"
]
)
)
[docs]def get_truncated_paths_for_partition(
path_or_config: str | dict, partition: str
) -> typing.List[str]:
"""Get the list of truncated paths for a given partition.
Args:
path_or_config: configuration dictionary, or path to the YAML file that contains
the configuration
partition: Dataset partition: ``train``, ``val`` or name of a test dataset
Returns:
List of the truncated paths of the pre-processed parquet files for this
partition.
"""
config = load_config(path_or_config=path_or_config)
test_dataset_names = get_required_test_dataset_names(config)
if partition in ["train", "val"]:
splitting_json_path = op.join(
config["processing"]["output_dir"], "splitting.json"
)
logging.info(f"Load truncated paths for {partition} in {splitting_json_path}")
with open(splitting_json_path, "r") as json_file:
truncated_paths = json.load(json_file)[partition]
elif partition in test_dataset_names:
detector = get_detector_from_pipeline_config(path_or_config)
test_preprocessed_input_dir = get_preprocessed_test_dataset_dir(
test_dataset_name=partition, detector=detector
)
logging.info(
f"Load pre-processed test datasets in {test_preprocessed_input_dir}."
)
truncated_paths = get_truncated_paths(input_dir=test_preprocessed_input_dir)
else:
raise ValueError(
"`partition` is not recognised. It can either be `train`, `val` "
"or the name of a test dataset: " + str(test_dataset_names)
)
return truncated_paths