Source code for scripts.train_model

#!/usr/bin/env python3
"""A script that runs the training of a model (embedding or GNN).
"""
from __future__ import annotations
import typing
from argparse import ArgumentParser, Namespace
import os
import logging

import torch
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CSVLogger

from utils.commonutils.config import load_config, cdirs
from utils.scriptutils import configure_logger, headline
from utils.scriptutils.parser import add_predefined_arguments
from pipeline import instantiate_model_for_training

configure_logger()


[docs]def get_parsed_args() -> Namespace: parser = ArgumentParser("Train a model.") add_predefined_arguments(parser, ["pipeline_config", "step"]) parser.add_argument( "-i", "--identifier", required=False, help="identifier added at the end of the step name", ) return parser.parse_args()
[docs]def train_model( path_or_config: str | dict, step: str, identifier: str | None = None ) -> typing.Tuple[Trainer, torch.nn.Module]: """Run the training of a model. Args: path_or_config: pipeline configuration or path to it. step: Model step, such as `embedding` or `gnn`. identifier: Identifier added at the end of the step name. Returns: Trainer and trained model. """ config = load_config(path_or_config) if identifier is None: identifier = "" logging.info(headline(f"{step} Training")) common_config = config["common"] model_config = config[step + identifier] logging.info(headline("a) Initialising model")) model = instantiate_model_for_training(path_or_config=config, step=step) logging.info(f"Model type: {model.__class__.__name__}") logging.info(headline("b) Running training")) save_directory = os.path.abspath( os.path.join(cdirs.artifact_directory, step + identifier) ) logger = CSVLogger(save_directory, name=common_config["experiment_name"]) logging.info("Save hyperparameters, metrics and artifacts in " + logger.log_dir) # model = model.cpu() trainer = Trainer( accelerator="gpu" if torch.cuda.is_available() else "cpu", devices=common_config.get("gpus", 1), max_epochs=model_config["max_epochs"], logger=logger, gradient_clip_val=model_config.get("gradient_clip_val"), reload_dataloaders_every_n_epochs=1, # callbacks=[EarlyStopping(monitor="val_loss", mode="min")] ) trainer.fit(model) logging.info(headline("c) Saving model")) os.makedirs(save_directory, exist_ok=True) trainer.save_checkpoint( os.path.join(save_directory, common_config["experiment_name"] + ".ckpt") ) return trainer, model
if __name__ == "__main__": parsed_args = get_parsed_args() config_path: str = parsed_args.config step: str = parsed_args.step identifier: str | None = parsed_args.identifier train_model(path_or_config=config_path, step=step, identifier=identifier)