m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

174 lines
7.8 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. """
  6. Implements ONNX's backend API.
  7. """
  8. import os
  9. import unittest
  10. import packaging.version
  11. from onnx import ModelProto, helper, version
  12. from onnx.backend.base import Backend
  13. from onnx.checker import check_model
  14. from onnxruntime import InferenceSession, SessionOptions, get_available_providers, get_device
  15. from onnxruntime.backend.backend_rep import OnnxRuntimeBackendRep
  16. class OnnxRuntimeBackend(Backend):
  17. """
  18. Implements
  19. `ONNX's backend API <https://github.com/onnx/onnx/blob/main/docs/ImplementingAnOnnxBackend.md>`_
  20. with *ONNX Runtime*.
  21. The backend is mostly used when you need to switch between
  22. multiple runtimes with the same API.
  23. `Importing models from ONNX to Caffe2 <https://github.com/onnx/tutorials/blob/master/tutorials/OnnxCaffe2Import.ipynb>`_
  24. shows how to use *caffe2* as a backend for a converted model.
  25. Note: This is not the official Python API.
  26. """ # noqa: E501
  27. allowReleasedOpsetsOnly = bool(os.getenv("ALLOW_RELEASED_ONNX_OPSET_ONLY", "1") == "1")
  28. @classmethod
  29. def is_compatible(cls, model, device=None, **kwargs):
  30. """
  31. Return whether the model is compatible with the backend.
  32. :param model: unused
  33. :param device: None to use the default device or a string (ex: `'CPU'`)
  34. :return: boolean
  35. """
  36. if device is None:
  37. device = get_device()
  38. return cls.supports_device(device)
  39. @classmethod
  40. def is_opset_supported(cls, model):
  41. """
  42. Return whether the opset for the model is supported by the backend.
  43. When By default only released onnx opsets are allowed by the backend
  44. To test new opsets env variable ALLOW_RELEASED_ONNX_OPSET_ONLY should be set to 0
  45. :param model: Model whose opsets needed to be verified.
  46. :return: boolean and error message if opset is not supported.
  47. """
  48. if cls.allowReleasedOpsetsOnly:
  49. for opset in model.opset_import:
  50. domain = opset.domain if opset.domain else "ai.onnx"
  51. try:
  52. key = (domain, opset.version)
  53. if not (key in helper.OP_SET_ID_VERSION_MAP):
  54. error_message = (
  55. "Skipping this test as only released onnx opsets are supported."
  56. "To run this test set env variable ALLOW_RELEASED_ONNX_OPSET_ONLY to 0."
  57. " Got Domain '{0}' version '{1}'.".format(domain, opset.version)
  58. )
  59. return False, error_message
  60. except AttributeError:
  61. # for some CI pipelines accessing helper.OP_SET_ID_VERSION_MAP
  62. # is generating attribute error. TODO investigate the pipelines to
  63. # fix this error. Falling back to a simple version check when this error is encountered
  64. if (domain == "ai.onnx" and opset.version > 12) or (domain == "ai.ommx.ml" and opset.version > 2):
  65. error_message = (
  66. "Skipping this test as only released onnx opsets are supported."
  67. "To run this test set env variable ALLOW_RELEASED_ONNX_OPSET_ONLY to 0."
  68. " Got Domain '{0}' version '{1}'.".format(domain, opset.version)
  69. )
  70. return False, error_message
  71. return True, ""
  72. @classmethod
  73. def supports_device(cls, device):
  74. """
  75. Check whether the backend is compiled with particular device support.
  76. In particular it's used in the testing suite.
  77. """
  78. if device == "CUDA":
  79. device = "GPU"
  80. return device in get_device()
  81. @classmethod
  82. def prepare(cls, model, device=None, **kwargs):
  83. """
  84. Load the model and creates a :class:`onnxruntime.InferenceSession`
  85. ready to be used as a backend.
  86. :param model: ModelProto (returned by `onnx.load`),
  87. string for a filename or bytes for a serialized model
  88. :param device: requested device for the computation,
  89. None means the default one which depends on
  90. the compilation settings
  91. :param kwargs: see :class:`onnxruntime.SessionOptions`
  92. :return: :class:`onnxruntime.InferenceSession`
  93. """
  94. if isinstance(model, OnnxRuntimeBackendRep):
  95. return model
  96. elif isinstance(model, InferenceSession):
  97. return OnnxRuntimeBackendRep(model)
  98. elif isinstance(model, (str, bytes)):
  99. options = SessionOptions()
  100. for k, v in kwargs.items():
  101. if hasattr(options, k):
  102. setattr(options, k, v)
  103. excluded_providers = os.getenv("ORT_ONNX_BACKEND_EXCLUDE_PROVIDERS", default="").split(",")
  104. providers = [x for x in get_available_providers() if (x not in excluded_providers)]
  105. inf = InferenceSession(model, sess_options=options, providers=providers)
  106. # backend API is primarily used for ONNX test/validation. As such, we should disable session.run() fallback
  107. # which may hide test failures.
  108. inf.disable_fallback()
  109. if device is not None and not cls.supports_device(device):
  110. raise RuntimeError("Incompatible device expected '{0}', got '{1}'".format(device, get_device()))
  111. return cls.prepare(inf, device, **kwargs)
  112. else:
  113. # type: ModelProto
  114. # check_model serializes the model anyways, so serialize the model once here
  115. # and reuse it below in the cls.prepare call to avoid an additional serialization
  116. # only works with onnx >= 1.10.0 hence the version check
  117. onnx_version = packaging.version.parse(version.version) or packaging.version.Version("0")
  118. onnx_supports_serialized_model_check = onnx_version.release >= (1, 10, 0)
  119. bin_or_model = model.SerializeToString() if onnx_supports_serialized_model_check else model
  120. check_model(bin_or_model)
  121. opset_supported, error_message = cls.is_opset_supported(model)
  122. if not opset_supported:
  123. raise unittest.SkipTest(error_message)
  124. # Now bin might be serialized, if it's not we need to serialize it otherwise we'll have
  125. # an infinite recursive call
  126. bin = bin_or_model
  127. if not isinstance(bin, (str, bytes)):
  128. bin = bin.SerializeToString()
  129. return cls.prepare(bin, device, **kwargs)
  130. @classmethod
  131. def run_model(cls, model, inputs, device=None, **kwargs):
  132. """
  133. Compute the prediction.
  134. :param model: :class:`onnxruntime.InferenceSession` returned
  135. by function *prepare*
  136. :param inputs: inputs
  137. :param device: requested device for the computation,
  138. None means the default one which depends on
  139. the compilation settings
  140. :param kwargs: see :class:`onnxruntime.RunOptions`
  141. :return: predictions
  142. """
  143. rep = cls.prepare(model, device, **kwargs)
  144. return rep.run(inputs, **kwargs)
  145. @classmethod
  146. def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
  147. """
  148. This method is not implemented as it is much more efficient
  149. to run a whole model than every node independently.
  150. """
  151. raise NotImplementedError("It is much more efficient to run a whole model than every node independently.")
  152. is_compatible = OnnxRuntimeBackend.is_compatible
  153. prepare = OnnxRuntimeBackend.prepare
  154. run = OnnxRuntimeBackend.run_model
  155. supports_device = OnnxRuntimeBackend.supports_device