Source code for pipeline.TrackBuilding.perfect_trackbuilding

"""Define the best tracking performance we can get.
"""
import torch
from torch_geometric.data import Data
from utils.modelutils.build import BuilderBase
from .builder import batch_from_edges_to_tracks
from GNN import perfect_gnn


[docs]class PerfectTrackBuildingBuilder(BuilderBase): def __init__(self, builder: str) -> None: super().__init__() self.builder = builder
[docs] def construct_downstream(self, batch: Data): batch["edge_index"] = batch["signal_true_edges"] batch["y"] = torch.ones(batch["edge_index"].shape[1], dtype=torch.bool) # Run perfect GNN inference if self.builder == "default": # Run track reconstruction return batch_from_edges_to_tracks( batch=batch, edge_index=batch["edge_index"] ) elif self.builder == "triplet": return perfect_gnn.PerfectTripletInferenceBuilder().construct_downstream( batch )
[docs] def load_batch(self, input_path: str) -> Data: return torch.load(input_path, map_location="cpu")