Source code for scripts.train_model
#!/usr/bin/env python3
"""A script that runs the training of a model (embedding or GNN).
"""
from __future__ import annotations
import typing
from argparse import ArgumentParser, Namespace
import os
import logging
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CSVLogger
from utils.commonutils.config import load_config, cdirs
from utils.scriptutils import configure_logger, headline
from utils.scriptutils.parser import add_predefined_arguments
from pipeline import instantiate_model_for_training
configure_logger()
[docs]def get_parsed_args() -> Namespace:
parser = ArgumentParser("Train a model.")
add_predefined_arguments(parser, ["pipeline_config", "step"])
parser.add_argument(
"-i",
"--identifier",
required=False,
help="identifier added at the end of the step name",
)
return parser.parse_args()
[docs]def train_model(
path_or_config: str | dict, step: str, identifier: str | None = None
) -> typing.Tuple[Trainer, torch.nn.Module]:
"""Run the training of a model.
Args:
path_or_config: pipeline configuration or path to it.
step: Model step, such as `embedding` or `gnn`.
identifier: Identifier added at the end of the step name.
Returns:
Trainer and trained model.
"""
config = load_config(path_or_config)
if identifier is None:
identifier = ""
logging.info(headline(f"{step} Training"))
common_config = config["common"]
model_config = config[step + identifier]
logging.info(headline("a) Initialising model"))
model = instantiate_model_for_training(path_or_config=config, step=step)
logging.info(f"Model type: {model.__class__.__name__}")
logging.info(headline("b) Running training"))
save_directory = os.path.abspath(
os.path.join(cdirs.artifact_directory, step + identifier)
)
logger = CSVLogger(save_directory, name=common_config["experiment_name"])
logging.info("Save hyperparameters, metrics and artifacts in " + logger.log_dir)
# model = model.cpu()
trainer = Trainer(
accelerator="gpu" if torch.cuda.is_available() else "cpu",
devices=common_config.get("gpus", 1),
max_epochs=model_config["max_epochs"],
logger=logger,
gradient_clip_val=model_config.get("gradient_clip_val"),
reload_dataloaders_every_n_epochs=1,
# callbacks=[EarlyStopping(monitor="val_loss", mode="min")]
)
trainer.fit(model)
logging.info(headline("c) Saving model"))
os.makedirs(save_directory, exist_ok=True)
trainer.save_checkpoint(
os.path.join(save_directory, common_config["experiment_name"] + ".ckpt")
)
return trainer, model
if __name__ == "__main__":
parsed_args = get_parsed_args()
config_path: str = parsed_args.config
step: str = parsed_args.step
identifier: str | None = parsed_args.identifier
train_model(path_or_config=config_path, step=step, identifier=identifier)