import typing
import os
import logging
[docs]class InOutFunction(typing.Protocol):
def __call__(
self, input_dir: str, output_dir: str, reproduce: bool = True
) -> typing.Any:
...
[docs]def run_for_different_partitions(
func: InOutFunction,
input_dir: str,
output_dir: str,
partitions: typing.List[str] = ["train", "val", "test"],
test_dataset_names: typing.List[str] | None = None,
reproduce: bool = True,
list_kwargs: typing.List[dict] | None = None,
**kwargs,
):
"""Run a function for different dataset "partitions".
Args:
func: Function to run, with input ``input_dir``, ``output_dir``, ``reproduce``
and possibly additional keyword arguments.
input_dir: input directory
output_dir: output directory
partitions: Partitions to run run the ``func`` on:
* ``train``: train dataset
* ``val``: validation dataset
* ``test``: all the test datasets
* A specific test dataset name
test_dataset_names: list of possible test dataset names
reproduce: whether to reproduce the output. This will remove the output
directory.
**kwargs: keyword arguments passed to ``func``
"""
for partition_idx, partition in enumerate(partitions):
if list_kwargs is None:
supplementary_kwargs = {}
else:
supplementary_kwargs = list_kwargs[partition_idx]
logging.info(
f"Use the following parameters for {partition}: {supplementary_kwargs}"
)
if partition in ["train", "val"]:
func(
input_dir=os.path.join(input_dir, partition),
output_dir=os.path.join(output_dir, partition),
reproduce=reproduce,
**supplementary_kwargs,
**kwargs,
)
elif partition == "test":
assert test_dataset_names is not None, (
"Trying to run the inference on a test sample, "
"but `test_dataset_names` was not provided."
)
for test_dataset_name in test_dataset_names:
func(
input_dir=os.path.join(input_dir, "test", test_dataset_name),
output_dir=os.path.join(output_dir, "test", test_dataset_name),
reproduce=reproduce,
**supplementary_kwargs,
**kwargs,
)
elif (test_dataset_names is not None) and (partition in test_dataset_names):
func(
input_dir=os.path.join(input_dir, "test", partition),
output_dir=os.path.join(output_dir, "test", partition),
reproduce=reproduce,
**supplementary_kwargs,
**kwargs,
)
else:
raise ValueError(
f"Partition `{partition}` is not recognised. "
"A partition can either be `train`, `val`, `test` "
"or a test dataset name: " + str(test_dataset_names)
)