Source code for pipeline.utils.modelutils.export

"""A python module that defines utilities to export a PyTorch model to ONNX.
"""

import onnx
import torch
from torch.onnx import symbolic_helper


[docs]def check_onnx_integrity(inpath: str) -> None: """Check the integrity of an ONNX model stored in ``inpath``.""" onnx_model = onnx.load(inpath) onnx.checker.check_model(onnx_model, full_check=True)
[docs]def change_input_index_types( inpath: str, target_type: int = onnx.TensorProto.INT32, outpath: str | None = None, ) -> None: """In PyTorch, indices must be INT64. This function loop over the input nodes of a an ONNX mode and turn the input nodes with type :py:data:`onnx.TensorProto.INT64` to the ``target_type``. Args: inpath: path to the ONNX file target_type: type to assign to the index input nodes outpath: path where to save the altered ONNX model. If not provided, the model is saved to ``inpath``. """ if outpath is None: outpath = inpath onnx_model = onnx.load(inpath) for input in onnx_model.graph.input: if input.type.tensor_type.elem_type == onnx.TensorProto.INT64: input.type.tensor_type.elem_type = target_type onnx.save(proto=onnx_model, f=outpath)
[docs]def convert_model_to_fp16(inpath: str, outpath: str | None = None) -> None: """Convert an ONNX model to fp16. Notes: See https://onnxruntime.ai/docs/performance/model-optimizations/float16.html. """ from onnxconverter_common import float16 if outpath is None: outpath = inpath model = onnx.load(inpath) model_fp16 = float16.convert_float_to_float16(model) onnx.save(model_fp16, outpath)
[docs]class TRTScatterAddOp(torch.autograd.Function): """A fake scatter add operator for ONNX export, used with a custom TensorRT plugin that implements the scatter add operation. Notes: For reference: https://leimao.github.io/blog/PyTorch-Custom-ONNX-Operator-Export/ """
[docs] @staticmethod def forward(ctx, source, index, h) -> torch.Tensor: # return scatter_add(source, index, dim=0, dim_size=h.size(0)) return torch.zeros( (h.size(0), source.size(1)), dtype=source.dtype, device=source.device ) # no need for the real operation here, just need the shape to be right
[docs] @staticmethod def symbolic(g, source, index, h): """TensorRT exportable scatter add Args: g: populated graph source: Source input tensor for the scattering index: Index input tensor for the scattering dim_size: Number of elements in the output tensor """ args = [source, index, h] kwargs = {} source_size = symbolic_helper._get_tensor_sizes(source) h_size = symbolic_helper._get_tensor_sizes(h) assert source_size is not None assert h_size is not None output_type = source.type().with_sizes([h_size[0], source_size[1]]) # return g.op("tensorrt_scatter::scatter_add", *args).setType(output_type) return g.op("tensorrt_scatter::scatter_add", *args, **kwargs).setType( output_type )