Source code for pipeline.GNN.perfect_gnn

"""Replace the GNN by a perfect inference in order to understand
what is the best result that can be obtained with the current pipeline.
"""

import torch
from torch_geometric.data import Data
from utils.modelutils.build import BuilderBase
from utils.graphutils.tripletbuilding import (
    from_edge_index_to_triplet_indices,
    get_triplet_truths_from_tensors,
)
from utils.graphutils.edgeutils import compute_edge_labels_from_pid_only
from TrackBuilding.builder import batch_from_triplets_to_tracks


[docs]class PerfectInferenceBuilder(BuilderBase): """Generate perfect inference, that is, the edge score is equal to the truth."""
[docs] def construct_downstream(self, batch: Data, pid: bool = False): if pid: if "y_pid" not in batch: batch["y_pid"] = compute_edge_labels_from_pid_only( edge_indices=batch.edge_index, particle_ids=batch.particle_id, ) batch.scores = batch["y_pid"] else: batch.scores = batch.y return batch
[docs]class PerfectTripletInferenceBuilder(BuilderBase):
[docs] def construct_downstream(self, batch: Data): filtered_edge_index = batch["edge_index"][:, batch["y"]] # Build triplets triplet_indices = from_edge_index_to_triplet_indices( edge_index=filtered_edge_index ) triplet_truths = get_triplet_truths_from_tensors( triplet_indices=triplet_indices, edge_index=filtered_edge_index, edge_truth=torch.ones( size=(filtered_edge_index.shape[1],), device=filtered_edge_index.device, dtype=torch.bool, ), particle_id_hit_idx=batch["particle_id_hit_idx"], ) batch = batch_from_triplets_to_tracks( batch=batch, triplet_indices={ triplet_name: triplet_index[:, triplet_truths[triplet_name]] for triplet_name, triplet_index in triplet_indices.items() }, edge_index=filtered_edge_index, ) return batch
[docs] def load_batch(self, input_path: str) -> Data: """Load a PyTorch Data object from its path. Might apply necessary pre-processing. """ return torch.load( input_path, map_location="cuda" if torch.cuda.is_available() else "cpu" )