m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

131 lines
5.7 KiB

6 months ago
  1. # Copyright (c) Microsoft Corporation. All rights reserved.
  2. # Licensed under the MIT License.
  3. import inspect
  4. from collections import abc
  5. import torch
  6. def _parse_inputs_for_onnx_export(all_input_parameters, inputs, kwargs):
  7. # extracted from https://github.com/microsoft/onnxruntime/blob/239c6ad3f021ff7cc2e6247eb074bd4208dc11e2/orttraining/orttraining/python/training/ortmodule/_io.py#L433 # noqa
  8. def _add_input(name, input):
  9. """Returns number of expanded inputs that _add_input processed"""
  10. if input is None:
  11. # Drop all None inputs and return 0.
  12. return 0
  13. num_expanded_non_none_inputs = 0
  14. if isinstance(input, abc.Sequence):
  15. # If the input is a sequence (like a list), expand the list so that
  16. # each element of the list is an input by itself.
  17. for i, val in enumerate(input):
  18. # Name each input with the index appended to the original name of the
  19. # argument.
  20. num_expanded_non_none_inputs += _add_input(f"{name}_{i}", val)
  21. # Return here since the list by itself is not a valid input.
  22. # All the elements of the list have already been added as inputs individually.
  23. return num_expanded_non_none_inputs
  24. elif isinstance(input, abc.Mapping):
  25. # If the input is a mapping (like a dict), expand the dict so that
  26. # each element of the dict is an input by itself.
  27. for key, val in input.items():
  28. num_expanded_non_none_inputs += _add_input(f"{name}_{key}", val)
  29. # Return here since the dict by itself is not a valid input.
  30. # All the elements of the dict have already been added as inputs individually.
  31. return num_expanded_non_none_inputs
  32. # InputInfo should contain all the names irrespective of whether they are
  33. # a part of the onnx graph or not.
  34. input_names.append(name)
  35. # A single input non none input was processed, return 1
  36. return 1
  37. input_names = []
  38. var_positional_idx = 0
  39. num_expanded_non_none_positional_inputs = 0
  40. for input_idx, input_parameter in enumerate(all_input_parameters):
  41. if input_parameter.kind == inspect.Parameter.VAR_POSITIONAL:
  42. # VAR_POSITIONAL parameter carries all *args parameters from original forward method
  43. for args_i in range(input_idx, len(inputs)):
  44. name = f"{input_parameter.name}_{var_positional_idx}"
  45. var_positional_idx += 1
  46. inp = inputs[args_i]
  47. num_expanded_non_none_positional_inputs += _add_input(name, inp)
  48. elif (
  49. input_parameter.kind == inspect.Parameter.POSITIONAL_ONLY
  50. or input_parameter.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
  51. or input_parameter.kind == inspect.Parameter.KEYWORD_ONLY
  52. ):
  53. # All positional non-*args and non-**kwargs are processed here
  54. name = input_parameter.name
  55. inp = None
  56. input_idx += var_positional_idx
  57. is_positional = True
  58. if input_idx < len(inputs) and inputs[input_idx] is not None:
  59. inp = inputs[input_idx]
  60. elif name in kwargs and kwargs[name] is not None:
  61. inp = kwargs[name]
  62. is_positional = False
  63. num_expanded_non_none_inputs_local = _add_input(name, inp)
  64. if is_positional:
  65. num_expanded_non_none_positional_inputs += num_expanded_non_none_inputs_local
  66. elif input_parameter.kind == inspect.Parameter.VAR_KEYWORD:
  67. # **kwargs is always the last argument of forward()
  68. for name, inp in kwargs.items():
  69. if name not in input_names:
  70. _add_input(name, inp)
  71. return input_names
  72. def _flatten_module_input(names, args, kwargs):
  73. """Flatten args and kwargs in a single tuple of tensors."""
  74. # extracted from https://github.com/microsoft/onnxruntime/blob/239c6ad3f021ff7cc2e6247eb074bd4208dc11e2/orttraining/orttraining/python/training/ortmodule/_io.py#L110 # noqa
  75. def is_primitive_type(value):
  76. return type(value) in {int, bool, float}
  77. def to_tensor(value):
  78. return torch.tensor(value)
  79. ret = [to_tensor(arg) if is_primitive_type(arg) else arg for arg in args]
  80. ret += [
  81. to_tensor(kwargs[name]) if is_primitive_type(kwargs[name]) else kwargs[name] for name in names if name in kwargs
  82. ]
  83. # if kwargs is empty, append an empty dictionary at the end of the sample inputs to make exporter
  84. # happy. This is because the exporter is confused with kwargs and dictionary inputs otherwise.
  85. if not kwargs:
  86. ret.append({})
  87. return tuple(ret)
  88. def infer_input_info(module: torch.nn.Module, *inputs, **kwargs):
  89. """
  90. Infer the input names and order from the arguments used to execute a PyTorch module for usage exporting
  91. the model via torch.onnx.export.
  92. Assumes model is on CPU. Use `module.to(torch.device('cpu'))` if it isn't.
  93. Example usage:
  94. input_names, inputs_as_tuple = infer_input_info(module, ...)
  95. torch.onnx.export(module, inputs_as_type, 'model.onnx', input_names=input_names, output_names=[...], ...)
  96. :param module: Module
  97. :param inputs: Positional inputs
  98. :param kwargs: Keyword argument inputs
  99. :return: Tuple of ordered input names and input values. These can be used directly with torch.onnx.export as the
  100. `input_names` and `inputs` arguments.
  101. """
  102. module_parameters = inspect.signature(module.forward).parameters.values()
  103. input_names = _parse_inputs_for_onnx_export(module_parameters, inputs, kwargs)
  104. inputs_as_tuple = _flatten_module_input(input_names, inputs, kwargs)
  105. return input_names, inputs_as_tuple