#!/usr/bin/env python3
"""A python script to export a model to an ONNX file.
"""
from __future__ import annotations
import typing
import os
from argparse import ArgumentParser, Namespace
import torch
import torchinfo
from pipeline import load_trained_model, instantiate_model_for_training
from utils.scriptutils.parser import add_predefined_arguments
from utils.modelutils.export import check_onnx_integrity
from utils.commonutils.config import load_config, cdirs
[docs]def export_model_to_onnx(
path_or_config: str | dict,
step: typing.Literal["embedding", "gnn"],
mode: str | None = None,
output_path: str | None = None,
options: typing.Iterable[str] | None = None,
dummy: bool = False,
) -> None:
"""Export a model of a pipeline to an ONNX file.
Args:
path_or_config: Path to the pipeline configuration file or the
configuration dictionary.
step: Model step, such as `embedding` or `gnn`.
mode: Export mode.
output_path: Path where to save the .onnx file containing the model.
If not provided, it is defined from the experiment name and step.
**options: export options
"""
config = load_config(path_or_config=path_or_config)
model = (
instantiate_model_for_training(path_or_config=config, step=step)
if dummy
else load_trained_model(path_or_config=config, step=step)
)
if torch.cuda.is_available():
model = model.cuda()
experiment_name: str = config["common"]["experiment_name"]
options = set() if options is None else set(options)
# Print the summary of the model that is going to be exported
torchinfo.summary(model)
print("Export options:", options)
subnetworks = (
model.subnetwork_groups.get(mode)
if mode is not None and hasattr(model, "subnetwork_groups")
else None
)
if output_path is None:
suffix = f"_{'_'.join(sorted(options))}" if options else ""
print("suffix:", suffix)
# Special case:
if subnetworks:
onnx_filename = f"{experiment_name}{suffix}_{{subnetwork}}.onnx"
else:
onnx_filename = (
f"{experiment_name}{suffix}.onnx"
if mode is None
else f"{experiment_name}{suffix}_{mode}.onnx"
)
if dummy:
onnx_filename = "dummy_" + onnx_filename
output_path = os.path.join(cdirs.export_directory, step, onnx_filename)
model.to_onnx(outpath=output_path, mode=mode, options=options)
# Check model integrities.
if subnetworks:
for subnetwork in subnetworks:
check_onnx_integrity(output_path.format(subnetwork=subnetwork))
else:
check_onnx_integrity(output_path)
[docs]def get_parsed_args() -> Namespace:
parser = ArgumentParser("Export an embedding network to an ONNX file.")
add_predefined_arguments(parser, ["pipeline_config"])
parser.add_argument(
"-s",
"--step",
required=True,
help="Model step, such as `embedding` or `gnn`.",
choices=["embedding", "gnn"],
)
parser.add_argument(
"-m",
"--mode",
required=False,
help="Export mode.",
)
parser.add_argument(
"-o",
"--output",
required=False,
help=(
"Path where to save the .onnx file contains the embedding model. "
"If not provided, it is defined from the experiment name and model "
"parameters."
),
)
parser.add_argument(
"-t",
"--options",
required=False,
nargs="+",
help="List of ONNX export options",
)
parser.add_argument(
"-d",
"--dummy",
action="store_true",
help="Whether to export an untrained model or not.",
)
return parser.parse_args()
if __name__ == "__main__":
parsed_args = get_parsed_args()
config_path: str = parsed_args.pipeline_config
step: typing.Literal["embedding", "gnn"] = parsed_args.step
mode: str = parsed_args.mode
output_path: str | None = parsed_args.output
options: typing.List[str] | None = parsed_args.options
dummy: bool = parsed_args.dummy
export_model_to_onnx(
path_or_config=config_path,
step=step,
mode=mode,
output_path=output_path,
options=options,
dummy=dummy,
)