Source code for pipeline.GNN.models
"""A package that defines various triplet-based GNNs.
"""
import typing
from ..triplet_gnn_base import TripletGNNBase
[docs]def get_model(model_type: str | None = None) -> typing.Type[TripletGNNBase]:
"""Get a GNN model class from its name.
Args:
model_type: GNN type. Available are ``triplet_interaction``,
``edge_based`` and ``scifi_triplet_interaction``.
Returns:
The GNN model class that can be instantiated.
"""
if model_type is None:
model_type = "triplet_interaction"
if model_type == "triplet_interaction":
from .triplet_interaction_gnn import TripletInteractionGNN
return TripletInteractionGNN
elif model_type == "edge_based":
from .edge_based_gnn import EdgeBasedGNN
return EdgeBasedGNN
elif model_type == "scifi_triplet_interaction":
from .scifi_triplet_interaction_gnn import SciFiTripletInteractionGNN
return SciFiTripletInteractionGNN
elif model_type == "simple_filtered_triplet_interaction":
from .simple_filtered_triplet_interaction_gnn import SimpleFilteredTripletInteractionGNN
return SimpleFilteredTripletInteractionGNN
elif model_type == "filtered_triplet_interaction":
from .filtered_triplet_interaction_gnn import (
FilteredTripletInteractionGNN,
)
return FilteredTripletInteractionGNN
elif model_type == "shallow_interaction":
from .shallow_interaction_gnn import ShallowInteractionGNN
return ShallowInteractionGNN
else:
raise ValueError(f"GNN type {model_type} is not recognised.")