Source code for pipeline.Embedding.models

"""A package that defines various embedding networks.
"""
import typing
from ..embedding_base import EmbeddingBase


[docs]def get_model(model_type: str | None = None) -> typing.Type[EmbeddingBase]: """Get an embedding model class from its name. Args: model_type: embedding type. Available so far are only ``layerless``. Returns: The GNN model class that can be instantiated. """ if model_type is None: model_type = "layerless" if model_type == "layerless": from .layerless_embedding import LayerlessEmbedding return LayerlessEmbedding else: raise ValueError(f"Embedding type {model_type} is not recognised.")