Source code for pipeline.utils.scriptutils.convenience_utils

from __future__ import annotations
import os

import pandas as pd
import torch
from pytorch_lightning import Trainer


device = "cuda" if torch.cuda.is_available() else "cpu"


[docs]def get_training_metrics(trainer: Trainer | str) -> pd.DataFrame: """Get the dataframe of the training metrics. Args: trainer: either a PyTorch Lighting Trainer object, or the path to the metric file to load directly. Returns: Dataframe of the training metrics (one row / epoch). """ if isinstance(trainer, Trainer): logger = trainer.logger assert logger is not None, "No logger was assigned to this trainer." log_dir = logger.log_dir assert log_dir is not None, "The logger was not assigned a local directory." log_file = os.path.join(log_dir, "metrics.csv") elif isinstance(trainer, str): log_file = trainer else: raise TypeError( "`trainer` should be str or a pytorch trainer, but is " + type(trainer).__name__ ) metrics = pd.read_csv(log_file, sep=",") train_metrics = metrics[~metrics["train_loss"].isna()][["epoch", "train_loss"]] train_metrics["epoch"] -= 1 val_metrics = metrics[~metrics["val_loss"].isna()][ ["val_loss", "eff", "pur", "current_lr", "epoch"] ] metrics = pd.merge(left=train_metrics, right=val_metrics, how="inner", on="epoch") return metrics