Source code for scripts.build_graph_using_embedding

#!/usr/bin/env python3
"""A script that runs the graph building using the embedding model learnt
at the previous stage.
"""
from __future__ import annotations
import typing
import logging
from argparse import ArgumentParser

import torch

from Embedding.embedding_base import EmbeddingBase
from Embedding.build_embedding import (
    EmbeddingInferenceBuilder,
    get_squared_distance_max_from_config,
)
from pipeline import load_trained_model
from utils.commonutils.config import load_config
from utils.commonutils.ctests import get_required_test_dataset_names
from utils.commonutils.crun import run_for_different_partitions
from utils.scriptutils import configure_logger, headline
from utils.scriptutils.parser import add_predefined_arguments

configure_logger()


[docs]def run( path_or_config: str | dict, partitions: typing.List[str] = ["train", "val", "test"], checkpoint: EmbeddingBase | str | None = None, reproduce: bool = True, use_gpu: bool = True, suffix: str | None = None, **kwargs, ): """Run the inference of the metric learning stage. Args: path_or_config: configuration dictionary, or path to the YAML file that contains the configuration partitions: Partitions to run the inference on: * ``train``: train dataset * ``val``: validation dataset * ``test``: all the test datasets * A specific test dataset name checkpoint: Model already loaded, or path to its checkpoint. If ``None``, try to find it automatically in the artifact folder given the configuration. reproduce: whether to delete an existing folder use_gpu: whether to use the GPU (if available) **kwargs: Other keyword arguments passed to the :py:func:`PyTorch.LightingModel.load_from_checkpoint` class method """ if suffix is None: suffix = "" config = load_config(path_or_config) logging.info(headline("Embedding Inference + kNN")) embedding_config = config["embedding" + suffix] logging.info(headline("a) Loading trained model")) if isinstance(checkpoint, EmbeddingBase): embedding_model = checkpoint else: device = torch.device( "cuda" if use_gpu and torch.cuda.is_available() else "cpu" ) embedding_model = load_trained_model( path_or_config=config, step="embedding", device=device, **kwargs ) logging.info(headline("b) Running inference")) squared_distance_max = get_squared_distance_max_from_config(path_or_config=config) test_processing = embedding_config.pop("test_processing", None) training_processing = embedding_config.pop("training_processing", None) logging.info(f"Use distance max {squared_distance_max}") graph_builder = EmbeddingInferenceBuilder( embedding_model, k_max=embedding_config["k_max"], squared_distance_max=squared_distance_max, ) run_for_different_partitions( graph_builder.infer, input_dir=embedding_config["input_dir"], output_dir=embedding_config["output_dir"], partitions=partitions, test_dataset_names=get_required_test_dataset_names(config), reproduce=reproduce, n_workers=embedding_config.get("n_workers", 1), list_kwargs=[ dict(processing=training_processing) if partition in ["train", "val"] else dict(processing=test_processing) for partition in partitions ], ) return graph_builder
if __name__ == "__main__": parser = ArgumentParser("Run the embedding inference.") add_predefined_arguments(parser, ["pipeline_config", "partitions", "reproduce"]) parsed_args = parser.parse_args() run( path_or_config=parsed_args.pipeline_config, partitions=parsed_args.partitions, reproduce=parsed_args.reproduce, )