import typing
import torch
from torch_scatter import scatter_add, scatter_max
from utils.modelutils.mlp import make_mlp
from ..triplet_gnn_base import TripletGNNBase, check_and_discard
from utils.modelutils.export import TRTScatterAddOp
def _get_list_n_layers(n_layers: int | typing.List[int]) -> typing.List[int]:
return [n_layers] if isinstance(n_layers, int) else n_layers
[docs]class TripletInteractionGNN(TripletGNNBase):
"""A triplet-based interaction network."""
#: A dictionary that allows to customise the ONNX export.
_onnx_options = set()
@property
def recursive(self) -> bool:
"""Whether the GNN is recursive, i.e., the networks involved in the message
step are the same.
"""
return len(self.edge_networks) == 1 and len(self.node_networks) == 1
@property
def n_graph_iters(self) -> int:
"""Number of message-passing iterations"""
return int(self.hparams["n_graph_iters"])
@property
def only_e(self) -> bool:
return self.hparams.get("only_e", False)
def __init__(self, hparams):
super().__init__(hparams)
"""
Initialise the Lightning Module that can scan over different GNN training
regimes
"""
n_node_hiddens: int | None = hparams.get("node_hidden", hparams.get("hidden"))
n_edge_hiddens: int | None = hparams.get("edge_hidden", hparams.get("hidden"))
assert n_node_hiddens is not None
assert n_edge_hiddens is not None
n_hiddens = max(hparams.get("hidden", 0), n_node_hiddens, n_edge_hiddens)
list_n_node_layers = _get_list_n_layers(hparams["nb_node_layers"])
list_n_edge_layers = _get_list_n_layers(hparams["nb_edge_layers"])
output_activation = self.hparams.get("output_activation")
# Setup input network
self.node_encoder = make_mlp(
self.get_n_features(),
[n_node_hiddens] * hparams["nb_node_encoder_layers"],
output_activation=output_activation,
hidden_activation=hparams["hidden_activation"],
layer_norm=hparams["layernorm"],
)
# The edge network computes new edge features from connected nodes
self.edge_encoder = make_mlp(
2 * (n_node_hiddens),
[n_edge_hiddens] * hparams["nb_edge_encoder_layers"],
layer_norm=hparams["layernorm"],
output_activation=output_activation,
hidden_activation=hparams["hidden_activation"],
)
# The edge network computes new edge features from connected nodes
if len(list_n_edge_layers) == 1:
self.edge_network = make_mlp(
2 * n_node_hiddens + n_edge_hiddens,
[n_edge_hiddens] * list_n_edge_layers[0],
layer_norm=hparams["layernorm"],
output_activation=output_activation,
hidden_activation=hparams["hidden_activation"],
)
self.edge_networks = [self.edge_network]
else:
self.edge_networks = torch.nn.ModuleList(
[
make_mlp(
2 * n_node_hiddens + n_edge_hiddens,
[n_edge_hiddens] * n_edge_layers,
layer_norm=hparams["layernorm"],
output_activation=output_activation,
hidden_activation=hparams["hidden_activation"],
)
for n_edge_layers in list_n_edge_layers
]
)
message_size = 4 if self.hparams["aggregation"] == "sum_max" else 2
# The node network computes new node features
if len(list_n_node_layers) == 1:
self.node_network = make_mlp(
n_node_hiddens + message_size * n_edge_hiddens,
[n_node_hiddens] * list_n_node_layers[0],
layer_norm=hparams["layernorm"],
output_activation=output_activation,
hidden_activation=hparams["hidden_activation"],
)
self.node_networks = [self.node_network]
else:
self.node_networks = torch.nn.ModuleList(
[
make_mlp(
n_node_hiddens + message_size * n_edge_hiddens,
[n_node_hiddens] * n_node_layers,
layer_norm=hparams["layernorm"],
output_activation=output_activation,
hidden_activation=hparams["hidden_activation"],
)
for n_node_layers in list_n_node_layers
]
)
# Final edge output classification network
edge_classifier_input_size = (
n_edge_hiddens if self.only_e else 2 * n_node_hiddens + n_edge_hiddens
)
triplet_classifier_input_size = (
2 * n_edge_hiddens
if self.only_e
else 3 * n_node_hiddens + 2 * n_edge_hiddens
)
self.output_edge_classifier = make_mlp(
edge_classifier_input_size,
[n_hiddens] * hparams["nb_edge_classifier_layers"] + [1],
layer_norm=hparams["layernorm"],
output_activation=None,
hidden_activation=hparams["hidden_activation"],
)
self.output_triplet_classifier = make_mlp(
triplet_classifier_input_size,
[n_hiddens] * hparams["nb_edge_classifier_layers"] + [1],
layer_norm=hparams["layernorm"],
output_activation=None,
hidden_activation=hparams["hidden_activation"],
)
[docs] def scatter_add(
self, source: torch.Tensor, index: torch.Tensor, h: torch.Tensor
) -> torch.Tensor:
"""Scatter add operation. In ONNX export mode for TensorRT, the operation
is replaced by a fake operator that is implemented through a plugin
in TensorRT.
"""
if torch.onnx.is_in_onnx_export() and "use_trt_scatter" in self._onnx_options:
return TRTScatterAddOp.apply(source, index, h) # type: ignore
else:
return scatter_add(source, index, dim=0, dim_size=h.shape[0])
[docs] def scatter_max(self, source: torch.Tensor, index: torch.Tensor, h: torch.Tensor):
if torch.onnx.is_in_onnx_export() and "use_trt_scatter" in self._onnx_options:
raise NotImplementedError(
"`scatter_max` is not defined for ONNX export with TensorRT."
)
else:
return scatter_max(source, index, dim=0, dim_size=h.shape[0])[0]
[docs] def message_step(
self,
h: torch.Tensor,
start: torch.Tensor,
end: torch.Tensor,
e: torch.Tensor,
step: int,
) -> typing.Tuple[torch.Tensor, torch.Tensor]:
"""Apply one step of message-passing that updates the node and edge
encodings.
"""
if self.hparams["aggregation"] == "sum":
node_inputs = torch.cat(
(
h,
self.scatter_add(e, end, h=h),
self.scatter_add(e, start, h=h),
),
dim=-1,
)
elif self.hparams["aggregation"] == "max":
node_inputs = torch.cat(
(
h,
self.scatter_max(e, end, h=h),
self.scatter_max(e, start, h=h),
),
dim=-1,
)
elif self.hparams["aggregation"] == "sum_max":
node_inputs = torch.cat(
(
h,
self.scatter_max(e, end, h=h),
self.scatter_add(e, end, h=h),
self.scatter_max(e, start, h=h),
self.scatter_add(e, start, h=h),
),
dim=-1,
)
else:
raise ValueError(
f"Aggregation `{self.hparams['aggregation']}` not recognised"
)
h = self.node_networks[min(step, len(self.node_networks) - 1)](node_inputs) + h
# Compute new edge features
edge_inputs = torch.cat([h[start], h[end], e], dim=-1)
e = self.edge_networks[min(step, len(self.edge_networks) - 1)](edge_inputs) + e
return h, e
[docs] def output_step(
self, h: torch.Tensor, start: torch.Tensor, end: torch.Tensor, e: torch.Tensor
) -> torch.Tensor:
"""Apply the edge output classifier to edges to get edge logits."""
if self.only_e:
classifier_inputs = e
else:
classifier_inputs = torch.cat((h[start], h[end], e), dim=-1)
return self.output_edge_classifier(classifier_inputs).squeeze(-1)
[docs] def triplet_output_step_articulation(self, h, e, edge_indices, triplet_indices):
assert torch.all(
edge_indices[1][triplet_indices[0]] == edge_indices[0][triplet_indices[1]]
)
if self.only_e:
triplet_classifier_inputs = torch.cat(
(e[triplet_indices[0]], e[triplet_indices[1]]),
dim=-1,
)
else:
triplet_classifier_inputs = torch.cat(
(
h[edge_indices[1][triplet_indices[0]]], # shared
h[edge_indices[0][triplet_indices[0]]], # first
h[edge_indices[1][triplet_indices[1]]], # second
e[triplet_indices[0]],
e[triplet_indices[1]],
),
dim=-1,
)
return self.output_triplet_classifier(triplet_classifier_inputs).squeeze(-1)
[docs] def triplet_output_step_elbow_left(self, h, e, edge_indices, triplet_indices):
assert torch.all(
edge_indices[0][triplet_indices[0]] == edge_indices[0][triplet_indices[1]]
)
if self.only_e:
triplet_classifier_inputs_1 = torch.cat(
(
e[triplet_indices[0]],
e[triplet_indices[1]],
),
dim=-1,
).squeeze(-1)
triplet_classifier_inputs_2 = torch.cat(
(
e[triplet_indices[1]],
e[triplet_indices[0]],
),
dim=-1,
).squeeze(-1)
else:
triplet_classifier_inputs_1 = torch.cat(
(
h[edge_indices[0][triplet_indices[0]]], # shared
h[edge_indices[1][triplet_indices[0]]], # first
h[edge_indices[1][triplet_indices[1]]], # second
e[triplet_indices[0]],
e[triplet_indices[1]],
),
dim=-1,
).squeeze(-1)
triplet_classifier_inputs_2 = torch.cat(
(
h[edge_indices[0][triplet_indices[1]]], # shared
h[edge_indices[1][triplet_indices[1]]], # first
h[edge_indices[1][triplet_indices[0]]], # second
e[triplet_indices[1]],
e[triplet_indices[0]],
),
dim=-1,
).squeeze(-1)
output_1 = self.output_triplet_classifier(triplet_classifier_inputs_1).squeeze(
-1
)
output_2 = self.output_triplet_classifier(triplet_classifier_inputs_2).squeeze(
-1
)
return (output_1 + output_2) / 2
[docs] def triplet_output_step_elbow_right(self, h, e, edge_indices, triplet_indices):
assert torch.all(
edge_indices[1][triplet_indices[0]] == edge_indices[1][triplet_indices[1]]
)
if self.only_e:
triplet_classifier_inputs_1 = torch.cat(
(e[triplet_indices[0]], e[triplet_indices[1]]),
dim=-1,
)
triplet_classifier_inputs_2 = torch.cat(
(e[triplet_indices[1]], e[triplet_indices[0]]),
dim=-1,
)
else:
triplet_classifier_inputs_1 = torch.cat(
(
h[edge_indices[1][triplet_indices[0]]],
h[edge_indices[0][triplet_indices[0]]],
h[edge_indices[0][triplet_indices[1]]],
e[triplet_indices[0]],
e[triplet_indices[1]],
),
dim=-1,
)
triplet_classifier_inputs_2 = torch.cat(
(
h[edge_indices[1][triplet_indices[1]]],
h[edge_indices[0][triplet_indices[1]]],
h[edge_indices[0][triplet_indices[0]]],
e[triplet_indices[1]],
e[triplet_indices[0]],
),
dim=-1,
)
output_1 = self.output_triplet_classifier(triplet_classifier_inputs_1).squeeze(
-1
)
output_2 = self.output_triplet_classifier(triplet_classifier_inputs_2).squeeze(
-1
)
return (output_1 + output_2) / 2
[docs] def encoding_step(
self, x: torch.Tensor, start: torch.Tensor, end: torch.Tensor
) -> typing.Tuple[torch.Tensor, torch.Tensor]:
"""Initial encoding step of the GNN.
Args:
x: hit input features
start: start indices of the edges
end: end indices of the edges
Returns:
The node encodings ``h`` and edge encodings ``e``
"""
h = self.node_encoder(x)
e = self.edge_encoder(torch.cat((h[start], h[end]), dim=-1))
return h, e
[docs] def forward_edges(
self, x: torch.Tensor, start: torch.Tensor, end: torch.Tensor
) -> typing.Dict[str, torch.Tensor]:
"""Forwrd step for edge classification.
Args:
x: hit input features
start: start indices of the edges
end: end indices of the edges
Returns:
A tuple of 3 tensors: the hit encodings and edge encodings after message
passing, and the edge classifier output.
"""
# Encode the graph features into the hidden space
h, e = self.encoding_step(x=x, start=start, end=end)
# Loop over iterations of edge and node networks
for step in range(self.n_graph_iters):
h, e = self.message_step(h=h, start=start, end=end, e=e, step=step)
# Compute final edge scores; use original edge directions only
edge_output = self.output_step(h=h, start=start, end=end, e=e)
return {"h": h, "e": e, "edge_output": edge_output}
[docs] def forward_triplets(
self,
h: torch.Tensor,
e: torch.Tensor,
filtered_edge_index: torch.Tensor,
edge_mask: torch.Tensor,
dict_triplet_indices: typing.Dict[str, torch.Tensor],
**kwargs,
) -> typing.Dict[str, torch.Tensor]:
"""Forward step for triplet classification.
Args:
h: Hit encodings after the edge forward step
e: Edge encodings after the edge forward step
filtered_edge_index: edge index after requiring the minimal edge score
edge_mask: edge mask of edges kept after filtering
edge_score_cut: custom edge score cut. If not specified, the value is
taken from the internal ``hparams`` dictionary.
Returns:
A dictionary that associates ``articulation``, ``elbow_left`` and
``elbow_right`` with the logits of the corresponding triplets.
"""
dict_triplet_outputs: typing.Dict[str, torch.Tensor] = self.triplet_output_step(
h=h,
e=e[edge_mask],
edge_indices=filtered_edge_index,
dict_triplet_indices=dict_triplet_indices,
)
return dict_triplet_outputs
@property
def subnetwork_groups(self) -> typing.Dict[str, typing.List[str]]:
return {
**super(TripletInteractionGNN, self).subnetwork_groups,
"edge_split": ["encoder", "network", "edge_output_classifier"],
}
@property
def subnetwork_to_outputs(self) -> typing.Dict[str, typing.List[str]]:
return {
**super(TripletInteractionGNN, self).subnetwork_to_outputs,
"encoder": ["h", "e"],
"network": ["h", "e"],
"edge_output_classifier": ["edge_score"],
"edge_all": (
["e", "edge_score"] if self.only_e else ["h", "e", "edge_score"]
),
"triplet_output_classifier": ["triplet_score"],
}
@property
def input_kwargs(self) -> typing.Dict[str, typing.Any]:
return {
**super(TripletInteractionGNN, self).input_kwargs,
"message_in": dict(
size=(self._n_hits, self.n_hiddens), dtype=torch.float32
),
"message_out": dict(
size=(self._n_hits, self.n_hiddens), dtype=torch.float32
),
"h": dict(size=(self._n_hits, self.n_hiddens), dtype=torch.float32),
"e": dict(size=(self._n_edges, self.n_hiddens), dtype=torch.float32),
}
@property
def input_to_dynamic_axes(self):
"""A dictionary that associates an input name
with the dynamic axis specification.
"""
return {
**super(TripletInteractionGNN, self).input_to_dynamic_axes,
"h": {0: "n_hits"},
"e": {0: "n_edges"},
"message_in": {0: "n_hits"},
"message_out": {0: "n_hits"},
"edge_score": {0: "n_edges"},
"triplet_score": {0: "n_triplets"},
}
def _onnx_edge_output_classifier(
self, h: torch.Tensor, start: torch.Tensor, end: torch.Tensor, e: torch.Tensor
):
return torch.sigmoid(
self.output_edge_classifier(torch.cat((h[start], h[end], e), dim=-1))
)
def _onnx_encoder(
self, x: torch.Tensor, start: torch.Tensor, end: torch.Tensor
) -> typing.Tuple[torch.Tensor, torch.Tensor]:
return self.encoding_step(x=x, start=start, end=start)
def _onnx_network(
self,
h: torch.Tensor,
e: torch.Tensor,
start: torch.Tensor,
end: torch.Tensor,
message_in: torch.Tensor,
message_out: torch.Tensor,
) -> typing.Tuple[torch.Tensor, torch.Tensor]:
h = self.node_network(torch.cat((h, message_in, message_out), dim=-1)) + h
# Compute new edge features
e = self.edge_network(torch.cat([h[start], h[end], e], dim=-1)) + e
return h, e
def _onnx_triplet_output_classifier(
self, e: torch.Tensor, triplet_start: torch.Tensor, triplet_end: torch.Tensor
) -> torch.Tensor:
return torch.sigmoid(
self.output_triplet_classifier(
torch.cat((e[triplet_start], e[triplet_end]), dim=-1)
)
).squeeze(-1)
[docs] def to_onnx(
self,
outpath: str,
mode: str | None = None,
options: typing.Iterable[str] | None = None,
) -> None:
options = set() if options is None else set(options)
# If the mode corresponds to a groups of submodes, the options will
# be passed to each mode individually.
use_options = mode not in self.subnetwork_groups
if use_options:
# Propagate options.
# Discard the options that are used and pass them to the parent class method.
self._onnx_options = set(
option_name
for option_name in ["use_trt_scatter"]
if check_and_discard(options, option_name)
)
super().to_onnx(outpath, mode, options=options)
if use_options:
self._onnx_options = {}