Source code for scripts.build_graph_using_embedding
#!/usr/bin/env python3
"""A script that runs the graph building using the embedding model learnt
at the previous stage.
"""
from __future__ import annotations
import typing
import logging
from argparse import ArgumentParser
import torch
from Embedding.embedding_base import EmbeddingBase
from Embedding.build_embedding import (
EmbeddingInferenceBuilder,
get_squared_distance_max_from_config,
)
from pipeline import load_trained_model
from utils.commonutils.config import load_config
from utils.commonutils.ctests import get_required_test_dataset_names
from utils.commonutils.crun import run_for_different_partitions
from utils.scriptutils import configure_logger, headline
from utils.scriptutils.parser import add_predefined_arguments
configure_logger()
[docs]def run(
path_or_config: str | dict,
partitions: typing.List[str] = ["train", "val", "test"],
checkpoint: EmbeddingBase | str | None = None,
reproduce: bool = True,
use_gpu: bool = True,
suffix: str | None = None,
**kwargs,
):
"""Run the inference of the metric learning stage.
Args:
path_or_config: configuration dictionary, or path to the YAML file that contains
the configuration
partitions: Partitions to run the inference on:
* ``train``: train dataset
* ``val``: validation dataset
* ``test``: all the test datasets
* A specific test dataset name
checkpoint: Model already loaded, or path to its checkpoint. If ``None``,
try to find it automatically in the artifact folder given
the configuration.
reproduce: whether to delete an existing folder
use_gpu: whether to use the GPU (if available)
**kwargs: Other keyword arguments passed to the
:py:func:`PyTorch.LightingModel.load_from_checkpoint` class method
"""
if suffix is None:
suffix = ""
config = load_config(path_or_config)
logging.info(headline("Embedding Inference + kNN"))
embedding_config = config["embedding" + suffix]
logging.info(headline("a) Loading trained model"))
if isinstance(checkpoint, EmbeddingBase):
embedding_model = checkpoint
else:
device = torch.device(
"cuda" if use_gpu and torch.cuda.is_available() else "cpu"
)
embedding_model = load_trained_model(
path_or_config=config, step="embedding", device=device, **kwargs
)
logging.info(headline("b) Running inference"))
squared_distance_max = get_squared_distance_max_from_config(path_or_config=config)
test_processing = embedding_config.pop("test_processing", None)
training_processing = embedding_config.pop("training_processing", None)
logging.info(f"Use distance max {squared_distance_max}")
graph_builder = EmbeddingInferenceBuilder(
embedding_model,
k_max=embedding_config["k_max"],
squared_distance_max=squared_distance_max,
)
run_for_different_partitions(
graph_builder.infer,
input_dir=embedding_config["input_dir"],
output_dir=embedding_config["output_dir"],
partitions=partitions,
test_dataset_names=get_required_test_dataset_names(config),
reproduce=reproduce,
n_workers=embedding_config.get("n_workers", 1),
list_kwargs=[
dict(processing=training_processing)
if partition in ["train", "val"]
else dict(processing=test_processing)
for partition in partitions
],
)
return graph_builder
if __name__ == "__main__":
parser = ArgumentParser("Run the embedding inference.")
add_predefined_arguments(parser, ["pipeline_config", "partitions", "reproduce"])
parsed_args = parser.parse_args()
run(
path_or_config=parsed_args.pipeline_config,
partitions=parsed_args.partitions,
reproduce=parsed_args.reproduce,
)