Source code for pipeline.utils.modelutils.export
"""A python module that defines utilities to export a PyTorch model to ONNX.
"""
import onnx
import torch
from torch.onnx import symbolic_helper
[docs]def check_onnx_integrity(inpath: str) -> None:
"""Check the integrity of an ONNX model stored in ``inpath``."""
onnx_model = onnx.load(inpath)
onnx.checker.check_model(onnx_model, full_check=True)
[docs]def convert_model_to_fp16(inpath: str, outpath: str | None = None) -> None:
"""Convert an ONNX model to fp16.
Notes:
See https://onnxruntime.ai/docs/performance/model-optimizations/float16.html.
"""
from onnxconverter_common import float16
if outpath is None:
outpath = inpath
model = onnx.load(inpath)
model_fp16 = float16.convert_float_to_float16(model)
onnx.save(model_fp16, outpath)
[docs]class TRTScatterAddOp(torch.autograd.Function):
"""A fake scatter add operator for ONNX export, used with a custom
TensorRT plugin that implements the scatter add operation.
Notes:
For reference: https://leimao.github.io/blog/PyTorch-Custom-ONNX-Operator-Export/
"""
[docs] @staticmethod
def forward(ctx, source, index, h) -> torch.Tensor:
# return scatter_add(source, index, dim=0, dim_size=h.size(0))
return torch.zeros(
(h.size(0), source.size(1)), dtype=source.dtype, device=source.device
) # no need for the real operation here, just need the shape to be right
[docs] @staticmethod
def symbolic(g, source, index, h):
"""TensorRT exportable scatter add
Args:
g: populated graph
source: Source input tensor for the scattering
index: Index input tensor for the scattering
dim_size: Number of elements in the output tensor
"""
args = [source, index, h]
kwargs = {}
source_size = symbolic_helper._get_tensor_sizes(source)
h_size = symbolic_helper._get_tensor_sizes(h)
assert source_size is not None
assert h_size is not None
output_type = source.type().with_sizes([h_size[0], source_size[1]])
# return g.op("tensorrt_scatter::scatter_add", *args).setType(output_type)
return g.op("tensorrt_scatter::scatter_add", *args, **kwargs).setType(
output_type
)