Source code for pipeline
"""This package handles the various steps of the GNN-based pipeline,
from preprocessing to track-finding evaluation.
"""
import logging
import typing
from Embedding.embedding_base import EmbeddingBase
from GNN.triplet_gnn_base import TripletGNNBase
from utils.commonutils.config import get_pipeline_config_path, load_config
from utils.modelutils.basemodel import ModelBase
from utils.modelutils.checkpoint_utils import (
get_last_artifact,
get_last_version_dir_from_config,
)
[docs]def get_model(path_or_config: str | dict, step: str) -> typing.Type[ModelBase]:
"""Get the model class of a given step for a given pipeline configuration.
Args:
path_or_config: pipeline configuration
step: model step (``embedding`` or ``gnn``)
Returns:
Model class that can be later instantiated.
"""
model_type = load_config(path_or_config)[step].get("model")
if step == "embedding":
from Embedding.models import get_model
return get_model(model_type=model_type)
elif step == "gnn":
from GNN.models import get_model
return get_model(model_type=model_type)
else:
raise ValueError(
f"Model step `{step}` is not recognised. "
"Only `embedding` and `gnn` are supported."
)
@typing.overload
def load_trained_model(
path_or_config: str | dict,
step: typing.Literal["embedding"],
checkpoint_path: str | None = None,
**kwargs,
) -> EmbeddingBase:
...
@typing.overload
def load_trained_model(
path_or_config: str | dict,
step: typing.Literal["gnn"],
checkpoint_path: str | None = None,
**kwargs,
) -> TripletGNNBase:
...
@typing.overload
def load_trained_model(
path_or_config: str | dict, step: str, checkpoint_path: str | None = None, **kwargs
) -> ModelBase:
...
[docs]def load_trained_model(
path_or_config: str | dict, step: str, checkpoint_path: str | None = None, **kwargs
) -> ModelBase:
"""Load a model that was already trained.
Args:
path_or_config: pipeline configuration
step: model step (``embedding`` or ``gnn``)
checkpoint_path: path to a checkpoint
Returns:
Trained model
"""
config = load_config(path_or_config=path_or_config)
# Get model class
Model = get_model(path_or_config=path_or_config, step=step)
if checkpoint_path is None:
# Get last artifact path
checkpoint_path = get_last_artifact(
version_dir=get_last_version_dir_from_config(
step=step, path_or_config=config
)
)
# Load model from checkpoint
return Model.load_from_checkpoint(
checkpoint_path=checkpoint_path, hparams=config[step], **kwargs
)
[docs]def instantiate_model_for_training(path_or_config: str | dict, step: str) -> ModelBase:
"""Instantiate a new model. The model can then be trained.
The function auto-detects the the model type.
The latter can also be instantiated from a trained model (transfer learning)
using the parameter ``from`` in the configuration file.
Args:
path_or_config: pipeline configuraiton
step: model step (``embedding`` or ``gnn``)
Returns:
Model that was not trained.
"""
config = load_config(path_or_config=path_or_config)
model_config = config[step]
# Get model class
Model = get_model(path_or_config=path_or_config, step=step)
# Check whether the model should be loaded from another trained from
if (from_model := model_config.get("from")) is not None:
other_config_path = get_pipeline_config_path(experiment_name=from_model)
checkpoint_path = get_last_artifact(
version_dir=get_last_version_dir_from_config(
step=step, path_or_config=other_config_path
)
)
logging.info(f"Load pre-trained model from {checkpoint_path}")
model = Model.load_from_checkpoint(
checkpoint_path=checkpoint_path, hparams=model_config
)
else:
model = Model(hparams=model_config)
return model