Source code for pipeline.utils.commonutils.crun

import typing
import os
import logging


[docs]class InOutFunction(typing.Protocol): def __call__( self, input_dir: str, output_dir: str, reproduce: bool = True ) -> typing.Any: ...
[docs]def run_for_different_partitions( func: InOutFunction, input_dir: str, output_dir: str, partitions: typing.List[str] = ["train", "val", "test"], test_dataset_names: typing.List[str] | None = None, reproduce: bool = True, list_kwargs: typing.List[dict] | None = None, **kwargs, ): """Run a function for different dataset "partitions". Args: func: Function to run, with input ``input_dir``, ``output_dir``, ``reproduce`` and possibly additional keyword arguments. input_dir: input directory output_dir: output directory partitions: Partitions to run run the ``func`` on: * ``train``: train dataset * ``val``: validation dataset * ``test``: all the test datasets * A specific test dataset name test_dataset_names: list of possible test dataset names reproduce: whether to reproduce the output. This will remove the output directory. **kwargs: keyword arguments passed to ``func`` """ for partition_idx, partition in enumerate(partitions): if list_kwargs is None: supplementary_kwargs = {} else: supplementary_kwargs = list_kwargs[partition_idx] logging.info( f"Use the following parameters for {partition}: {supplementary_kwargs}" ) if partition in ["train", "val"]: func( input_dir=os.path.join(input_dir, partition), output_dir=os.path.join(output_dir, partition), reproduce=reproduce, **supplementary_kwargs, **kwargs, ) elif partition == "test": assert test_dataset_names is not None, ( "Trying to run the inference on a test sample, " "but `test_dataset_names` was not provided." ) for test_dataset_name in test_dataset_names: func( input_dir=os.path.join(input_dir, "test", test_dataset_name), output_dir=os.path.join(output_dir, "test", test_dataset_name), reproduce=reproduce, **supplementary_kwargs, **kwargs, ) elif (test_dataset_names is not None) and (partition in test_dataset_names): func( input_dir=os.path.join(input_dir, "test", partition), output_dir=os.path.join(output_dir, "test", partition), reproduce=reproduce, **supplementary_kwargs, **kwargs, ) else: raise ValueError( f"Partition `{partition}` is not recognised. " "A partition can either be `train`, `val`, `test` " "or a test dataset name: " + str(test_dataset_names) )