Source code for scripts.build_tracks

#!/usr/bin/env python3
"""Script that runs the edge filtering, triplet building and filtering,
and track building from triplets.
"""
from __future__ import annotations
import typing
import warnings
import logging
from argparse import ArgumentParser

import torch
from pipeline import load_trained_model

from utils.commonutils.config import load_config
from utils.commonutils.ctests import get_required_test_dataset_names
from utils.commonutils.crun import run_for_different_partitions
from utils.scriptutils import configure_logger, headline
from utils.scriptutils.parser import add_predefined_arguments
from utils.modelutils import checkpoint_utils

from GNN.triplet_gnn_base import TripletGNNBase
from TrackBuilding.builder import TripletTrackBuilder, EdgeTrackBuilder

configure_logger()


warnings.filterwarnings(
    "ignore", "None of the inputs have requires_grad=True. Gradients will be None"
)


[docs]def build( path_or_config: str | dict, partitions: typing.List[str] = ["train", "val", "test"], checkpoint: TripletGNNBase | str | None = None, reproduce: bool = True, edge_score_cut: float | None = None, triplet_score_cut: float | typing.Dict[str, float] | None = None, single_edge_score_cut: float | None = None, strategy: str | None = None, with_triplets: bool = True, **kwargs, ): config = load_config(path_or_config) logging.info(headline("Track Building")) track_building_configs = config[ "track_building" if with_triplets else "track_building_from_edges" ] edge_score_cut = ( track_building_configs["edge_score_cut"] if edge_score_cut is None else edge_score_cut ) logging.info(headline("a) Loading trained model")) if checkpoint is None: gnn_version_dir = checkpoint_utils.get_last_version_dir_from_config( step="gnn", path_or_config=path_or_config ) gnn_artifact_path = checkpoint_utils.get_last_artifact( version_dir=gnn_version_dir ) checkpoint = gnn_artifact_path if isinstance(checkpoint, TripletGNNBase): gnn_model = checkpoint else: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") gnn_model = load_trained_model( path_or_config=config, step="gnn", device=device, **kwargs ) logging.info(f"GNN type: {gnn_model.__class__.__name__}") logging.info(headline("b) Running inference")) if with_triplets: triplet_score_cut = ( track_building_configs["triplet_score_cut"] if triplet_score_cut is None else triplet_score_cut ) strategy = ( track_building_configs.get("strategy") if strategy is None else strategy ) trackBuilder = TripletTrackBuilder( gnn_model, edge_score_cut=edge_score_cut, triplet_score_cut=triplet_score_cut, strategy=strategy, single_edge_score_cut=single_edge_score_cut, ) else: trackBuilder = EdgeTrackBuilder( model=gnn_model, edge_score_cut=edge_score_cut, ) logging.info("Edge score cut: " + str(edge_score_cut)) logging.info("Triplet score cut: " + str(triplet_score_cut)) print(track_building_configs) run_for_different_partitions( trackBuilder.infer, input_dir=track_building_configs["input_dir"], output_dir=track_building_configs["output_dir"], partitions=partitions, test_dataset_names=get_required_test_dataset_names(config), reproduce=reproduce, )
if __name__ == "__main__": parser = ArgumentParser("Run the GNN inference.") add_predefined_arguments(parser, ["pipeline_config", "partitions", "reproduce"]) parser.add_argument("--edge_score_cut", help="Edge score cut.", required=False) parser.add_argument( "--triplet_score_cut", help="Triplet score cut.", required=False ) parser.add_argument( "--without_triplets", help="Whether to reconstruct without triplets", action="store_true", ) parsed_args = parser.parse_args() build( path_or_config=parsed_args.pipeline_config, partitions=parsed_args.partitions, reproduce=parsed_args.reproduce, edge_score_cut=parsed_args.edge_score_cut, triplet_score_cut=parsed_args.triplet_score_cut, with_triplets=not parsed_args.without_triplets, )