Source code for pipeline.GNN.models.triplet_interaction_gnn

import typing

import torch
from torch_scatter import scatter_add, scatter_max

from utils.modelutils.mlp import make_mlp
from ..triplet_gnn_base import TripletGNNBase, check_and_discard

from utils.modelutils.export import TRTScatterAddOp


def _get_list_n_layers(n_layers: int | typing.List[int]) -> typing.List[int]:
    return [n_layers] if isinstance(n_layers, int) else n_layers


[docs]class TripletInteractionGNN(TripletGNNBase): """A triplet-based interaction network.""" #: A dictionary that allows to customise the ONNX export. _onnx_options = set() @property def recursive(self) -> bool: """Whether the GNN is recursive, i.e., the networks involved in the message step are the same. """ return len(self.edge_networks) == 1 and len(self.node_networks) == 1 @property def n_graph_iters(self) -> int: """Number of message-passing iterations""" return int(self.hparams["n_graph_iters"]) @property def only_e(self) -> bool: return self.hparams.get("only_e", False) def __init__(self, hparams): super().__init__(hparams) """ Initialise the Lightning Module that can scan over different GNN training regimes """ n_node_hiddens: int | None = hparams.get("node_hidden", hparams.get("hidden")) n_edge_hiddens: int | None = hparams.get("edge_hidden", hparams.get("hidden")) assert n_node_hiddens is not None assert n_edge_hiddens is not None n_hiddens = max(hparams.get("hidden", 0), n_node_hiddens, n_edge_hiddens) list_n_node_layers = _get_list_n_layers(hparams["nb_node_layers"]) list_n_edge_layers = _get_list_n_layers(hparams["nb_edge_layers"]) output_activation = self.hparams.get("output_activation") # Setup input network self.node_encoder = make_mlp( self.get_n_features(), [n_node_hiddens] * hparams["nb_node_encoder_layers"], output_activation=output_activation, hidden_activation=hparams["hidden_activation"], layer_norm=hparams["layernorm"], ) # The edge network computes new edge features from connected nodes self.edge_encoder = make_mlp( 2 * (n_node_hiddens), [n_edge_hiddens] * hparams["nb_edge_encoder_layers"], layer_norm=hparams["layernorm"], output_activation=output_activation, hidden_activation=hparams["hidden_activation"], ) # The edge network computes new edge features from connected nodes if len(list_n_edge_layers) == 1: self.edge_network = make_mlp( 2 * n_node_hiddens + n_edge_hiddens, [n_edge_hiddens] * list_n_edge_layers[0], layer_norm=hparams["layernorm"], output_activation=output_activation, hidden_activation=hparams["hidden_activation"], ) self.edge_networks = [self.edge_network] else: self.edge_networks = torch.nn.ModuleList( [ make_mlp( 2 * n_node_hiddens + n_edge_hiddens, [n_edge_hiddens] * n_edge_layers, layer_norm=hparams["layernorm"], output_activation=output_activation, hidden_activation=hparams["hidden_activation"], ) for n_edge_layers in list_n_edge_layers ] ) message_size = 4 if self.hparams["aggregation"] == "sum_max" else 2 # The node network computes new node features if len(list_n_node_layers) == 1: self.node_network = make_mlp( n_node_hiddens + message_size * n_edge_hiddens, [n_node_hiddens] * list_n_node_layers[0], layer_norm=hparams["layernorm"], output_activation=output_activation, hidden_activation=hparams["hidden_activation"], ) self.node_networks = [self.node_network] else: self.node_networks = torch.nn.ModuleList( [ make_mlp( n_node_hiddens + message_size * n_edge_hiddens, [n_node_hiddens] * n_node_layers, layer_norm=hparams["layernorm"], output_activation=output_activation, hidden_activation=hparams["hidden_activation"], ) for n_node_layers in list_n_node_layers ] ) # Final edge output classification network edge_classifier_input_size = ( n_edge_hiddens if self.only_e else 2 * n_node_hiddens + n_edge_hiddens ) triplet_classifier_input_size = ( 2 * n_edge_hiddens if self.only_e else 3 * n_node_hiddens + 2 * n_edge_hiddens ) self.output_edge_classifier = make_mlp( edge_classifier_input_size, [n_hiddens] * hparams["nb_edge_classifier_layers"] + [1], layer_norm=hparams["layernorm"], output_activation=None, hidden_activation=hparams["hidden_activation"], ) self.output_triplet_classifier = make_mlp( triplet_classifier_input_size, [n_hiddens] * hparams["nb_edge_classifier_layers"] + [1], layer_norm=hparams["layernorm"], output_activation=None, hidden_activation=hparams["hidden_activation"], )
[docs] def scatter_add( self, source: torch.Tensor, index: torch.Tensor, h: torch.Tensor ) -> torch.Tensor: """Scatter add operation. In ONNX export mode for TensorRT, the operation is replaced by a fake operator that is implemented through a plugin in TensorRT. """ if torch.onnx.is_in_onnx_export() and "use_trt_scatter" in self._onnx_options: return TRTScatterAddOp.apply(source, index, h) # type: ignore else: return scatter_add(source, index, dim=0, dim_size=h.shape[0])
[docs] def scatter_max(self, source: torch.Tensor, index: torch.Tensor, h: torch.Tensor): if torch.onnx.is_in_onnx_export() and "use_trt_scatter" in self._onnx_options: raise NotImplementedError( "`scatter_max` is not defined for ONNX export with TensorRT." ) else: return scatter_max(source, index, dim=0, dim_size=h.shape[0])[0]
[docs] def message_step( self, h: torch.Tensor, start: torch.Tensor, end: torch.Tensor, e: torch.Tensor, step: int, ) -> typing.Tuple[torch.Tensor, torch.Tensor]: """Apply one step of message-passing that updates the node and edge encodings. """ if self.hparams["aggregation"] == "sum": node_inputs = torch.cat( ( h, self.scatter_add(e, end, h=h), self.scatter_add(e, start, h=h), ), dim=-1, ) elif self.hparams["aggregation"] == "max": node_inputs = torch.cat( ( h, self.scatter_max(e, end, h=h), self.scatter_max(e, start, h=h), ), dim=-1, ) elif self.hparams["aggregation"] == "sum_max": node_inputs = torch.cat( ( h, self.scatter_max(e, end, h=h), self.scatter_add(e, end, h=h), self.scatter_max(e, start, h=h), self.scatter_add(e, start, h=h), ), dim=-1, ) else: raise ValueError( f"Aggregation `{self.hparams['aggregation']}` not recognised" ) h = self.node_networks[min(step, len(self.node_networks) - 1)](node_inputs) + h # Compute new edge features edge_inputs = torch.cat([h[start], h[end], e], dim=-1) e = self.edge_networks[min(step, len(self.edge_networks) - 1)](edge_inputs) + e return h, e
[docs] def output_step( self, h: torch.Tensor, start: torch.Tensor, end: torch.Tensor, e: torch.Tensor ) -> torch.Tensor: """Apply the edge output classifier to edges to get edge logits.""" if self.only_e: classifier_inputs = e else: classifier_inputs = torch.cat((h[start], h[end], e), dim=-1) return self.output_edge_classifier(classifier_inputs).squeeze(-1)
[docs] def triplet_output_step_articulation(self, h, e, edge_indices, triplet_indices): assert torch.all( edge_indices[1][triplet_indices[0]] == edge_indices[0][triplet_indices[1]] ) if self.only_e: triplet_classifier_inputs = torch.cat( (e[triplet_indices[0]], e[triplet_indices[1]]), dim=-1, ) else: triplet_classifier_inputs = torch.cat( ( h[edge_indices[1][triplet_indices[0]]], # shared h[edge_indices[0][triplet_indices[0]]], # first h[edge_indices[1][triplet_indices[1]]], # second e[triplet_indices[0]], e[triplet_indices[1]], ), dim=-1, ) return self.output_triplet_classifier(triplet_classifier_inputs).squeeze(-1)
[docs] def triplet_output_step_elbow_left(self, h, e, edge_indices, triplet_indices): assert torch.all( edge_indices[0][triplet_indices[0]] == edge_indices[0][triplet_indices[1]] ) if self.only_e: triplet_classifier_inputs_1 = torch.cat( ( e[triplet_indices[0]], e[triplet_indices[1]], ), dim=-1, ).squeeze(-1) triplet_classifier_inputs_2 = torch.cat( ( e[triplet_indices[1]], e[triplet_indices[0]], ), dim=-1, ).squeeze(-1) else: triplet_classifier_inputs_1 = torch.cat( ( h[edge_indices[0][triplet_indices[0]]], # shared h[edge_indices[1][triplet_indices[0]]], # first h[edge_indices[1][triplet_indices[1]]], # second e[triplet_indices[0]], e[triplet_indices[1]], ), dim=-1, ).squeeze(-1) triplet_classifier_inputs_2 = torch.cat( ( h[edge_indices[0][triplet_indices[1]]], # shared h[edge_indices[1][triplet_indices[1]]], # first h[edge_indices[1][triplet_indices[0]]], # second e[triplet_indices[1]], e[triplet_indices[0]], ), dim=-1, ).squeeze(-1) output_1 = self.output_triplet_classifier(triplet_classifier_inputs_1).squeeze( -1 ) output_2 = self.output_triplet_classifier(triplet_classifier_inputs_2).squeeze( -1 ) return (output_1 + output_2) / 2
[docs] def triplet_output_step_elbow_right(self, h, e, edge_indices, triplet_indices): assert torch.all( edge_indices[1][triplet_indices[0]] == edge_indices[1][triplet_indices[1]] ) if self.only_e: triplet_classifier_inputs_1 = torch.cat( (e[triplet_indices[0]], e[triplet_indices[1]]), dim=-1, ) triplet_classifier_inputs_2 = torch.cat( (e[triplet_indices[1]], e[triplet_indices[0]]), dim=-1, ) else: triplet_classifier_inputs_1 = torch.cat( ( h[edge_indices[1][triplet_indices[0]]], h[edge_indices[0][triplet_indices[0]]], h[edge_indices[0][triplet_indices[1]]], e[triplet_indices[0]], e[triplet_indices[1]], ), dim=-1, ) triplet_classifier_inputs_2 = torch.cat( ( h[edge_indices[1][triplet_indices[1]]], h[edge_indices[0][triplet_indices[1]]], h[edge_indices[0][triplet_indices[0]]], e[triplet_indices[1]], e[triplet_indices[0]], ), dim=-1, ) output_1 = self.output_triplet_classifier(triplet_classifier_inputs_1).squeeze( -1 ) output_2 = self.output_triplet_classifier(triplet_classifier_inputs_2).squeeze( -1 ) return (output_1 + output_2) / 2
[docs] def encoding_step( self, x: torch.Tensor, start: torch.Tensor, end: torch.Tensor ) -> typing.Tuple[torch.Tensor, torch.Tensor]: """Initial encoding step of the GNN. Args: x: hit input features start: start indices of the edges end: end indices of the edges Returns: The node encodings ``h`` and edge encodings ``e`` """ h = self.node_encoder(x) e = self.edge_encoder(torch.cat((h[start], h[end]), dim=-1)) return h, e
[docs] def forward_edges( self, x: torch.Tensor, start: torch.Tensor, end: torch.Tensor ) -> typing.Dict[str, torch.Tensor]: """Forwrd step for edge classification. Args: x: hit input features start: start indices of the edges end: end indices of the edges Returns: A tuple of 3 tensors: the hit encodings and edge encodings after message passing, and the edge classifier output. """ # Encode the graph features into the hidden space h, e = self.encoding_step(x=x, start=start, end=end) # Loop over iterations of edge and node networks for step in range(self.n_graph_iters): h, e = self.message_step(h=h, start=start, end=end, e=e, step=step) # Compute final edge scores; use original edge directions only edge_output = self.output_step(h=h, start=start, end=end, e=e) return {"h": h, "e": e, "edge_output": edge_output}
[docs] def forward_triplets( self, h: torch.Tensor, e: torch.Tensor, filtered_edge_index: torch.Tensor, edge_mask: torch.Tensor, dict_triplet_indices: typing.Dict[str, torch.Tensor], **kwargs, ) -> typing.Dict[str, torch.Tensor]: """Forward step for triplet classification. Args: h: Hit encodings after the edge forward step e: Edge encodings after the edge forward step filtered_edge_index: edge index after requiring the minimal edge score edge_mask: edge mask of edges kept after filtering edge_score_cut: custom edge score cut. If not specified, the value is taken from the internal ``hparams`` dictionary. Returns: A dictionary that associates ``articulation``, ``elbow_left`` and ``elbow_right`` with the logits of the corresponding triplets. """ dict_triplet_outputs: typing.Dict[str, torch.Tensor] = self.triplet_output_step( h=h, e=e[edge_mask], edge_indices=filtered_edge_index, dict_triplet_indices=dict_triplet_indices, ) return dict_triplet_outputs
@property def subnetwork_groups(self) -> typing.Dict[str, typing.List[str]]: return { **super(TripletInteractionGNN, self).subnetwork_groups, "edge_split": ["encoder", "network", "edge_output_classifier"], } @property def subnetwork_to_outputs(self) -> typing.Dict[str, typing.List[str]]: return { **super(TripletInteractionGNN, self).subnetwork_to_outputs, "encoder": ["h", "e"], "network": ["h", "e"], "edge_output_classifier": ["edge_score"], "edge_all": ( ["e", "edge_score"] if self.only_e else ["h", "e", "edge_score"] ), "triplet_output_classifier": ["triplet_score"], } @property def input_kwargs(self) -> typing.Dict[str, typing.Any]: return { **super(TripletInteractionGNN, self).input_kwargs, "message_in": dict( size=(self._n_hits, self.n_hiddens), dtype=torch.float32 ), "message_out": dict( size=(self._n_hits, self.n_hiddens), dtype=torch.float32 ), "h": dict(size=(self._n_hits, self.n_hiddens), dtype=torch.float32), "e": dict(size=(self._n_edges, self.n_hiddens), dtype=torch.float32), } @property def input_to_dynamic_axes(self): """A dictionary that associates an input name with the dynamic axis specification. """ return { **super(TripletInteractionGNN, self).input_to_dynamic_axes, "h": {0: "n_hits"}, "e": {0: "n_edges"}, "message_in": {0: "n_hits"}, "message_out": {0: "n_hits"}, "edge_score": {0: "n_edges"}, "triplet_score": {0: "n_triplets"}, } def _onnx_edge_output_classifier( self, h: torch.Tensor, start: torch.Tensor, end: torch.Tensor, e: torch.Tensor ): return torch.sigmoid( self.output_edge_classifier(torch.cat((h[start], h[end], e), dim=-1)) ) def _onnx_encoder( self, x: torch.Tensor, start: torch.Tensor, end: torch.Tensor ) -> typing.Tuple[torch.Tensor, torch.Tensor]: return self.encoding_step(x=x, start=start, end=start) def _onnx_network( self, h: torch.Tensor, e: torch.Tensor, start: torch.Tensor, end: torch.Tensor, message_in: torch.Tensor, message_out: torch.Tensor, ) -> typing.Tuple[torch.Tensor, torch.Tensor]: h = self.node_network(torch.cat((h, message_in, message_out), dim=-1)) + h # Compute new edge features e = self.edge_network(torch.cat([h[start], h[end], e], dim=-1)) + e return h, e def _onnx_triplet_output_classifier( self, e: torch.Tensor, triplet_start: torch.Tensor, triplet_end: torch.Tensor ) -> torch.Tensor: return torch.sigmoid( self.output_triplet_classifier( torch.cat((e[triplet_start], e[triplet_end]), dim=-1) ) ).squeeze(-1)
[docs] def to_onnx( self, outpath: str, mode: str | None = None, options: typing.Iterable[str] | None = None, ) -> None: options = set() if options is None else set(options) # If the mode corresponds to a groups of submodes, the options will # be passed to each mode individually. use_options = mode not in self.subnetwork_groups if use_options: # Propagate options. # Discard the options that are used and pass them to the parent class method. self._onnx_options = set( option_name for option_name in ["use_trt_scatter"] if check_and_discard(options, option_name) ) super().to_onnx(outpath, mode, options=options) if use_options: self._onnx_options = {}