"""A module that defines the various track builders.
"""
from __future__ import annotations
import typing
import torch
from torch_geometric.data import Data
from utils.modelutils.build import ModelBuilderBase
from utils.tools import tarray
from GNN.triplet_gnn_base import TripletGNNBase
from . import edges2tracks
from . import triplets2tracks, triplets
[docs]def batch_from_triplets_to_tracks(
batch: Data,
triplet_indices: typing.Dict[str, torch.Tensor],
edge_index: torch.Tensor,
triplet_scores: typing.Dict[str, torch.Tensor] | None = None,
triplet_score_cut: float | typing.Dict[str, float] | None = None,
edge_score: torch.Tensor | None = None,
single_edge_score_cut: float | None = None,
strategy: str | None = None,
) -> Data:
if strategy is None:
strategy = "1cc"
if triplet_score_cut is not None:
assert triplet_scores is not None, (
"As `triplet_score_cut` is provided, the triplet indices are filtered. "
"However, `triplet_scores` was not provided."
)
triplet_indices = triplets.get_filtered_triplet_indices(
triplet_indices=triplet_indices,
triplet_scores=triplet_scores,
triplet_score_cut=triplet_score_cut,
)
if strategy == "1cc":
df_tracks = triplets2tracks.build_tracks_from_triplets_1cc(
triplet_indices=triplet_indices,
edge_index=edge_index,
edge_score=edge_score,
single_edge_score_cut=single_edge_score_cut,
)
else:
df_tracks = triplets2tracks.build_tracks_from_triplets_2cc(
triplet_indices=triplet_indices,
edge_index=edge_index,
strategy=strategy,
)
if df_tracks.shape[0]:
hit_indices = tarray.series_to_tensor(df_tracks["hit_idx"]) # type: ignore
track_ids = tarray.series_to_tensor(df_tracks["track_id"]) # type: ignore
else:
hit_indices = torch.zeros(
size=(0,), dtype=batch["hit_id"].dtype, device=edge_index.device
)
track_ids = torch.zeros(
size=(0,), dtype=batch["hit_id"].dtype, device=edge_index.device
)
return Data(
hit_id=batch["hit_id"][hit_indices],
labels=track_ids,
event_str=batch["event_str"],
truncated_path=batch["truncated_path"],
# **{column: batch[column][hit_indices] for column in ["un_x", "un_y", "un_z"]},
)
[docs]def batch_from_edges_to_tracks(batch: Data, edge_index: torch.Tensor):
labels = edges2tracks.build_tracks_from_edges(
edge_index=edge_index, n_hits=batch["x"].shape[0]
)
batch["labels"] = labels
return batch
[docs]class EdgeTrackBuilder(ModelBuilderBase):
def __init__(self, model: TripletGNNBase, edge_score_cut: float) -> None:
"""Define the builder that allows to perform GNN inference to filter out
fake edges, then build tracks use a connected component algorithm.
This builder skips the triplet building. It is used to check how bad it performs
without it.
Args:
model: triplet GNN
edge_score_cut: minimal edge score to filter out the edges before building
the triplets
"""
super(EdgeTrackBuilder, self).__init__(model=model)
self.edge_score_cut = float(edge_score_cut)
[docs] def construct_downstream(self, batch: Data):
self.model.hparams["edge_score_cut"] = self.edge_score_cut
# Get output of the model
with torch.no_grad():
self.model: TripletGNNBase
outputs = self.model.shared_evaluation(batch, with_triplets=False)
edge_index = outputs["filtered_edge_index"].cpu()
return batch_from_edges_to_tracks(batch=batch, edge_index=edge_index)
[docs]class TripletTrackBuilder(ModelBuilderBase):
def __init__(
self,
model: TripletGNNBase,
edge_score_cut: float,
triplet_score_cut: float | typing.Dict[str, float],
single_edge_score_cut: float | None = None,
strategy: str | None = None,
) -> None:
"""Define the builder that allows to perform GNN inference to get filtered
graph of triplets, then build tracks from the triplets.
Args:
model: triplet GNN
edge_score_cut: minimal edge score to filter out the edges before building
the triplets
triplet_score_cut: minimal triplet score to get the graph of triplets
which the tracks are built from.
strategy: strategy used to build the tracks from triplets.
* ``1cc``: connect left and right elbows by assigning the smallest \
edge index to every each connected to an elbow. \
Only one connected component algorithm is applied using this \
strategy.
* ``2cc_without_articulation``: 2 connected components are
applied. \
The first connected component is applied on elbows only.
* ``2cc_no_multiple_central_hit``: the first connected component \
algorithm is applied to elbows and articulations that do not share \
a central hit
* ``2cc_no_multiple_central_hit``: the first connected component \
algorithm is applied to elbows and articulations that do not share \
their left or right edges
The default is ``1cc``.
"""
super(TripletTrackBuilder, self).__init__(model=model)
self.edge_score_cut = float(edge_score_cut)
self.triplet_score_cut = triplet_score_cut
self.single_edge_score_cut = single_edge_score_cut
self.strategy = strategy
[docs] def construct_downstream(self, batch: Data):
self.model.hparams["edge_score_cut"] = self.edge_score_cut
# Get output of the model
with torch.no_grad():
self.model: TripletGNNBase
outputs = self.model.shared_evaluation(batch, log=False, with_triplets=True)
triplet_indices = outputs["triplet_indices"]
triplet_scores = outputs["triplet_scores"]
edge_index = outputs["filtered_edge_index"]
def dict_to_cpu(dict_tensors):
return {key: value.cpu() for key, value in dict_tensors.items()}
return batch_from_triplets_to_tracks(
batch=batch.cpu(),
triplet_indices=dict_to_cpu(triplet_indices),
triplet_scores=dict_to_cpu(triplet_scores),
edge_index=edge_index.cpu(),
edge_score=(
outputs["filtered_edge_score"].cpu()
if self.single_edge_score_cut is not None
else None
),
triplet_score_cut=self.triplet_score_cut,
single_edge_score_cut=self.single_edge_score_cut,
strategy=self.strategy,
)
[docs] def load_batch(self, input_path: str) -> Data:
"""Load a PyTorch Data object from its path.
Might apply necessary pre-processing.
"""
return torch.load(
input_path, map_location="cuda" if torch.cuda.is_available() else "cpu"
)