Source code for pipeline.utils.scriptutils.convenience_utils
from __future__ import annotations
import os
import pandas as pd
import torch
from pytorch_lightning import Trainer
device = "cuda" if torch.cuda.is_available() else "cpu"
[docs]def get_training_metrics(trainer: Trainer | str) -> pd.DataFrame:
"""Get the dataframe of the training metrics.
Args:
trainer: either a PyTorch Lighting Trainer object, or the path to the metric
file to load directly.
Returns:
Dataframe of the training metrics (one row / epoch).
"""
if isinstance(trainer, Trainer):
logger = trainer.logger
assert logger is not None, "No logger was assigned to this trainer."
log_dir = logger.log_dir
assert log_dir is not None, "The logger was not assigned a local directory."
log_file = os.path.join(log_dir, "metrics.csv")
elif isinstance(trainer, str):
log_file = trainer
else:
raise TypeError(
"`trainer` should be str or a pytorch trainer, but is "
+ type(trainer).__name__
)
metrics = pd.read_csv(log_file, sep=",")
train_metrics = metrics[~metrics["train_loss"].isna()][["epoch", "train_loss"]]
train_metrics["epoch"] -= 1
val_metrics = metrics[~metrics["val_loss"].isna()][
["val_loss", "eff", "pur", "current_lr", "epoch"]
]
metrics = pd.merge(left=train_metrics, right=val_metrics, how="inner", on="epoch")
return metrics