Source code for pipeline.GNN.models

"""A package that defines various triplet-based GNNs.
"""

import typing
from ..triplet_gnn_base import TripletGNNBase


[docs]def get_model(model_type: str | None = None) -> typing.Type[TripletGNNBase]: """Get a GNN model class from its name. Args: model_type: GNN type. Available are ``triplet_interaction``, ``edge_based`` and ``scifi_triplet_interaction``. Returns: The GNN model class that can be instantiated. """ if model_type is None: model_type = "triplet_interaction" if model_type == "triplet_interaction": from .triplet_interaction_gnn import TripletInteractionGNN return TripletInteractionGNN elif model_type == "edge_based": from .edge_based_gnn import EdgeBasedGNN return EdgeBasedGNN elif model_type == "scifi_triplet_interaction": from .scifi_triplet_interaction_gnn import SciFiTripletInteractionGNN return SciFiTripletInteractionGNN elif model_type == "simple_filtered_triplet_interaction": from .simple_filtered_triplet_interaction_gnn import SimpleFilteredTripletInteractionGNN return SimpleFilteredTripletInteractionGNN elif model_type == "filtered_triplet_interaction": from .filtered_triplet_interaction_gnn import ( FilteredTripletInteractionGNN, ) return FilteredTripletInteractionGNN elif model_type == "shallow_interaction": from .shallow_interaction_gnn import ShallowInteractionGNN return ShallowInteractionGNN else: raise ValueError(f"GNN type {model_type} is not recognised.")