Source code for pipeline.Embedding.models
"""A package that defines various embedding networks.
"""
import typing
from ..embedding_base import EmbeddingBase
[docs]def get_model(model_type: str | None = None) -> typing.Type[EmbeddingBase]:
"""Get an embedding model class from its name.
Args:
model_type: embedding type. Available so far are only ``layerless``.
Returns:
The GNN model class that can be instantiated.
"""
if model_type is None:
model_type = "layerless"
if model_type == "layerless":
from .layerless_embedding import LayerlessEmbedding
return LayerlessEmbedding
else:
raise ValueError(f"Embedding type {model_type} is not recognised.")