Source code for pipeline

"""This package handles the various steps of the GNN-based pipeline,
from preprocessing to track-finding evaluation.
"""
import logging
import typing
from Embedding.embedding_base import EmbeddingBase
from GNN.triplet_gnn_base import TripletGNNBase
from utils.commonutils.config import get_pipeline_config_path, load_config
from utils.modelutils.basemodel import ModelBase
from utils.modelutils.checkpoint_utils import (
    get_last_artifact,
    get_last_version_dir_from_config,
)


[docs]def get_model(path_or_config: str | dict, step: str) -> typing.Type[ModelBase]: """Get the model class of a given step for a given pipeline configuration. Args: path_or_config: pipeline configuration step: model step (``embedding`` or ``gnn``) Returns: Model class that can be later instantiated. """ model_type = load_config(path_or_config)[step].get("model") if step == "embedding": from Embedding.models import get_model return get_model(model_type=model_type) elif step == "gnn": from GNN.models import get_model return get_model(model_type=model_type) else: raise ValueError( f"Model step `{step}` is not recognised. " "Only `embedding` and `gnn` are supported." )
@typing.overload def load_trained_model( path_or_config: str | dict, step: typing.Literal["embedding"], checkpoint_path: str | None = None, **kwargs, ) -> EmbeddingBase: ... @typing.overload def load_trained_model( path_or_config: str | dict, step: typing.Literal["gnn"], checkpoint_path: str | None = None, **kwargs, ) -> TripletGNNBase: ... @typing.overload def load_trained_model( path_or_config: str | dict, step: str, checkpoint_path: str | None = None, **kwargs ) -> ModelBase: ...
[docs]def load_trained_model( path_or_config: str | dict, step: str, checkpoint_path: str | None = None, **kwargs ) -> ModelBase: """Load a model that was already trained. Args: path_or_config: pipeline configuration step: model step (``embedding`` or ``gnn``) checkpoint_path: path to a checkpoint Returns: Trained model """ config = load_config(path_or_config=path_or_config) # Get model class Model = get_model(path_or_config=path_or_config, step=step) if checkpoint_path is None: # Get last artifact path checkpoint_path = get_last_artifact( version_dir=get_last_version_dir_from_config( step=step, path_or_config=config ) ) # Load model from checkpoint return Model.load_from_checkpoint( checkpoint_path=checkpoint_path, hparams=config[step], **kwargs )
[docs]def instantiate_model_for_training(path_or_config: str | dict, step: str) -> ModelBase: """Instantiate a new model. The model can then be trained. The function auto-detects the the model type. The latter can also be instantiated from a trained model (transfer learning) using the parameter ``from`` in the configuration file. Args: path_or_config: pipeline configuraiton step: model step (``embedding`` or ``gnn``) Returns: Model that was not trained. """ config = load_config(path_or_config=path_or_config) model_config = config[step] # Get model class Model = get_model(path_or_config=path_or_config, step=step) # Check whether the model should be loaded from another trained from if (from_model := model_config.get("from")) is not None: other_config_path = get_pipeline_config_path(experiment_name=from_model) checkpoint_path = get_last_artifact( version_dir=get_last_version_dir_from_config( step=step, path_or_config=other_config_path ) ) logging.info(f"Load pre-trained model from {checkpoint_path}") model = Model.load_from_checkpoint( checkpoint_path=checkpoint_path, hparams=model_config ) else: model = Model(hparams=model_config) return model