#!/usr/bin/env python3
"""Script that runs the edge filtering, triplet building and filtering,
and track building from triplets.
"""
from __future__ import annotations
import typing
import warnings
import logging
from argparse import ArgumentParser
import torch
from pipeline import load_trained_model
from utils.commonutils.config import load_config
from utils.commonutils.ctests import get_required_test_dataset_names
from utils.commonutils.crun import run_for_different_partitions
from utils.scriptutils import configure_logger, headline
from utils.scriptutils.parser import add_predefined_arguments
from utils.modelutils import checkpoint_utils
from GNN.triplet_gnn_base import TripletGNNBase
from TrackBuilding.builder import TripletTrackBuilder, EdgeTrackBuilder
configure_logger()
warnings.filterwarnings(
"ignore", "None of the inputs have requires_grad=True. Gradients will be None"
)
[docs]def build(
path_or_config: str | dict,
partitions: typing.List[str] = ["train", "val", "test"],
checkpoint: TripletGNNBase | str | None = None,
reproduce: bool = True,
edge_score_cut: float | None = None,
triplet_score_cut: float | typing.Dict[str, float] | None = None,
single_edge_score_cut: float | None = None,
strategy: str | None = None,
with_triplets: bool = True,
**kwargs,
):
config = load_config(path_or_config)
logging.info(headline("Track Building"))
track_building_configs = config[
"track_building" if with_triplets else "track_building_from_edges"
]
edge_score_cut = (
track_building_configs["edge_score_cut"]
if edge_score_cut is None
else edge_score_cut
)
logging.info(headline("a) Loading trained model"))
if checkpoint is None:
gnn_version_dir = checkpoint_utils.get_last_version_dir_from_config(
step="gnn", path_or_config=path_or_config
)
gnn_artifact_path = checkpoint_utils.get_last_artifact(
version_dir=gnn_version_dir
)
checkpoint = gnn_artifact_path
if isinstance(checkpoint, TripletGNNBase):
gnn_model = checkpoint
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gnn_model = load_trained_model(
path_or_config=config, step="gnn", device=device, **kwargs
)
logging.info(f"GNN type: {gnn_model.__class__.__name__}")
logging.info(headline("b) Running inference"))
if with_triplets:
triplet_score_cut = (
track_building_configs["triplet_score_cut"]
if triplet_score_cut is None
else triplet_score_cut
)
strategy = (
track_building_configs.get("strategy") if strategy is None else strategy
)
trackBuilder = TripletTrackBuilder(
gnn_model,
edge_score_cut=edge_score_cut,
triplet_score_cut=triplet_score_cut,
strategy=strategy,
single_edge_score_cut=single_edge_score_cut,
)
else:
trackBuilder = EdgeTrackBuilder(
model=gnn_model,
edge_score_cut=edge_score_cut,
)
logging.info("Edge score cut: " + str(edge_score_cut))
logging.info("Triplet score cut: " + str(triplet_score_cut))
print(track_building_configs)
run_for_different_partitions(
trackBuilder.infer,
input_dir=track_building_configs["input_dir"],
output_dir=track_building_configs["output_dir"],
partitions=partitions,
test_dataset_names=get_required_test_dataset_names(config),
reproduce=reproduce,
)
if __name__ == "__main__":
parser = ArgumentParser("Run the GNN inference.")
add_predefined_arguments(parser, ["pipeline_config", "partitions", "reproduce"])
parser.add_argument("--edge_score_cut", help="Edge score cut.", required=False)
parser.add_argument(
"--triplet_score_cut", help="Triplet score cut.", required=False
)
parser.add_argument(
"--without_triplets",
help="Whether to reconstruct without triplets",
action="store_true",
)
parsed_args = parser.parse_args()
build(
path_or_config=parsed_args.pipeline_config,
partitions=parsed_args.partitions,
reproduce=parsed_args.reproduce,
edge_score_cut=parsed_args.edge_score_cut,
triplet_score_cut=parsed_args.triplet_score_cut,
with_triplets=not parsed_args.without_triplets,
)