"""A module that define helper functions for checkpointing.
"""
from __future__ import annotations
import typing
import os
import glob
from utils.commonutils.config import load_config, cdirs
import pandas as pd
from pytorch_lightning import Trainer
[docs]def get_last_version_dir(experiment_dir: str) -> str:
"""Get the path of the last "version directory" given the path to the directory
of the given training experiment. This directory must have a file "metrics.csv".
Args:
experiment_dir: path to the training experiment of interest
Returns:
Path to the last version directory
"""
version_folder_paths = glob.glob(
os.path.join(experiment_dir, "version_*/metrics.csv")
)
available_versions = [
int(
version_folder_path[
len(os.path.join(experiment_dir, "version_")) : -len("/metrics.csv")
]
)
for version_folder_path in version_folder_paths
]
if not available_versions:
raise ValueError(f"No version with `metrics.csv` found in {experiment_dir}")
last_version = sorted(available_versions)[-1]
return os.path.join(os.path.join(experiment_dir, f"version_{last_version}"))
[docs]def get_last_artifact(version_dir: str, ckpt_dirname: str = "checkpoints") -> str:
"""Get the last artifact stored in a given version directory.
The last artifact is the one that has the largest number of epochs, the largest
number of steps and the last version.
Args:
version_dir: path to a directory that stores the training outcomes
of a given experiment
Returns:
Path to the last PyTorch artifact file
"""
checkpoints_dir = os.path.join(version_dir, ckpt_dirname)
available_artifact_files = os.listdir(checkpoints_dir)
dict_info_artifact_files = []
for available_artifact_file in available_artifact_files:
available_artifact_file_without_ext, ext = os.path.splitext(
available_artifact_file
)
if ext == ".ckpt":
tuple_info = available_artifact_file_without_ext.split("-")
if len(tuple_info) == 2:
epoch_str, step_str = tuple_info
version_str = None
elif len(tuple_info) == 3:
epoch_str, step_str, version_str = tuple_info
else:
raise ValueError(
f"Artifact name `{available_artifact_file}` could not be parsed."
)
epoch = int(epoch_str[len("epoch=") :])
step = int(step_str[len("step=") :])
version = 0 if version_str is None else version_str[len("v") :]
dict_info_artifact_files.append(
{
"epoch": epoch,
"step": step,
"version": version,
"file_name": available_artifact_file,
}
)
df_info_artifact_files = pd.DataFrame(dict_info_artifact_files)
if len(df_info_artifact_files) == 0:
raise ValueError(f"No artifact files were found in {checkpoints_dir}")
else:
last_artifact_file_name = df_info_artifact_files.sort_values(
by=["epoch", "step", "version"],
ascending=False,
)["file_name"].iloc[0]
return os.path.join(checkpoints_dir, last_artifact_file_name)
[docs]def get_last_version_dir_from_config(step: str, path_or_config: str | dict) -> str:
"""Get the path to the last version directory given the configuration.
Args:
step: ``embedding`` or ``gnn``
path_or_config: configuration dictionary, or path to the YAML file that contains
the configuration
Returns:
Path to the last version directory
Notes:
For backward compatibility, if the ``embedding`` does not exist,
it is replaced by ``metric_learning``.
"""
configs = load_config(path_or_config)
experiment_name = configs["common"]["experiment_name"]
experiment_dir = os.path.join(cdirs.artifact_directory, step, experiment_name)
# For backward compatibility, for if ``embedding`` directory does not exist,
# check if ``metric_learning`` does.
if step == "embedding":
if not os.path.exists(experiment_dir):
experiment_dir = os.path.join(
cdirs.artifact_directory, "metric_learning", experiment_name
)
return get_last_version_dir(experiment_dir=experiment_dir)
[docs]def get_training_metrics(
trainer: Trainer | str | typing.List[str], suffix: str = ""
) -> pd.DataFrame:
"""Get the dataframe of the training metrics.
Args:
trainer: either a PyTorch Lighting Trainer object, or the path(s) to the metric
file(s) to load directly.
suffix: suffix to add to the name of the columns in the CSV file
Returns:
Dataframe of the training metrics (one row / epoch).
"""
if isinstance(trainer, Trainer):
logger = trainer.logger
assert logger is not None, "No logger was assigned to this trainer."
log_dir = logger.log_dir
assert log_dir is not None, "The logger was not assigned a local directory."
log_file = os.path.join(log_dir, "metrics.csv")
elif isinstance(trainer, str):
log_file = trainer
elif isinstance(trainer, (list, tuple)):
return pd.concat(
(get_training_metrics(trainer=log_file) for log_file in trainer),
axis=0,
)
else:
raise TypeError(
"`trainer` should be str, a list of str or a pytorch trainer, but is "
+ type(trainer).__name__
)
metrics = pd.read_csv(log_file, sep=",")
train_loss_column = (
"train_loss" + suffix
if "train_loss" + suffix in metrics
else f"train_loss{suffix}_epoch"
)
val_loss_column = (
"val_loss" + suffix
if "val_loss" + suffix in metrics
else f"val_loss{suffix}_epoch"
)
train_metrics = metrics[~metrics[train_loss_column].isna()][
["epoch", train_loss_column]
]
# train_metrics["epoch"] -= 1
val_metrics = metrics[~metrics[val_loss_column].isna()][
[
column if column in metrics else column + "_epoch"
for column in [
"val_loss" + suffix,
"eff" + suffix,
"pur" + suffix,
"current_lr",
"epoch",
]
]
]
metrics = pd.merge(left=train_metrics, right=val_metrics, how="inner", on="epoch")
for column in metrics.columns:
if column.endswith("_epoch"):
metrics.rename(columns={column: column[: -len("_epoch")]}, inplace=True)
if suffix:
metrics.rename(
columns={
column: column[: -len(suffix)]
for column in metrics.columns
if column.endswith(suffix)
},
inplace=True,
)
return metrics