Source code for pipeline.TrackBuilding.builder

"""A module that defines the various track builders.
"""
from __future__ import annotations
import typing

import torch
from torch_geometric.data import Data


from utils.modelutils.build import ModelBuilderBase
from utils.tools import tarray
from GNN.triplet_gnn_base import TripletGNNBase
from . import edges2tracks
from . import triplets2tracks, triplets


[docs]def batch_from_triplets_to_tracks( batch: Data, triplet_indices: typing.Dict[str, torch.Tensor], edge_index: torch.Tensor, triplet_scores: typing.Dict[str, torch.Tensor] | None = None, triplet_score_cut: float | typing.Dict[str, float] | None = None, edge_score: torch.Tensor | None = None, single_edge_score_cut: float | None = None, strategy: str | None = None, ) -> Data: if strategy is None: strategy = "1cc" if triplet_score_cut is not None: assert triplet_scores is not None, ( "As `triplet_score_cut` is provided, the triplet indices are filtered. " "However, `triplet_scores` was not provided." ) triplet_indices = triplets.get_filtered_triplet_indices( triplet_indices=triplet_indices, triplet_scores=triplet_scores, triplet_score_cut=triplet_score_cut, ) if strategy == "1cc": df_tracks = triplets2tracks.build_tracks_from_triplets_1cc( triplet_indices=triplet_indices, edge_index=edge_index, edge_score=edge_score, single_edge_score_cut=single_edge_score_cut, ) else: df_tracks = triplets2tracks.build_tracks_from_triplets_2cc( triplet_indices=triplet_indices, edge_index=edge_index, strategy=strategy, ) if df_tracks.shape[0]: hit_indices = tarray.series_to_tensor(df_tracks["hit_idx"]) # type: ignore track_ids = tarray.series_to_tensor(df_tracks["track_id"]) # type: ignore else: hit_indices = torch.zeros( size=(0,), dtype=batch["hit_id"].dtype, device=edge_index.device ) track_ids = torch.zeros( size=(0,), dtype=batch["hit_id"].dtype, device=edge_index.device ) return Data( hit_id=batch["hit_id"][hit_indices], labels=track_ids, event_str=batch["event_str"], truncated_path=batch["truncated_path"], # **{column: batch[column][hit_indices] for column in ["un_x", "un_y", "un_z"]}, )
[docs]def batch_from_edges_to_tracks(batch: Data, edge_index: torch.Tensor): labels = edges2tracks.build_tracks_from_edges( edge_index=edge_index, n_hits=batch["x"].shape[0] ) batch["labels"] = labels return batch
[docs]class EdgeTrackBuilder(ModelBuilderBase): def __init__(self, model: TripletGNNBase, edge_score_cut: float) -> None: """Define the builder that allows to perform GNN inference to filter out fake edges, then build tracks use a connected component algorithm. This builder skips the triplet building. It is used to check how bad it performs without it. Args: model: triplet GNN edge_score_cut: minimal edge score to filter out the edges before building the triplets """ super(EdgeTrackBuilder, self).__init__(model=model) self.edge_score_cut = float(edge_score_cut)
[docs] def construct_downstream(self, batch: Data): self.model.hparams["edge_score_cut"] = self.edge_score_cut # Get output of the model with torch.no_grad(): self.model: TripletGNNBase outputs = self.model.shared_evaluation(batch, with_triplets=False) edge_index = outputs["filtered_edge_index"].cpu() return batch_from_edges_to_tracks(batch=batch, edge_index=edge_index)
[docs]class TripletTrackBuilder(ModelBuilderBase): def __init__( self, model: TripletGNNBase, edge_score_cut: float, triplet_score_cut: float | typing.Dict[str, float], single_edge_score_cut: float | None = None, strategy: str | None = None, ) -> None: """Define the builder that allows to perform GNN inference to get filtered graph of triplets, then build tracks from the triplets. Args: model: triplet GNN edge_score_cut: minimal edge score to filter out the edges before building the triplets triplet_score_cut: minimal triplet score to get the graph of triplets which the tracks are built from. strategy: strategy used to build the tracks from triplets. * ``1cc``: connect left and right elbows by assigning the smallest \ edge index to every each connected to an elbow. \ Only one connected component algorithm is applied using this \ strategy. * ``2cc_without_articulation``: 2 connected components are applied. \ The first connected component is applied on elbows only. * ``2cc_no_multiple_central_hit``: the first connected component \ algorithm is applied to elbows and articulations that do not share \ a central hit * ``2cc_no_multiple_central_hit``: the first connected component \ algorithm is applied to elbows and articulations that do not share \ their left or right edges The default is ``1cc``. """ super(TripletTrackBuilder, self).__init__(model=model) self.edge_score_cut = float(edge_score_cut) self.triplet_score_cut = triplet_score_cut self.single_edge_score_cut = single_edge_score_cut self.strategy = strategy
[docs] def construct_downstream(self, batch: Data): self.model.hparams["edge_score_cut"] = self.edge_score_cut # Get output of the model with torch.no_grad(): self.model: TripletGNNBase outputs = self.model.shared_evaluation(batch, log=False, with_triplets=True) triplet_indices = outputs["triplet_indices"] triplet_scores = outputs["triplet_scores"] edge_index = outputs["filtered_edge_index"] def dict_to_cpu(dict_tensors): return {key: value.cpu() for key, value in dict_tensors.items()} return batch_from_triplets_to_tracks( batch=batch.cpu(), triplet_indices=dict_to_cpu(triplet_indices), triplet_scores=dict_to_cpu(triplet_scores), edge_index=edge_index.cpu(), edge_score=( outputs["filtered_edge_score"].cpu() if self.single_edge_score_cut is not None else None ), triplet_score_cut=self.triplet_score_cut, single_edge_score_cut=self.single_edge_score_cut, strategy=self.strategy, )
[docs] def load_batch(self, input_path: str) -> Data: """Load a PyTorch Data object from its path. Might apply necessary pre-processing. """ return torch.load( input_path, map_location="cuda" if torch.cuda.is_available() else "cpu" )