Source code for pipeline.TrackBuilding.perfect_trackbuilding
"""Define the best tracking performance we can get.
"""
import torch
from torch_geometric.data import Data
from utils.modelutils.build import BuilderBase
from .builder import batch_from_edges_to_tracks
from GNN import perfect_gnn
[docs]class PerfectTrackBuildingBuilder(BuilderBase):
    def __init__(self, builder: str) -> None:
        super().__init__()
        self.builder = builder
[docs]    def construct_downstream(self, batch: Data):
        batch["edge_index"] = batch["signal_true_edges"]
        batch["y"] = torch.ones(batch["edge_index"].shape[1], dtype=torch.bool)
        # Run perfect GNN inference
        if self.builder == "default":
            # Run track reconstruction
            return batch_from_edges_to_tracks(
                batch=batch, edge_index=batch["edge_index"]
            )
        elif self.builder == "triplet":
            return perfect_gnn.PerfectTripletInferenceBuilder().construct_downstream(
                batch
            ) 
[docs]    def load_batch(self, input_path: str) -> Data:
        return torch.load(input_path, map_location="cpu")