"""This module define the default parser used in all the scripts.
"""
import typing
from argparse import ArgumentParser
from utils.commonutils.config import cdirs
[docs]def parse_args(**kwargs) -> str:
"""Parse the arguments of a parser with a single argument ``config``,
that allows to specify the configuration.
Args:
**kwargs: passed to the :py:class:`argparse.ArgumentParser`
Returns:
Path to the configuration file
"""
parser = ArgumentParser(**kwargs)
parser.add_argument("-c", "--config", help="Path to the YAML configuration file.")
return parser.parse_args().config
predefined_arguments: typing.Dict[
str, typing.Tuple[typing.List[str], typing.Dict[str, typing.Any]]
] = {
"test_dataset_name": (
["-t", "--test_dataset_name"],
dict(help="Name of the test dataset.", required=True),
),
"partition": (
["-p", "--partition"],
dict(help="Name of the test dataset, or `train` or `val`.", required=True),
),
"partitions": (
["-p", "--partitions"],
dict(
help="Names of the test dataset, `train` or `val`.",
required=True,
nargs="+",
),
),
"test_config": (
["-tx", "--test_config"],
dict(
help="Path where the test sample YAML configuration is saved.",
default=cdirs.test_config_path,
required=False,
),
),
"experiment_name": (
["-x", "--experiment_name"],
dict(help="Name of the model", required=True),
),
"n_workers": (
["-w", "--n_workers"],
dict(
help="Number of workers run in parallel.",
required=False,
default=16,
type=int,
),
),
"output_dir": (
["-o", "--output_dir"],
dict(help="Output directory where to save the plots", required=False),
),
"output_path": (
["-o", "--output_path"],
dict(help="Output path where to save the plot", required=False),
),
"pipeline_config": (
["-c", "--pipeline_config"],
dict(help="Path to the configuration pipeline", required=True),
),
"reproduce": (
["-r", "--reproduce"],
dict(
help="Whether to run even though the output file(s) already exist.",
action="store_true",
),
),
"detector": (
["-d", "--detector"],
dict(
help="Detector to evaluate Allen on.",
required=False,
default=cdirs.detectors[0],
choices=cdirs.detectors,
),
),
"step": (
["-s", "--step"],
dict(required=True, help="Model step, such as `embedding` or `gnn`."),
),
}
[docs]def add_predefined_arguments(
parser: ArgumentParser, arguments: typing.Iterable[str], **kwargs
):
"""Add predefined arguments defined by the :py:data:`.predefined_arguments`
dictionary.
Args:
parser: Argument parser which to add the predefined arguments to
arguments: names of predefined arguments to add
"""
for argument in arguments:
if argument not in predefined_arguments:
raise ValueError(f"Argument {argument} is not recognised.")
argument_params = predefined_arguments[argument]
assert len(argument_params) == 2
parser.add_argument(*argument_params[0], **argument_params[1], **kwargs)