import typing
import torch
from torch_scatter import scatter_add
from utils.modelutils.mlp import make_mlp
from ..triplet_gnn_base import TripletGNNBase
[docs]class EdgeBasedGNN(TripletGNNBase):
"""A triplet-based GNN without using node encodings."""
def __init__(self, hparams: typing.Dict[str, typing.Any]):
super(TripletGNNBase, self).__init__(hparams)
nb_edge_layers: int = hparams["nb_edge_layers"]
nb_hidden: int = hparams["hidden"]
self.edge_encoder = make_mlp(
2 * self.get_n_features(),
[nb_hidden] * hparams.get("nb_edge_encoder_layers", nb_edge_layers),
layer_norm=hparams["layernorm"],
output_activation=None,
hidden_activation=hparams["hidden_activation"],
)
message_size = 4 if self.hparams["aggregation"] == "sum_max" else 2
# The node network computes new node features
self.edge_network = make_mlp(
(1 + message_size) * nb_hidden,
[nb_hidden] * nb_edge_layers,
layer_norm=hparams["layernorm"],
output_activation=None,
hidden_activation=hparams["hidden_activation"],
)
# Final edge output classification network
self.output_edge_classifier = make_mlp(
nb_hidden,
[nb_hidden] * hparams.get("nb_edge_classifier_layers", nb_edge_layers)
+ [1],
layer_norm=hparams["layernorm"],
output_activation=None,
hidden_activation=hparams["hidden_activation"],
)
self.output_triplet_classifier = make_mlp(
2 * nb_hidden,
[nb_hidden] * hparams.get("nb_edge_classifier_layers", nb_edge_layers)
+ [1],
layer_norm=hparams["layernorm"],
output_activation=None,
hidden_activation=hparams["hidden_activation"],
)
[docs] def message_step(
self, e: torch.Tensor, start: torch.Tensor, end: torch.Tensor, dim_size: int
) -> torch.Tensor:
# Compute messages
message_in = scatter_add(e, end, dim=0, dim_size=dim_size)
message_out = scatter_add(e, start, dim=0, dim_size=dim_size)
# Propagate messages to edges
e = (
self.edge_network(
torch.cat((e, message_in[start], message_out[end]), dim=-1)
)
+ e
)
return e
[docs] def forward_edges(
self, x: torch.Tensor, start: torch.Tensor, end: torch.Tensor
) -> typing.Dict[str, torch.Tensor]:
"""Forward step for edge classification.
Args:
x: Hit features
edge_index: Torch tensor with 2 rows that define the edges.
Returns:
A tuple of 3 tensors: the hit encodings and edge encodings after message
passing, and the edge classifier output.
"""
# Encode the graph features into the hidden space
e = self.edge_encoder(torch.cat((x[start], x[end]), dim=-1))
# Loop over iterations of edge and node networks
for _ in range(self.hparams["n_graph_iters"]):
e = self.message_step(e, start, end, dim_size=x.shape[0]) # type: ignore
# Compute final edge scores; use original edge directions only
edge_output = self.output_edge_classifier(e).squeeze(-1)
return {"e": e, "edge_output": edge_output}
[docs] def triplet_output_step_articulation(self, e, edge_indices, triplet_indices):
assert torch.all(
edge_indices[1][triplet_indices[0]] == edge_indices[0][triplet_indices[1]]
)
triplet_classifier_inputs = torch.cat(
(e[[triplet_indices[0]]], e[triplet_indices[1]]),
dim=-1,
)
return self.output_triplet_classifier(triplet_classifier_inputs).squeeze(-1)
[docs] def triplet_output_step_elbow_left(self, e, edge_indices, triplet_indices):
assert torch.all(
edge_indices[0][triplet_indices[0]] == edge_indices[0][triplet_indices[1]]
)
triplet_classifier_inputs_1 = torch.cat(
(e[[triplet_indices[0]]], e[triplet_indices[1]]),
dim=-1,
).squeeze(-1)
triplet_classifier_inputs_2 = torch.cat(
(e[triplet_indices[0]], e[triplet_indices[1]]),
dim=-1,
).squeeze(-1)
output_1 = self.output_triplet_classifier(triplet_classifier_inputs_1).squeeze(
-1
)
output_2 = self.output_triplet_classifier(triplet_classifier_inputs_2).squeeze(
-1
)
return (output_1 + output_2) / 2
[docs] def triplet_output_step_elbow_right(self, e, edge_indices, triplet_indices):
assert torch.all(
edge_indices[1][triplet_indices[0]] == edge_indices[1][triplet_indices[1]]
)
triplet_classifier_inputs_1 = torch.cat(
(e[[triplet_indices[0]]], e[triplet_indices[1]]),
dim=-1,
).squeeze(-1)
triplet_classifier_inputs_2 = torch.cat(
(e[triplet_indices[0]], e[triplet_indices[1]]),
dim=-1,
).squeeze(-1)
output_1 = self.output_triplet_classifier(triplet_classifier_inputs_1).squeeze(
-1
)
output_2 = self.output_triplet_classifier(triplet_classifier_inputs_2).squeeze(
-1
)
return (output_1 + output_2) / 2
[docs] def forward_triplets(
self,
e: torch.Tensor,
filtered_edge_index: torch.Tensor,
edge_mask: torch.Tensor,
dict_triplet_indices: typing.Dict[str, torch.Tensor],
**kwargs, # parameters ignored
) -> typing.Dict[str, torch.Tensor]:
"""Forward step for triplet classification.
Args:
e: Edge encodings after the edge forward step
filtered_edge_index: edge index after requiring the minimal edge score
edge_mask: edge mask of edges kept after filtering
edge_score_cut: custom edge score cut. If not specified, the value is
taken from the internal ``hparams`` dictionary.
Returns:
A dictionary that associates ``articulation``, ``elbow_left`` and
``elbow_right`` with the logits of the corresponding triplets.
"""
dict_triplet_outputs: typing.Dict[str, torch.Tensor] = self.triplet_output_step(
e=e[edge_mask],
edge_indices=filtered_edge_index,
dict_triplet_indices=dict_triplet_indices,
)
return dict_triplet_outputs
@property
def subnetworks(self) -> typing.List[str]:
return super(EdgeBasedGNN, self).subnetworks + [
"edge_encoder",
"edge_network",
"edge_output_classifier",
]
@property
def subnetwork_to_outputs(self) -> typing.Dict[str, typing.List[str]]:
return {
**super(EdgeBasedGNN, self).subnetwork_to_outputs,
"edge_encoder": ["e"],
"edge_network": ["e"],
"edge_output_classifier": ["edge_score"],
}
@property
def subnetwork_groups(self) -> typing.Dict[str, typing.List[str]]:
return {
**super(EdgeBasedGNN, self).subnetwork_groups,
"edge_split": ["edge_encoder", "edge_network", "edge_output_classifier"],
}
@property
def input_kwargs(self) -> typing.Dict[str, typing.Any]:
return {
**super(EdgeBasedGNN, self).input_kwargs,
"message_in": dict(
size=(self._n_hits, self.n_hiddens), dtype=torch.float32
),
"message_out": dict(
size=(self._n_hits, self.n_hiddens), dtype=torch.float32
),
"e": dict(size=(self._n_edges, self.n_hiddens), dtype=torch.float32),
}
@property
def input_to_dynamic_axes(self):
"""A dictionary that associates an input name
with the dynamic axis specification.
"""
return {
**super(EdgeBasedGNN, self).input_to_dynamic_axes,
"e": {0: "n_edges"},
"message_in": {0: "n_hits"},
"message_out": {0: "n_hits"},
"edge_score": {0: "n_edges"},
}
def _onnx_edge_output_classifier(self, e):
return torch.sigmoid(self.output_edge_classifier(e))
def _onnx_edge_encoder(
self, x: torch.Tensor, start: torch.Tensor, end: torch.Tensor
) -> torch.Tensor:
return self.edge_encoder(torch.cat((x[start], x[end]), dim=-1))
def _onnx_edge_network(
self,
e: torch.Tensor,
start: torch.Tensor,
end: torch.Tensor,
message_in: torch.Tensor,
message_out: torch.Tensor,
) -> torch.Tensor:
e = (
self.edge_network(
torch.cat((e, message_in[start], message_out[end]), dim=-1)
)
+ e
)
return e