You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
131 lines
5.7 KiB
131 lines
5.7 KiB
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
|
|
import inspect
|
|
from collections import abc
|
|
|
|
import torch
|
|
|
|
|
|
def _parse_inputs_for_onnx_export(all_input_parameters, inputs, kwargs):
|
|
# extracted from https://github.com/microsoft/onnxruntime/blob/239c6ad3f021ff7cc2e6247eb074bd4208dc11e2/orttraining/orttraining/python/training/ortmodule/_io.py#L433 # noqa
|
|
|
|
def _add_input(name, input):
|
|
"""Returns number of expanded inputs that _add_input processed"""
|
|
|
|
if input is None:
|
|
# Drop all None inputs and return 0.
|
|
return 0
|
|
|
|
num_expanded_non_none_inputs = 0
|
|
if isinstance(input, abc.Sequence):
|
|
# If the input is a sequence (like a list), expand the list so that
|
|
# each element of the list is an input by itself.
|
|
for i, val in enumerate(input):
|
|
# Name each input with the index appended to the original name of the
|
|
# argument.
|
|
num_expanded_non_none_inputs += _add_input(f"{name}_{i}", val)
|
|
|
|
# Return here since the list by itself is not a valid input.
|
|
# All the elements of the list have already been added as inputs individually.
|
|
return num_expanded_non_none_inputs
|
|
elif isinstance(input, abc.Mapping):
|
|
# If the input is a mapping (like a dict), expand the dict so that
|
|
# each element of the dict is an input by itself.
|
|
for key, val in input.items():
|
|
num_expanded_non_none_inputs += _add_input(f"{name}_{key}", val)
|
|
|
|
# Return here since the dict by itself is not a valid input.
|
|
# All the elements of the dict have already been added as inputs individually.
|
|
return num_expanded_non_none_inputs
|
|
|
|
# InputInfo should contain all the names irrespective of whether they are
|
|
# a part of the onnx graph or not.
|
|
input_names.append(name)
|
|
|
|
# A single input non none input was processed, return 1
|
|
return 1
|
|
|
|
input_names = []
|
|
var_positional_idx = 0
|
|
num_expanded_non_none_positional_inputs = 0
|
|
|
|
for input_idx, input_parameter in enumerate(all_input_parameters):
|
|
if input_parameter.kind == inspect.Parameter.VAR_POSITIONAL:
|
|
# VAR_POSITIONAL parameter carries all *args parameters from original forward method
|
|
for args_i in range(input_idx, len(inputs)):
|
|
name = f"{input_parameter.name}_{var_positional_idx}"
|
|
var_positional_idx += 1
|
|
inp = inputs[args_i]
|
|
num_expanded_non_none_positional_inputs += _add_input(name, inp)
|
|
elif (
|
|
input_parameter.kind == inspect.Parameter.POSITIONAL_ONLY
|
|
or input_parameter.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
|
or input_parameter.kind == inspect.Parameter.KEYWORD_ONLY
|
|
):
|
|
# All positional non-*args and non-**kwargs are processed here
|
|
name = input_parameter.name
|
|
inp = None
|
|
input_idx += var_positional_idx
|
|
is_positional = True
|
|
if input_idx < len(inputs) and inputs[input_idx] is not None:
|
|
inp = inputs[input_idx]
|
|
elif name in kwargs and kwargs[name] is not None:
|
|
inp = kwargs[name]
|
|
is_positional = False
|
|
num_expanded_non_none_inputs_local = _add_input(name, inp)
|
|
if is_positional:
|
|
num_expanded_non_none_positional_inputs += num_expanded_non_none_inputs_local
|
|
elif input_parameter.kind == inspect.Parameter.VAR_KEYWORD:
|
|
# **kwargs is always the last argument of forward()
|
|
for name, inp in kwargs.items():
|
|
if name not in input_names:
|
|
_add_input(name, inp)
|
|
|
|
return input_names
|
|
|
|
|
|
def _flatten_module_input(names, args, kwargs):
|
|
"""Flatten args and kwargs in a single tuple of tensors."""
|
|
# extracted from https://github.com/microsoft/onnxruntime/blob/239c6ad3f021ff7cc2e6247eb074bd4208dc11e2/orttraining/orttraining/python/training/ortmodule/_io.py#L110 # noqa
|
|
|
|
def is_primitive_type(value):
|
|
return type(value) in {int, bool, float}
|
|
|
|
def to_tensor(value):
|
|
return torch.tensor(value)
|
|
|
|
ret = [to_tensor(arg) if is_primitive_type(arg) else arg for arg in args]
|
|
ret += [
|
|
to_tensor(kwargs[name]) if is_primitive_type(kwargs[name]) else kwargs[name] for name in names if name in kwargs
|
|
]
|
|
|
|
# if kwargs is empty, append an empty dictionary at the end of the sample inputs to make exporter
|
|
# happy. This is because the exporter is confused with kwargs and dictionary inputs otherwise.
|
|
if not kwargs:
|
|
ret.append({})
|
|
|
|
return tuple(ret)
|
|
|
|
|
|
def infer_input_info(module: torch.nn.Module, *inputs, **kwargs):
|
|
"""
|
|
Infer the input names and order from the arguments used to execute a PyTorch module for usage exporting
|
|
the model via torch.onnx.export.
|
|
Assumes model is on CPU. Use `module.to(torch.device('cpu'))` if it isn't.
|
|
|
|
Example usage:
|
|
input_names, inputs_as_tuple = infer_input_info(module, ...)
|
|
torch.onnx.export(module, inputs_as_type, 'model.onnx', input_names=input_names, output_names=[...], ...)
|
|
|
|
:param module: Module
|
|
:param inputs: Positional inputs
|
|
:param kwargs: Keyword argument inputs
|
|
:return: Tuple of ordered input names and input values. These can be used directly with torch.onnx.export as the
|
|
`input_names` and `inputs` arguments.
|
|
"""
|
|
module_parameters = inspect.signature(module.forward).parameters.values()
|
|
input_names = _parse_inputs_for_onnx_export(module_parameters, inputs, kwargs)
|
|
inputs_as_tuple = _flatten_module_input(input_names, inputs, kwargs)
|
|
|
|
return input_names, inputs_as_tuple
|