Source code for scripts.export_model_to_onnx

#!/usr/bin/env python3
"""A python script to export a model to an ONNX file.
"""
from __future__ import annotations
import typing
import os
from argparse import ArgumentParser, Namespace

import torch
import torchinfo

from pipeline import load_trained_model, instantiate_model_for_training
from utils.scriptutils.parser import add_predefined_arguments
from utils.modelutils.export import check_onnx_integrity
from utils.commonutils.config import load_config, cdirs


[docs]def export_model_to_onnx( path_or_config: str | dict, step: typing.Literal["embedding", "gnn"], mode: str | None = None, output_path: str | None = None, options: typing.Iterable[str] | None = None, dummy: bool = False, ) -> None: """Export a model of a pipeline to an ONNX file. Args: path_or_config: Path to the pipeline configuration file or the configuration dictionary. step: Model step, such as `embedding` or `gnn`. mode: Export mode. output_path: Path where to save the .onnx file containing the model. If not provided, it is defined from the experiment name and step. **options: export options """ config = load_config(path_or_config=path_or_config) model = ( instantiate_model_for_training(path_or_config=config, step=step) if dummy else load_trained_model(path_or_config=config, step=step) ) if torch.cuda.is_available(): model = model.cuda() experiment_name: str = config["common"]["experiment_name"] options = set() if options is None else set(options) # Print the summary of the model that is going to be exported torchinfo.summary(model) print("Export options:", options) subnetworks = ( model.subnetwork_groups.get(mode) if mode is not None and hasattr(model, "subnetwork_groups") else None ) if output_path is None: suffix = f"_{'_'.join(sorted(options))}" if options else "" print("suffix:", suffix) # Special case: if subnetworks: onnx_filename = f"{experiment_name}{suffix}_{{subnetwork}}.onnx" else: onnx_filename = ( f"{experiment_name}{suffix}.onnx" if mode is None else f"{experiment_name}{suffix}_{mode}.onnx" ) if dummy: onnx_filename = "dummy_" + onnx_filename output_path = os.path.join(cdirs.export_directory, step, onnx_filename) model.to_onnx(outpath=output_path, mode=mode, options=options) # Check model integrities. if subnetworks: for subnetwork in subnetworks: check_onnx_integrity(output_path.format(subnetwork=subnetwork)) else: check_onnx_integrity(output_path)
[docs]def get_parsed_args() -> Namespace: parser = ArgumentParser("Export an embedding network to an ONNX file.") add_predefined_arguments(parser, ["pipeline_config"]) parser.add_argument( "-s", "--step", required=True, help="Model step, such as `embedding` or `gnn`.", choices=["embedding", "gnn"], ) parser.add_argument( "-m", "--mode", required=False, help="Export mode.", ) parser.add_argument( "-o", "--output", required=False, help=( "Path where to save the .onnx file contains the embedding model. " "If not provided, it is defined from the experiment name and model " "parameters." ), ) parser.add_argument( "-t", "--options", required=False, nargs="+", help="List of ONNX export options", ) parser.add_argument( "-d", "--dummy", action="store_true", help="Whether to export an untrained model or not.", ) return parser.parse_args()
if __name__ == "__main__": parsed_args = get_parsed_args() config_path: str = parsed_args.pipeline_config step: typing.Literal["embedding", "gnn"] = parsed_args.step mode: str = parsed_args.mode output_path: str | None = parsed_args.output options: typing.List[str] | None = parsed_args.options dummy: bool = parsed_args.dummy export_model_to_onnx( path_or_config=config_path, step=step, mode=mode, output_path=output_path, options=options, dummy=dummy, )