"""A module that define :py:class:`.TripletGNNBase`, the base class of all
triplet-based GNNs in this repository.
"""
from __future__ import annotations
import typing
import numpy as np
from sklearn.metrics import roc_auc_score
import torch
from torch.utils.checkpoint import checkpoint
from torch_geometric.data import Data
from torchvision.ops import sigmoid_focal_loss
from utils.modelutils.basemodel import ModelBase, check_and_discard
from utils.modelutils.metrics import compute_classification_efficiency_purity
from utils.loaderutils.dataiterator import LazyDatasetBase
from utils.tools import tarray
from utils.graphutils import batch2df
from utils.graphutils.tripletbuilding import (
from_edge_index_to_triplet_indices,
get_triplet_truths_from_tensors,
)
[docs]def get_df_edges_from_batch_only(batch: Data) -> tarray.DataFrame:
"""Get the dataframe of edges with particle ID information."""
df_hits_particles = batch2df.get_df_hits_particles(batch=batch)
df_edges = batch2df.get_df_edges(batch, df_hits_particles=df_hits_particles)
return df_edges
[docs]class TripletGNNLazyDataset(LazyDatasetBase):
def __init__(
self,
input_dir: str,
n_events: int | None = None,
shuffle: bool = False,
seed: int | None = None,
**kwargs,
):
super().__init__(input_dir, n_events, shuffle, seed, **kwargs)
[docs] def fetch_dataset(self, input_path: str, **kwargs) -> Data:
loaded_event = super(TripletGNNLazyDataset, self).fetch_dataset(
input_path=input_path,
# map_location="cuda" if torch.cuda.is_available() else "cpu",
map_location="cpu",
**kwargs,
)
return loaded_event
def _get_triplet_scores(
outputs: typing.Dict[str, typing.Dict[str, torch.Tensor]]
) -> typing.Dict[str, torch.Tensor]:
return (
{
triplet_name: torch.sigmoid(triplet_output)
for triplet_name, triplet_output in outputs["triplet_outputs"].items()
}
if "triplet_scores" not in outputs
else outputs["triplet_scores"]
)
[docs]class TripletGNNBase(ModelBase):
"""The base class for triplet-base models, that first classify edges, then triplets."""
@property
def loss(self) -> str:
return self.hparams.get("loss", "focal")
[docs] def get_lazy_dataset(self, *args, **kwargs) -> TripletGNNLazyDataset:
return TripletGNNLazyDataset(*args, **kwargs)
[docs] def get_lazy_dataset_partition(
self, partition: str, *args, **kwargs
) -> LazyDatasetBase:
if partition in ["train", "val"]:
return super().get_lazy_dataset_partition(
partition=partition, *args, **kwargs
)
else:
return super().get_lazy_dataset_partition(
partition=partition, *args, **kwargs
)
@property
def with_triplets(self) -> bool:
if (triplets_step := self.hparams.get("triplets_step")) is not None:
if triplets_step == -1:
return False
else:
return self.global_step >= triplets_step
else:
return True
@property
def triplet_checkpointing(self) -> bool:
return self.hparams.get("triplet_checkpointing", True)
@property
def edge_checkpointing(self) -> bool:
return self.hparams.get("edge_checkpointing", True)
[docs] def triplet_output_step_articulation(self, triplet_indices, *args, **kwargs):
raise NotImplementedError("Triplet articulation step not implemented.")
[docs] def triplet_output_step_elbow_left(self, triplet_indices, *args, **kwargs):
raise NotImplementedError("Triplet left elbow step not implemented.")
[docs] def triplet_output_step_elbow_right(self, triplet_indices, *args, **kwargs):
raise NotImplementedError("Triplet right elbow step not implemented.")
[docs] def triplet_output_step(
self, dict_triplet_indices, *args, **kwargs
) -> typing.Dict[str, torch.Tensor]:
dict_triplet_outputs = {
"articulation": self.triplet_output_step_articulation(
triplet_indices=dict_triplet_indices["articulation"],
*args,
**kwargs,
),
"elbow_left": self.triplet_output_step_elbow_left(
triplet_indices=dict_triplet_indices["elbow_left"],
*args,
**kwargs,
),
"elbow_right": self.triplet_output_step_elbow_right(
triplet_indices=dict_triplet_indices["elbow_right"],
*args,
**kwargs,
),
}
return dict_triplet_outputs # type: ignore
[docs] def forward_edges(
self,
x: torch.Tensor,
start: torch.Tensor,
end: torch.Tensor,
) -> typing.Dict[str, torch.Tensor]:
"""Forward step for edge classification.
Args:
x: Hit features
start: tensor of start indices
end: tensor of edge indices
Returns:
A dictionary of tensors. Should at least contain ``edge_output``,
the logits of each edges.
"""
raise NotImplementedError("Forward step for edges is not implemented.")
[docs] def forward_triplets(
self,
dict_triplet_indices: typing.Dict[str, torch.Tensor],
*args,
**kwargs,
) -> typing.Dict[str, torch.Tensor]:
"""Forward step for triplet building and classification.
Args:
dict_triplet_indices: associates ``articulation``, ``elbow_left``
and ``elbow_right`` with the corresponding triplet indices.
args, kwargs: Other arguments to pass to the triplet output step.
Returns:
A dictionary that associates ``articulation``, ``elbow_left`` and
``elbow_right`` with the logits of the corresponding triplets.
"""
dict_triplet_outputs: typing.Dict[str, torch.Tensor] = self.triplet_output_step(
dict_triplet_indices, *args, **kwargs
)
return dict_triplet_outputs
[docs] def filter_edges(
self,
edge_index: torch.Tensor,
edge_score: torch.Tensor,
edge_score_cut: float | None = None,
) -> typing.Tuple[torch.Tensor, torch.Tensor]:
edge_score_cut_: float = (
self.hparams["edge_score_cut"] if edge_score_cut is None else edge_score_cut
)
edge_mask = edge_score > edge_score_cut_
filtered_edge_index = edge_index[:, edge_mask]
return filtered_edge_index, edge_mask
[docs] def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
edge_score_cut: float | None = None,
with_triplets: bool = True,
) -> typing.Dict[str, typing.Any]:
"""Forward step of the triplet-based Neural Network.
1. :py:func:`.forward_edges` method is called, and outputs
the edge logits ``edge_output``, with possibly other tensors
that can be used for the triplets.
2. The edges are filtered using the :py:func:`self.filter_edges_for_triplets`
method.
3. The triplets are built using
:py:func:`utils.graphutils.tripletbuilding.from_edge_index_to_triplet_indices`
4. :py:func:`.forward_triplets` method is called, and outputs
the triplets logits ``triplet_outputs``.
Args:
x: node features
edge_index: tensor with shape ``(2, n_edges)`` of the edge indices
edge_score_cut: Minal edge score
with_triplets: whether to include the triplet inference
"""
# 1. Edge inference ===
start, end = edge_index
if self.edge_checkpointing:
x.requires_grad = True
edge_outputs: typing.Dict[str, torch.Tensor] = checkpoint( # type: ignore
self.forward_edges, x=x, start=start, end=end, use_reentrant=False
)
else:
edge_outputs = self.forward_edges(x=x, start=start, end=end)
x = None # type: ignore
start = None
end = None
# 2. Edge filtering ===
edge_output = edge_outputs.pop("edge_output")
filtered_edge_index, edge_mask = self.filter_edges(
edge_index=edge_index,
edge_score=torch.sigmoid(edge_output),
edge_score_cut=edge_score_cut,
)
#: Dictionary of tensors outputted by the forward method.
outputs = {}
outputs["edge_output"] = edge_output
outputs["filtered_edge_index"] = filtered_edge_index
outputs["edge_mask"] = edge_mask
for output_name in ["extended_edge_outputs", "extended_indices"]:
if (output_value := edge_outputs.pop(output_name, None)) is not None:
outputs[output_name] = output_value
if with_triplets:
# 3. Triplet building ===
dict_triplet_indices = from_edge_index_to_triplet_indices(
edge_index=filtered_edge_index
)
# 4. triplet inference ===
#: Inputs to pass to the forward for triplets
#: on top of `dict_triplet_indices`.
_triplet_inputs = dict(
# Filtered edge indices
filtered_edge_index=filtered_edge_index,
# Indices to filter edge-based tensors
edge_mask=edge_mask,
# Edge outputs given by the forward for edges
**edge_outputs,
)
if self.triplet_checkpointing:
triplet_outputs: typing.Dict[str, torch.Tensor] = checkpoint( # type: ignore
self.forward_triplets,
use_reentrant=False,
dict_triplet_indices=dict_triplet_indices,
**_triplet_inputs, # type: ignore
)
else:
triplet_outputs = self.forward_triplets(
dict_triplet_indices=dict_triplet_indices, **_triplet_inputs
)
# Add triplet information to outputs
outputs["triplet_outputs"] = triplet_outputs
outputs["triplet_indices"] = dict_triplet_indices
return outputs
[docs] def compute_normalised_loss(
self,
output: torch.Tensor,
truth: torch.Tensor,
) -> torch.Tensor:
"""Compute typical weighted focal loss for given output and truth.
Args:
output: logits
truth: targets
Returns:
Normalised sigmoid focal loss.
"""
loss_name = self.loss
if loss_name == "focal":
weights = (~truth).sum() / truth.shape[0]
gamma = self.hparams.get("gamma", 2.0)
loss = sigmoid_focal_loss(
inputs=output,
targets=truth.float(),
alpha=weights, # type: ignore
gamma=gamma,
reduction="mean",
)
elif loss_name == "cross_entropy":
pos_weight = (~truth).sum() / truth.sum()
loss = torch.nn.functional.binary_cross_entropy_with_logits(
input=output,
target=truth.float(),
pos_weight=pos_weight,
reduction="mean",
)
else:
raise ValueError(f"Loss `{loss_name}` is not recognised.")
return loss
[docs] def inference(
self,
batch: Data,
with_triplets: bool = True,
with_triplet_truths: bool = False,
edge_score_cut: float | None = None,
) -> typing.Dict[str, typing.Any]:
"""Run inference (without loss computation).
Args:
batch: event graph
with_triplets: whether to include the forward step on triplets
edge_score_cut: minimal edge score the edges are required to have
Returns:
Output of the forward step.
"""
input_data = self.get_features(batch)
outputs = self.forward(
x=input_data,
edge_index=batch["edge_index"],
edge_score_cut=edge_score_cut,
with_triplets=with_triplets,
)
if with_triplets and with_triplet_truths:
outputs["triplet_truths"] = get_triplet_truths_from_tensors(
triplet_indices=outputs["triplet_indices"],
edge_index=outputs["filtered_edge_index"],
edge_truth=batch["y"][outputs["edge_mask"]],
particle_id_hit_idx=batch["particle_id_hit_idx"],
)
return outputs
[docs] def common_training_validation_step(
self,
batch: Data,
edge_score_cut: float | None = None,
with_triplets: bool | None = None,
compute_loss: bool = False,
) -> typing.Dict[str, typing.Any]:
"""Common forward step and loss computation for the training and validation
steps.
Args:
batch: event graph
with_triplets: whether to include the forward step on triplets
edge_score_cut: minimal edge score the edges are required to have
Returns:
Output of the forward step and loss computation.
"""
if with_triplets is None:
with_triplets = self.with_triplets
edge_truth = batch.y
outputs = self.inference(
batch=batch,
edge_score_cut=edge_score_cut,
with_triplets=with_triplets,
with_triplet_truths=with_triplets,
)
edge_output = outputs["edge_output"]
if compute_loss:
edge_mask = (
~batch[masked_edges_column]
if (masked_edges_column := self.hparams.get("masked_edges"))
else None
)
if (
extended_edge_outputs := outputs.get("extended_edge_outputs")
) is not None:
extended_indices = outputs.get("extended_indices")
assert extended_indices is not None
edge_losses = []
for extended_edge_output, extended_index in zip(
extended_edge_outputs, extended_indices
):
if edge_mask is not None:
extended_edge_mask = edge_mask[extended_index]
extended_edge_output = extended_edge_output[extended_edge_mask]
extended_edge_truth = edge_truth[extended_index][
extended_edge_mask
]
else:
extended_edge_truth = edge_truth[extended_index]
edge_losses.append(
self.compute_normalised_loss(
output=extended_edge_output,
truth=extended_edge_truth,
)
)
edge_loss = sum(edge_losses)
else:
if edge_mask is not None:
edge_loss = self.compute_normalised_loss(
output=edge_output[edge_mask], truth=edge_truth[edge_mask]
)
else:
edge_loss = self.compute_normalised_loss(
output=edge_output, truth=edge_truth
)
outputs["edge_loss"] = edge_loss
if with_triplets:
triplet_outputs: typing.Dict[str, torch.Tensor] = outputs[
"triplet_outputs"
]
triplet_truths: typing.Dict[str, torch.Tensor] = outputs[
"triplet_truths"
]
# Compute triplet truths ---
# Get truth information about edges
triplet_loss = self.compute_normalised_loss(
output=torch.cat(tuple(triplet_outputs.values())),
truth=torch.cat(tuple(triplet_truths.values())),
)
outputs["triplet_loss"] = triplet_loss
outputs["overall_loss"] = edge_loss + triplet_loss
else:
outputs["overall_loss"] = edge_loss
outputs["edge_truth"] = edge_truth
outputs["edge_score"] = torch.sigmoid(edge_output)
return outputs
[docs] def training_step(self, batch, batch_idx):
"""Training step."""
try:
outputs = self.common_training_validation_step(
batch=batch, with_triplets=self.with_triplets, compute_loss=True
)
overall_loss = outputs["overall_loss"]
edge_loss = outputs["edge_loss"]
edge_output = outputs["edge_output"]
self.log(
"train_loss_edge",
edge_loss,
on_epoch=True,
on_step=self.on_step,
batch_size=edge_output.shape[0],
prog_bar=True,
)
if self.with_triplets:
triplet_loss = outputs["triplet_loss"]
triplet_outputs = outputs["triplet_outputs"]
self.log(
"train_loss_triplet",
triplet_loss,
on_epoch=True,
on_step=self.on_step,
batch_size=sum(
triplet_output.shape[0]
for triplet_output in triplet_outputs.values()
),
prog_bar=True,
)
return overall_loss
except MemoryError:
logging.warning(f"Skipping {batch.event_str} due to out of memory")
return None
[docs] def log_metrics_gen(
self,
loss: torch.Tensor,
scores: torch.Tensor,
predictions: torch.Tensor,
truths: torch.Tensor,
suffix: str = "",
) -> None:
"""Add entry to the log.
Args:
loss: overall loss
scores: edge or triplet scores. Used to compute the AUC
predictions: edge or triplet predicted targets
truths: edge or triplet targets
suffix: optional suffix, e.g., ``_edge`` or ``_triplet``
"""
eff, pur = compute_classification_efficiency_purity(
predictions=predictions,
truths=truths,
)
# Fix error: "ValueError: Only one class present in y_true.
# ROC AUC score is not defined in that case"
try:
auc = roc_auc_score(truths.bool().cpu().detach(), scores.cpu().detach())
except ValueError:
auc = np.nan
current_lr = self.optimizers().param_groups[0]["lr"] # type: ignore
self.log_dict(
{
f"val_loss{suffix}": loss,
f"auc{suffix}": auc, # type: ignore
f"eff{suffix}": eff,
f"pur{suffix}": pur,
"current_lr": current_lr,
},
on_epoch=True,
on_step=False,
batch_size=predictions.shape[0],
)
[docs] def shared_evaluation(
self,
batch: Data,
log: bool = False,
with_triplets: bool | None = None,
):
"""Evaluation step. Can be used for validation and test.
Args:
batch: event graph
log: whether to add an entry to the log
with_triplets: whether to include the triplet inference
"""
if with_triplets is None:
with_triplets = self.with_triplets
outputs = self.common_training_validation_step(
batch=batch, with_triplets=with_triplets, compute_loss=log
)
edge_score = outputs["edge_score"]
if with_triplets:
triplet_scores = _get_triplet_scores(outputs)
outputs["triplet_scores"] = triplet_scores
else:
triplet_scores = None
if log:
self.log_metrics_gen(
loss=outputs["edge_loss"],
scores=edge_score,
predictions=edge_score > self.hparams["edge_score_cut"],
truths=outputs["edge_truth"],
suffix="_edge",
)
if with_triplets:
assert triplet_scores is not None
triplet_score = torch.cat(tuple(triplet_scores.values()))
self.log_metrics_gen(
loss=outputs["triplet_loss"],
scores=triplet_score,
predictions=triplet_score > self.hparams["triplet_score_cut"],
truths=torch.cat(tuple(outputs["triplet_truths"].values())),
suffix="_triplet",
)
return outputs
[docs] def validation_step(self, batch, batch_idx):
outputs = self.shared_evaluation(batch=batch, log=True)
return outputs["overall_loss"]
[docs] def test_step(self, batch, batch_idx):
outputs = self.shared_evaluation(batch=batch, log=False)
return outputs
@property
def _n_hits(self) -> int:
"""Dummy number of hits used for ONNX export."""
return 200
@property
def _n_edges(self) -> int:
"""Dummy number of edges used for ONNX export."""
return 2000
@property
def _n_triplets(self) -> int:
"""Dummy number of triplets used for ONNX export."""
return 1800
@property
def n_hiddens(self) -> int:
"""Number of hidden units"""
return self.hparams["hidden"]
@property
def input_kwargs(self) -> typing.Dict[str, typing.Any]:
"""Associates an input name with a dictionary corresponding to
the keyword arguments used to build a dummy tensor representing the input.
This dictionary basically gives the ``size`` and ``dtype`` of the tensor.
"""
return {
"x": dict(size=(self._n_hits, 3), dtype=torch.float32),
"start": dict(size=(self._n_edges,), dtype=torch.int64),
"end": dict(size=(self._n_edges,), dtype=torch.int64),
"triplet_start": dict(size=(self._n_triplets,), dtype=torch.int64),
"triplet_end": dict(size=(self._n_triplets,), dtype=torch.int64),
}
@property
def subnetwork_to_outputs(self) -> typing.Dict[str, typing.List[str]]:
return {"edge": ["edge_score"]}
@property
def input_to_dynamic_axes(self):
return {
**super(TripletGNNBase, self).input_to_dynamic_axes,
"x": {0: "n_hits"},
"start": {0: "n_edges"},
"end": {0: "n_edges"},
"e": {0: "n_edges"},
"triplet_start": {0: "n_triplets"},
"triplet_end": {0: "n_triplets"},
}
def _onnx_edge(
self, x: torch.Tensor, start: torch.Tensor, end: torch.Tensor
) -> torch.Tensor:
"""Forward pass for the ``edge`` subnetwork."""
edge_output = self.forward_edges(x, start, end)["edge_output"]
return torch.sigmoid(edge_output)
def _onnx_edge_all(self, x: torch.Tensor, start: torch.Tensor, end: torch.Tensor):
output_names = self.subnetwork_to_outputs["edge_all"]
outputs = self.forward_edges(x, start, end)
return [
(
torch.sigmoid(outputs["edge_output"])
if output_name == "edge_score"
else outputs[output_name]
)
for output_name in output_names
]
[docs] def to_onnx(
self,
outpath: str,
mode: str | None = None,
options: typing.Iterable[str] | None = None,
) -> None:
"""Export the model to ONNX
Args:
outpath: path to the ONNX output file
mode: subnetwork to save
"""
options = set() if options is None else set(options)
use_options = mode not in self.subnetwork_groups
if use_options:
keep_int64 = check_and_discard(options, "keep_int64")
super(TripletGNNBase, self).to_onnx(outpath=outpath, mode=mode, options=options)
if use_options:
if not keep_int64:
from utils.modelutils.export import change_input_index_types
change_input_index_types(outpath)