Source code for pipeline.TrackBuilding.perfect_trackbuilding
"""Define the best tracking performance we can get.
"""
import torch
from torch_geometric.data import Data
from utils.modelutils.build import BuilderBase
from .builder import batch_from_edges_to_tracks
from GNN import perfect_gnn
[docs]class PerfectTrackBuildingBuilder(BuilderBase):
def __init__(self, builder: str) -> None:
super().__init__()
self.builder = builder
[docs] def construct_downstream(self, batch: Data):
batch["edge_index"] = batch["signal_true_edges"]
batch["y"] = torch.ones(batch["edge_index"].shape[1], dtype=torch.bool)
# Run perfect GNN inference
if self.builder == "default":
# Run track reconstruction
return batch_from_edges_to_tracks(
batch=batch, edge_index=batch["edge_index"]
)
elif self.builder == "triplet":
return perfect_gnn.PerfectTripletInferenceBuilder().construct_downstream(
batch
)
[docs] def load_batch(self, input_path: str) -> Data:
return torch.load(input_path, map_location="cpu")