Source code for pipeline.GNN.models.edge_based_gnn

import typing
import torch
from torch_scatter import scatter_add

from utils.modelutils.mlp import make_mlp

from ..triplet_gnn_base import TripletGNNBase


[docs]class EdgeBasedGNN(TripletGNNBase): """A triplet-based GNN without using node encodings.""" def __init__(self, hparams: typing.Dict[str, typing.Any]): super(TripletGNNBase, self).__init__(hparams) nb_edge_layers: int = hparams["nb_edge_layers"] nb_hidden: int = hparams["hidden"] self.edge_encoder = make_mlp( 2 * self.get_n_features(), [nb_hidden] * hparams.get("nb_edge_encoder_layers", nb_edge_layers), layer_norm=hparams["layernorm"], output_activation=None, hidden_activation=hparams["hidden_activation"], ) message_size = 4 if self.hparams["aggregation"] == "sum_max" else 2 # The node network computes new node features self.edge_network = make_mlp( (1 + message_size) * nb_hidden, [nb_hidden] * nb_edge_layers, layer_norm=hparams["layernorm"], output_activation=None, hidden_activation=hparams["hidden_activation"], ) # Final edge output classification network self.output_edge_classifier = make_mlp( nb_hidden, [nb_hidden] * hparams.get("nb_edge_classifier_layers", nb_edge_layers) + [1], layer_norm=hparams["layernorm"], output_activation=None, hidden_activation=hparams["hidden_activation"], ) self.output_triplet_classifier = make_mlp( 2 * nb_hidden, [nb_hidden] * hparams.get("nb_edge_classifier_layers", nb_edge_layers) + [1], layer_norm=hparams["layernorm"], output_activation=None, hidden_activation=hparams["hidden_activation"], )
[docs] def message_step( self, e: torch.Tensor, start: torch.Tensor, end: torch.Tensor, dim_size: int ) -> torch.Tensor: # Compute messages message_in = scatter_add(e, end, dim=0, dim_size=dim_size) message_out = scatter_add(e, start, dim=0, dim_size=dim_size) # Propagate messages to edges e = ( self.edge_network( torch.cat((e, message_in[start], message_out[end]), dim=-1) ) + e ) return e
[docs] def forward_edges( self, x: torch.Tensor, start: torch.Tensor, end: torch.Tensor ) -> typing.Dict[str, torch.Tensor]: """Forward step for edge classification. Args: x: Hit features edge_index: Torch tensor with 2 rows that define the edges. Returns: A tuple of 3 tensors: the hit encodings and edge encodings after message passing, and the edge classifier output. """ # Encode the graph features into the hidden space e = self.edge_encoder(torch.cat((x[start], x[end]), dim=-1)) # Loop over iterations of edge and node networks for _ in range(self.hparams["n_graph_iters"]): e = self.message_step(e, start, end, dim_size=x.shape[0]) # type: ignore # Compute final edge scores; use original edge directions only edge_output = self.output_edge_classifier(e).squeeze(-1) return {"e": e, "edge_output": edge_output}
[docs] def triplet_output_step_articulation(self, e, edge_indices, triplet_indices): assert torch.all( edge_indices[1][triplet_indices[0]] == edge_indices[0][triplet_indices[1]] ) triplet_classifier_inputs = torch.cat( (e[[triplet_indices[0]]], e[triplet_indices[1]]), dim=-1, ) return self.output_triplet_classifier(triplet_classifier_inputs).squeeze(-1)
[docs] def triplet_output_step_elbow_left(self, e, edge_indices, triplet_indices): assert torch.all( edge_indices[0][triplet_indices[0]] == edge_indices[0][triplet_indices[1]] ) triplet_classifier_inputs_1 = torch.cat( (e[[triplet_indices[0]]], e[triplet_indices[1]]), dim=-1, ).squeeze(-1) triplet_classifier_inputs_2 = torch.cat( (e[triplet_indices[0]], e[triplet_indices[1]]), dim=-1, ).squeeze(-1) output_1 = self.output_triplet_classifier(triplet_classifier_inputs_1).squeeze( -1 ) output_2 = self.output_triplet_classifier(triplet_classifier_inputs_2).squeeze( -1 ) return (output_1 + output_2) / 2
[docs] def triplet_output_step_elbow_right(self, e, edge_indices, triplet_indices): assert torch.all( edge_indices[1][triplet_indices[0]] == edge_indices[1][triplet_indices[1]] ) triplet_classifier_inputs_1 = torch.cat( (e[[triplet_indices[0]]], e[triplet_indices[1]]), dim=-1, ).squeeze(-1) triplet_classifier_inputs_2 = torch.cat( (e[triplet_indices[0]], e[triplet_indices[1]]), dim=-1, ).squeeze(-1) output_1 = self.output_triplet_classifier(triplet_classifier_inputs_1).squeeze( -1 ) output_2 = self.output_triplet_classifier(triplet_classifier_inputs_2).squeeze( -1 ) return (output_1 + output_2) / 2
[docs] def forward_triplets( self, e: torch.Tensor, filtered_edge_index: torch.Tensor, edge_mask: torch.Tensor, dict_triplet_indices: typing.Dict[str, torch.Tensor], **kwargs, # parameters ignored ) -> typing.Dict[str, torch.Tensor]: """Forward step for triplet classification. Args: e: Edge encodings after the edge forward step filtered_edge_index: edge index after requiring the minimal edge score edge_mask: edge mask of edges kept after filtering edge_score_cut: custom edge score cut. If not specified, the value is taken from the internal ``hparams`` dictionary. Returns: A dictionary that associates ``articulation``, ``elbow_left`` and ``elbow_right`` with the logits of the corresponding triplets. """ dict_triplet_outputs: typing.Dict[str, torch.Tensor] = self.triplet_output_step( e=e[edge_mask], edge_indices=filtered_edge_index, dict_triplet_indices=dict_triplet_indices, ) return dict_triplet_outputs
@property def subnetworks(self) -> typing.List[str]: return super(EdgeBasedGNN, self).subnetworks + [ "edge_encoder", "edge_network", "edge_output_classifier", ] @property def subnetwork_to_outputs(self) -> typing.Dict[str, typing.List[str]]: return { **super(EdgeBasedGNN, self).subnetwork_to_outputs, "edge_encoder": ["e"], "edge_network": ["e"], "edge_output_classifier": ["edge_score"], } @property def subnetwork_groups(self) -> typing.Dict[str, typing.List[str]]: return { **super(EdgeBasedGNN, self).subnetwork_groups, "edge_split": ["edge_encoder", "edge_network", "edge_output_classifier"], } @property def input_kwargs(self) -> typing.Dict[str, typing.Any]: return { **super(EdgeBasedGNN, self).input_kwargs, "message_in": dict( size=(self._n_hits, self.n_hiddens), dtype=torch.float32 ), "message_out": dict( size=(self._n_hits, self.n_hiddens), dtype=torch.float32 ), "e": dict(size=(self._n_edges, self.n_hiddens), dtype=torch.float32), } @property def input_to_dynamic_axes(self): """A dictionary that associates an input name with the dynamic axis specification. """ return { **super(EdgeBasedGNN, self).input_to_dynamic_axes, "e": {0: "n_edges"}, "message_in": {0: "n_hits"}, "message_out": {0: "n_hits"}, "edge_score": {0: "n_edges"}, } def _onnx_edge_output_classifier(self, e): return torch.sigmoid(self.output_edge_classifier(e)) def _onnx_edge_encoder( self, x: torch.Tensor, start: torch.Tensor, end: torch.Tensor ) -> torch.Tensor: return self.edge_encoder(torch.cat((x[start], x[end]), dim=-1)) def _onnx_edge_network( self, e: torch.Tensor, start: torch.Tensor, end: torch.Tensor, message_in: torch.Tensor, message_out: torch.Tensor, ) -> torch.Tensor: e = ( self.edge_network( torch.cat((e, message_in[start], message_out[end]), dim=-1) ) + e ) return e