Source code for pipeline.utils.modelutils.checkpoint_utils

"""A module that define helper functions for checkpointing.
"""
from __future__ import annotations
import typing
import os
import glob

from utils.commonutils.config import load_config, cdirs
import pandas as pd

from pytorch_lightning import Trainer


[docs]def get_last_version_dir(experiment_dir: str) -> str: """Get the path of the last "version directory" given the path to the directory of the given training experiment. This directory must have a file "metrics.csv". Args: experiment_dir: path to the training experiment of interest Returns: Path to the last version directory """ version_folder_paths = glob.glob( os.path.join(experiment_dir, "version_*/metrics.csv") ) available_versions = [ int( version_folder_path[ len(os.path.join(experiment_dir, "version_")) : -len("/metrics.csv") ] ) for version_folder_path in version_folder_paths ] if not available_versions: raise ValueError(f"No version with `metrics.csv` found in {experiment_dir}") last_version = sorted(available_versions)[-1] return os.path.join(os.path.join(experiment_dir, f"version_{last_version}"))
[docs]def get_last_artifact(version_dir: str, ckpt_dirname: str = "checkpoints") -> str: """Get the last artifact stored in a given version directory. The last artifact is the one that has the largest number of epochs, the largest number of steps and the last version. Args: version_dir: path to a directory that stores the training outcomes of a given experiment Returns: Path to the last PyTorch artifact file """ checkpoints_dir = os.path.join(version_dir, ckpt_dirname) available_artifact_files = os.listdir(checkpoints_dir) dict_info_artifact_files = [] for available_artifact_file in available_artifact_files: available_artifact_file_without_ext, ext = os.path.splitext( available_artifact_file ) if ext == ".ckpt": tuple_info = available_artifact_file_without_ext.split("-") if len(tuple_info) == 2: epoch_str, step_str = tuple_info version_str = None elif len(tuple_info) == 3: epoch_str, step_str, version_str = tuple_info else: raise ValueError( f"Artifact name `{available_artifact_file}` could not be parsed." ) epoch = int(epoch_str[len("epoch=") :]) step = int(step_str[len("step=") :]) version = 0 if version_str is None else version_str[len("v") :] dict_info_artifact_files.append( { "epoch": epoch, "step": step, "version": version, "file_name": available_artifact_file, } ) df_info_artifact_files = pd.DataFrame(dict_info_artifact_files) if len(df_info_artifact_files) == 0: raise ValueError(f"No artifact files were found in {checkpoints_dir}") else: last_artifact_file_name = df_info_artifact_files.sort_values( by=["epoch", "step", "version"], ascending=False, )["file_name"].iloc[0] return os.path.join(checkpoints_dir, last_artifact_file_name)
[docs]def get_last_version_dir_from_config(step: str, path_or_config: str | dict) -> str: """Get the path to the last version directory given the configuration. Args: step: ``embedding`` or ``gnn`` path_or_config: configuration dictionary, or path to the YAML file that contains the configuration Returns: Path to the last version directory Notes: For backward compatibility, if the ``embedding`` does not exist, it is replaced by ``metric_learning``. """ configs = load_config(path_or_config) experiment_name = configs["common"]["experiment_name"] experiment_dir = os.path.join(cdirs.artifact_directory, step, experiment_name) # For backward compatibility, for if ``embedding`` directory does not exist, # check if ``metric_learning`` does. if step == "embedding": if not os.path.exists(experiment_dir): experiment_dir = os.path.join( cdirs.artifact_directory, "metric_learning", experiment_name ) return get_last_version_dir(experiment_dir=experiment_dir)
[docs]def get_training_metrics( trainer: Trainer | str | typing.List[str], suffix: str = "" ) -> pd.DataFrame: """Get the dataframe of the training metrics. Args: trainer: either a PyTorch Lighting Trainer object, or the path(s) to the metric file(s) to load directly. suffix: suffix to add to the name of the columns in the CSV file Returns: Dataframe of the training metrics (one row / epoch). """ if isinstance(trainer, Trainer): logger = trainer.logger assert logger is not None, "No logger was assigned to this trainer." log_dir = logger.log_dir assert log_dir is not None, "The logger was not assigned a local directory." log_file = os.path.join(log_dir, "metrics.csv") elif isinstance(trainer, str): log_file = trainer elif isinstance(trainer, (list, tuple)): return pd.concat( (get_training_metrics(trainer=log_file) for log_file in trainer), axis=0, ) else: raise TypeError( "`trainer` should be str, a list of str or a pytorch trainer, but is " + type(trainer).__name__ ) metrics = pd.read_csv(log_file, sep=",") train_loss_column = ( "train_loss" + suffix if "train_loss" + suffix in metrics else f"train_loss{suffix}_epoch" ) val_loss_column = ( "val_loss" + suffix if "val_loss" + suffix in metrics else f"val_loss{suffix}_epoch" ) train_metrics = metrics[~metrics[train_loss_column].isna()][ ["epoch", train_loss_column] ] # train_metrics["epoch"] -= 1 val_metrics = metrics[~metrics[val_loss_column].isna()][ [ column if column in metrics else column + "_epoch" for column in [ "val_loss" + suffix, "eff" + suffix, "pur" + suffix, "current_lr", "epoch", ] ] ] metrics = pd.merge(left=train_metrics, right=val_metrics, how="inner", on="epoch") for column in metrics.columns: if column.endswith("_epoch"): metrics.rename(columns={column: column[: -len("_epoch")]}, inplace=True) if suffix: metrics.rename( columns={ column: column[: -len(suffix)] for column in metrics.columns if column.endswith(suffix) }, inplace=True, ) return metrics