Source code for pipeline.GNN.triplet_gnn_base

"""A module that define :py:class:`.TripletGNNBase`, the base class of all
triplet-based GNNs in this repository.
"""

from __future__ import annotations
import typing

import numpy as np
from sklearn.metrics import roc_auc_score
import torch
from torch.utils.checkpoint import checkpoint
from torch_geometric.data import Data
from torchvision.ops import sigmoid_focal_loss

from utils.modelutils.basemodel import ModelBase, check_and_discard
from utils.modelutils.metrics import compute_classification_efficiency_purity
from utils.loaderutils.dataiterator import LazyDatasetBase
from utils.tools import tarray
from utils.graphutils import batch2df
from utils.graphutils.tripletbuilding import (
    from_edge_index_to_triplet_indices,
    get_triplet_truths_from_tensors,
)



[docs]def get_df_edges_from_batch_only(batch: Data) -> tarray.DataFrame: """Get the dataframe of edges with particle ID information.""" df_hits_particles = batch2df.get_df_hits_particles(batch=batch) df_edges = batch2df.get_df_edges(batch, df_hits_particles=df_hits_particles) return df_edges
[docs]class TripletGNNLazyDataset(LazyDatasetBase): def __init__( self, input_dir: str, n_events: int | None = None, shuffle: bool = False, seed: int | None = None, **kwargs, ): super().__init__(input_dir, n_events, shuffle, seed, **kwargs)
[docs] def fetch_dataset(self, input_path: str, **kwargs) -> Data: loaded_event = super(TripletGNNLazyDataset, self).fetch_dataset( input_path=input_path, # map_location="cuda" if torch.cuda.is_available() else "cpu", map_location="cpu", **kwargs, ) return loaded_event
def _get_triplet_scores( outputs: typing.Dict[str, typing.Dict[str, torch.Tensor]] ) -> typing.Dict[str, torch.Tensor]: return ( { triplet_name: torch.sigmoid(triplet_output) for triplet_name, triplet_output in outputs["triplet_outputs"].items() } if "triplet_scores" not in outputs else outputs["triplet_scores"] )
[docs]class TripletGNNBase(ModelBase): """The base class for triplet-base models, that first classify edges, then triplets.""" @property def loss(self) -> str: return self.hparams.get("loss", "focal")
[docs] def get_lazy_dataset(self, *args, **kwargs) -> TripletGNNLazyDataset: return TripletGNNLazyDataset(*args, **kwargs)
[docs] def get_lazy_dataset_partition( self, partition: str, *args, **kwargs ) -> LazyDatasetBase: if partition in ["train", "val"]: return super().get_lazy_dataset_partition( partition=partition, *args, **kwargs ) else: return super().get_lazy_dataset_partition( partition=partition, *args, **kwargs )
@property def with_triplets(self) -> bool: if (triplets_step := self.hparams.get("triplets_step")) is not None: if triplets_step == -1: return False else: return self.global_step >= triplets_step else: return True @property def triplet_checkpointing(self) -> bool: return self.hparams.get("triplet_checkpointing", True) @property def edge_checkpointing(self) -> bool: return self.hparams.get("edge_checkpointing", True)
[docs] def triplet_output_step_articulation(self, triplet_indices, *args, **kwargs): raise NotImplementedError("Triplet articulation step not implemented.")
[docs] def triplet_output_step_elbow_left(self, triplet_indices, *args, **kwargs): raise NotImplementedError("Triplet left elbow step not implemented.")
[docs] def triplet_output_step_elbow_right(self, triplet_indices, *args, **kwargs): raise NotImplementedError("Triplet right elbow step not implemented.")
[docs] def triplet_output_step( self, dict_triplet_indices, *args, **kwargs ) -> typing.Dict[str, torch.Tensor]: dict_triplet_outputs = { "articulation": self.triplet_output_step_articulation( triplet_indices=dict_triplet_indices["articulation"], *args, **kwargs, ), "elbow_left": self.triplet_output_step_elbow_left( triplet_indices=dict_triplet_indices["elbow_left"], *args, **kwargs, ), "elbow_right": self.triplet_output_step_elbow_right( triplet_indices=dict_triplet_indices["elbow_right"], *args, **kwargs, ), } return dict_triplet_outputs # type: ignore
[docs] def forward_edges( self, x: torch.Tensor, start: torch.Tensor, end: torch.Tensor, ) -> typing.Dict[str, torch.Tensor]: """Forward step for edge classification. Args: x: Hit features start: tensor of start indices end: tensor of edge indices Returns: A dictionary of tensors. Should at least contain ``edge_output``, the logits of each edges. """ raise NotImplementedError("Forward step for edges is not implemented.")
[docs] def forward_triplets( self, dict_triplet_indices: typing.Dict[str, torch.Tensor], *args, **kwargs, ) -> typing.Dict[str, torch.Tensor]: """Forward step for triplet building and classification. Args: dict_triplet_indices: associates ``articulation``, ``elbow_left`` and ``elbow_right`` with the corresponding triplet indices. args, kwargs: Other arguments to pass to the triplet output step. Returns: A dictionary that associates ``articulation``, ``elbow_left`` and ``elbow_right`` with the logits of the corresponding triplets. """ dict_triplet_outputs: typing.Dict[str, torch.Tensor] = self.triplet_output_step( dict_triplet_indices, *args, **kwargs ) return dict_triplet_outputs
[docs] def filter_edges( self, edge_index: torch.Tensor, edge_score: torch.Tensor, edge_score_cut: float | None = None, ) -> typing.Tuple[torch.Tensor, torch.Tensor]: edge_score_cut_: float = ( self.hparams["edge_score_cut"] if edge_score_cut is None else edge_score_cut ) edge_mask = edge_score > edge_score_cut_ filtered_edge_index = edge_index[:, edge_mask] return filtered_edge_index, edge_mask
[docs] def forward( self, x: torch.Tensor, edge_index: torch.Tensor, edge_score_cut: float | None = None, with_triplets: bool = True, ) -> typing.Dict[str, typing.Any]: """Forward step of the triplet-based Neural Network. 1. :py:func:`.forward_edges` method is called, and outputs the edge logits ``edge_output``, with possibly other tensors that can be used for the triplets. 2. The edges are filtered using the :py:func:`self.filter_edges_for_triplets` method. 3. The triplets are built using :py:func:`utils.graphutils.tripletbuilding.from_edge_index_to_triplet_indices` 4. :py:func:`.forward_triplets` method is called, and outputs the triplets logits ``triplet_outputs``. Args: x: node features edge_index: tensor with shape ``(2, n_edges)`` of the edge indices edge_score_cut: Minal edge score with_triplets: whether to include the triplet inference """ # 1. Edge inference === start, end = edge_index if self.edge_checkpointing: x.requires_grad = True edge_outputs: typing.Dict[str, torch.Tensor] = checkpoint( # type: ignore self.forward_edges, x=x, start=start, end=end, use_reentrant=False ) else: edge_outputs = self.forward_edges(x=x, start=start, end=end) x = None # type: ignore start = None end = None # 2. Edge filtering === edge_output = edge_outputs.pop("edge_output") filtered_edge_index, edge_mask = self.filter_edges( edge_index=edge_index, edge_score=torch.sigmoid(edge_output), edge_score_cut=edge_score_cut, ) #: Dictionary of tensors outputted by the forward method. outputs = {} outputs["edge_output"] = edge_output outputs["filtered_edge_index"] = filtered_edge_index outputs["edge_mask"] = edge_mask for output_name in ["extended_edge_outputs", "extended_indices"]: if (output_value := edge_outputs.pop(output_name, None)) is not None: outputs[output_name] = output_value if with_triplets: # 3. Triplet building === dict_triplet_indices = from_edge_index_to_triplet_indices( edge_index=filtered_edge_index ) # 4. triplet inference === #: Inputs to pass to the forward for triplets #: on top of `dict_triplet_indices`. _triplet_inputs = dict( # Filtered edge indices filtered_edge_index=filtered_edge_index, # Indices to filter edge-based tensors edge_mask=edge_mask, # Edge outputs given by the forward for edges **edge_outputs, ) if self.triplet_checkpointing: triplet_outputs: typing.Dict[str, torch.Tensor] = checkpoint( # type: ignore self.forward_triplets, use_reentrant=False, dict_triplet_indices=dict_triplet_indices, **_triplet_inputs, # type: ignore ) else: triplet_outputs = self.forward_triplets( dict_triplet_indices=dict_triplet_indices, **_triplet_inputs ) # Add triplet information to outputs outputs["triplet_outputs"] = triplet_outputs outputs["triplet_indices"] = dict_triplet_indices return outputs
[docs] def compute_normalised_loss( self, output: torch.Tensor, truth: torch.Tensor, ) -> torch.Tensor: """Compute typical weighted focal loss for given output and truth. Args: output: logits truth: targets Returns: Normalised sigmoid focal loss. """ loss_name = self.loss if loss_name == "focal": weights = (~truth).sum() / truth.shape[0] gamma = self.hparams.get("gamma", 2.0) loss = sigmoid_focal_loss( inputs=output, targets=truth.float(), alpha=weights, # type: ignore gamma=gamma, reduction="mean", ) elif loss_name == "cross_entropy": pos_weight = (~truth).sum() / truth.sum() loss = torch.nn.functional.binary_cross_entropy_with_logits( input=output, target=truth.float(), pos_weight=pos_weight, reduction="mean", ) else: raise ValueError(f"Loss `{loss_name}` is not recognised.") return loss
[docs] def inference( self, batch: Data, with_triplets: bool = True, with_triplet_truths: bool = False, edge_score_cut: float | None = None, ) -> typing.Dict[str, typing.Any]: """Run inference (without loss computation). Args: batch: event graph with_triplets: whether to include the forward step on triplets edge_score_cut: minimal edge score the edges are required to have Returns: Output of the forward step. """ input_data = self.get_features(batch) outputs = self.forward( x=input_data, edge_index=batch["edge_index"], edge_score_cut=edge_score_cut, with_triplets=with_triplets, ) if with_triplets and with_triplet_truths: outputs["triplet_truths"] = get_triplet_truths_from_tensors( triplet_indices=outputs["triplet_indices"], edge_index=outputs["filtered_edge_index"], edge_truth=batch["y"][outputs["edge_mask"]], particle_id_hit_idx=batch["particle_id_hit_idx"], ) return outputs
[docs] def common_training_validation_step( self, batch: Data, edge_score_cut: float | None = None, with_triplets: bool | None = None, compute_loss: bool = False, ) -> typing.Dict[str, typing.Any]: """Common forward step and loss computation for the training and validation steps. Args: batch: event graph with_triplets: whether to include the forward step on triplets edge_score_cut: minimal edge score the edges are required to have Returns: Output of the forward step and loss computation. """ if with_triplets is None: with_triplets = self.with_triplets edge_truth = batch.y outputs = self.inference( batch=batch, edge_score_cut=edge_score_cut, with_triplets=with_triplets, with_triplet_truths=with_triplets, ) edge_output = outputs["edge_output"] if compute_loss: edge_mask = ( ~batch[masked_edges_column] if (masked_edges_column := self.hparams.get("masked_edges")) else None ) if ( extended_edge_outputs := outputs.get("extended_edge_outputs") ) is not None: extended_indices = outputs.get("extended_indices") assert extended_indices is not None edge_losses = [] for extended_edge_output, extended_index in zip( extended_edge_outputs, extended_indices ): if edge_mask is not None: extended_edge_mask = edge_mask[extended_index] extended_edge_output = extended_edge_output[extended_edge_mask] extended_edge_truth = edge_truth[extended_index][ extended_edge_mask ] else: extended_edge_truth = edge_truth[extended_index] edge_losses.append( self.compute_normalised_loss( output=extended_edge_output, truth=extended_edge_truth, ) ) edge_loss = sum(edge_losses) else: if edge_mask is not None: edge_loss = self.compute_normalised_loss( output=edge_output[edge_mask], truth=edge_truth[edge_mask] ) else: edge_loss = self.compute_normalised_loss( output=edge_output, truth=edge_truth ) outputs["edge_loss"] = edge_loss if with_triplets: triplet_outputs: typing.Dict[str, torch.Tensor] = outputs[ "triplet_outputs" ] triplet_truths: typing.Dict[str, torch.Tensor] = outputs[ "triplet_truths" ] # Compute triplet truths --- # Get truth information about edges triplet_loss = self.compute_normalised_loss( output=torch.cat(tuple(triplet_outputs.values())), truth=torch.cat(tuple(triplet_truths.values())), ) outputs["triplet_loss"] = triplet_loss outputs["overall_loss"] = edge_loss + triplet_loss else: outputs["overall_loss"] = edge_loss outputs["edge_truth"] = edge_truth outputs["edge_score"] = torch.sigmoid(edge_output) return outputs
[docs] def training_step(self, batch, batch_idx): """Training step.""" try: outputs = self.common_training_validation_step( batch=batch, with_triplets=self.with_triplets, compute_loss=True ) overall_loss = outputs["overall_loss"] edge_loss = outputs["edge_loss"] edge_output = outputs["edge_output"] self.log( "train_loss_edge", edge_loss, on_epoch=True, on_step=self.on_step, batch_size=edge_output.shape[0], prog_bar=True, ) if self.with_triplets: triplet_loss = outputs["triplet_loss"] triplet_outputs = outputs["triplet_outputs"] self.log( "train_loss_triplet", triplet_loss, on_epoch=True, on_step=self.on_step, batch_size=sum( triplet_output.shape[0] for triplet_output in triplet_outputs.values() ), prog_bar=True, ) return overall_loss except MemoryError: logging.warning(f"Skipping {batch.event_str} due to out of memory") return None
[docs] def log_metrics_gen( self, loss: torch.Tensor, scores: torch.Tensor, predictions: torch.Tensor, truths: torch.Tensor, suffix: str = "", ) -> None: """Add entry to the log. Args: loss: overall loss scores: edge or triplet scores. Used to compute the AUC predictions: edge or triplet predicted targets truths: edge or triplet targets suffix: optional suffix, e.g., ``_edge`` or ``_triplet`` """ eff, pur = compute_classification_efficiency_purity( predictions=predictions, truths=truths, ) # Fix error: "ValueError: Only one class present in y_true. # ROC AUC score is not defined in that case" try: auc = roc_auc_score(truths.bool().cpu().detach(), scores.cpu().detach()) except ValueError: auc = np.nan current_lr = self.optimizers().param_groups[0]["lr"] # type: ignore self.log_dict( { f"val_loss{suffix}": loss, f"auc{suffix}": auc, # type: ignore f"eff{suffix}": eff, f"pur{suffix}": pur, "current_lr": current_lr, }, on_epoch=True, on_step=False, batch_size=predictions.shape[0], )
[docs] def shared_evaluation( self, batch: Data, log: bool = False, with_triplets: bool | None = None, ): """Evaluation step. Can be used for validation and test. Args: batch: event graph log: whether to add an entry to the log with_triplets: whether to include the triplet inference """ if with_triplets is None: with_triplets = self.with_triplets outputs = self.common_training_validation_step( batch=batch, with_triplets=with_triplets, compute_loss=log ) edge_score = outputs["edge_score"] if with_triplets: triplet_scores = _get_triplet_scores(outputs) outputs["triplet_scores"] = triplet_scores else: triplet_scores = None if log: self.log_metrics_gen( loss=outputs["edge_loss"], scores=edge_score, predictions=edge_score > self.hparams["edge_score_cut"], truths=outputs["edge_truth"], suffix="_edge", ) if with_triplets: assert triplet_scores is not None triplet_score = torch.cat(tuple(triplet_scores.values())) self.log_metrics_gen( loss=outputs["triplet_loss"], scores=triplet_score, predictions=triplet_score > self.hparams["triplet_score_cut"], truths=torch.cat(tuple(outputs["triplet_truths"].values())), suffix="_triplet", ) return outputs
[docs] def validation_step(self, batch, batch_idx): outputs = self.shared_evaluation(batch=batch, log=True) return outputs["overall_loss"]
[docs] def test_step(self, batch, batch_idx): outputs = self.shared_evaluation(batch=batch, log=False) return outputs
@property def _n_hits(self) -> int: """Dummy number of hits used for ONNX export.""" return 200 @property def _n_edges(self) -> int: """Dummy number of edges used for ONNX export.""" return 2000 @property def _n_triplets(self) -> int: """Dummy number of triplets used for ONNX export.""" return 1800 @property def n_hiddens(self) -> int: """Number of hidden units""" return self.hparams["hidden"] @property def input_kwargs(self) -> typing.Dict[str, typing.Any]: """Associates an input name with a dictionary corresponding to the keyword arguments used to build a dummy tensor representing the input. This dictionary basically gives the ``size`` and ``dtype`` of the tensor. """ return { "x": dict(size=(self._n_hits, 3), dtype=torch.float32), "start": dict(size=(self._n_edges,), dtype=torch.int64), "end": dict(size=(self._n_edges,), dtype=torch.int64), "triplet_start": dict(size=(self._n_triplets,), dtype=torch.int64), "triplet_end": dict(size=(self._n_triplets,), dtype=torch.int64), } @property def subnetwork_to_outputs(self) -> typing.Dict[str, typing.List[str]]: return {"edge": ["edge_score"]} @property def input_to_dynamic_axes(self): return { **super(TripletGNNBase, self).input_to_dynamic_axes, "x": {0: "n_hits"}, "start": {0: "n_edges"}, "end": {0: "n_edges"}, "e": {0: "n_edges"}, "triplet_start": {0: "n_triplets"}, "triplet_end": {0: "n_triplets"}, } def _onnx_edge( self, x: torch.Tensor, start: torch.Tensor, end: torch.Tensor ) -> torch.Tensor: """Forward pass for the ``edge`` subnetwork.""" edge_output = self.forward_edges(x, start, end)["edge_output"] return torch.sigmoid(edge_output) def _onnx_edge_all(self, x: torch.Tensor, start: torch.Tensor, end: torch.Tensor): output_names = self.subnetwork_to_outputs["edge_all"] outputs = self.forward_edges(x, start, end) return [ ( torch.sigmoid(outputs["edge_output"]) if output_name == "edge_score" else outputs[output_name] ) for output_name in output_names ]
[docs] def to_onnx( self, outpath: str, mode: str | None = None, options: typing.Iterable[str] | None = None, ) -> None: """Export the model to ONNX Args: outpath: path to the ONNX output file mode: subnetwork to save """ options = set() if options is None else set(options) use_options = mode not in self.subnetwork_groups if use_options: keep_int64 = check_and_discard(options, "keep_int64") super(TripletGNNBase, self).to_onnx(outpath=outpath, mode=mode, options=options) if use_options: if not keep_int64: from utils.modelutils.export import change_input_index_types change_input_index_types(outpath)