Source code for pipeline.GNN.perfect_gnn
"""Replace the GNN by a perfect inference in order to understand
what is the best result that can be obtained with the current pipeline.
"""
import torch
from torch_geometric.data import Data
from utils.modelutils.build import BuilderBase
from utils.graphutils.tripletbuilding import (
from_edge_index_to_triplet_indices,
get_triplet_truths_from_tensors,
)
from utils.graphutils.edgeutils import compute_edge_labels_from_pid_only
from TrackBuilding.builder import batch_from_triplets_to_tracks
[docs]class PerfectInferenceBuilder(BuilderBase):
"""Generate perfect inference, that is, the edge score is equal to the truth."""
[docs] def construct_downstream(self, batch: Data, pid: bool = False):
if pid:
if "y_pid" not in batch:
batch["y_pid"] = compute_edge_labels_from_pid_only(
edge_indices=batch.edge_index,
particle_ids=batch.particle_id,
)
batch.scores = batch["y_pid"]
else:
batch.scores = batch.y
return batch
[docs]class PerfectTripletInferenceBuilder(BuilderBase):
[docs] def construct_downstream(self, batch: Data):
filtered_edge_index = batch["edge_index"][:, batch["y"]]
# Build triplets
triplet_indices = from_edge_index_to_triplet_indices(
edge_index=filtered_edge_index
)
triplet_truths = get_triplet_truths_from_tensors(
triplet_indices=triplet_indices,
edge_index=filtered_edge_index,
edge_truth=torch.ones(
size=(filtered_edge_index.shape[1],),
device=filtered_edge_index.device,
dtype=torch.bool,
),
particle_id_hit_idx=batch["particle_id_hit_idx"],
)
batch = batch_from_triplets_to_tracks(
batch=batch,
triplet_indices={
triplet_name: triplet_index[:, triplet_truths[triplet_name]]
for triplet_name, triplet_index in triplet_indices.items()
},
edge_index=filtered_edge_index,
)
return batch
[docs] def load_batch(self, input_path: str) -> Data:
"""Load a PyTorch Data object from its path.
Might apply necessary pre-processing.
"""
return torch.load(
input_path, map_location="cuda" if torch.cuda.is_available() else "cpu"
)