m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

145 lines
6.0 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. """
  6. Check OS requirements for ONNX Runtime Python Bindings.
  7. """
  8. import linecache
  9. import platform
  10. import warnings
  11. def check_distro_info():
  12. __my_distro__ = ""
  13. __my_distro_ver__ = ""
  14. __my_system__ = platform.system().lower()
  15. __OS_RELEASE_FILE__ = "/etc/os-release"
  16. __LSB_RELEASE_FILE__ = "/etc/lsb-release"
  17. if __my_system__ == "windows":
  18. __my_distro__ = __my_system__
  19. __my_distro_ver__ = platform.release().lower()
  20. if __my_distro_ver__ != "10":
  21. warnings.warn(
  22. "Unsupported Windows version (%s). ONNX Runtime supports Windows 10 and above, only."
  23. % __my_distro_ver__
  24. )
  25. elif __my_system__ == "linux":
  26. """Although the 'platform' python module for getting Distro information works well on standard OS images
  27. running on real hardware, it is not accurate when running on Azure VMs, Git Bash, Cygwin, etc.
  28. The returned values for release and version are unpredictable for virtualized or emulated environments.
  29. /etc/os-release and /etc/lsb_release files, on the other hand, are guaranteed to exist and have standard values
  30. in all OSes supported by onnxruntime. The former is the current standard file to check OS info and the latter
  31. is its predecessor.
  32. """
  33. # Newer systems have /etc/os-release with relevant distro info
  34. __my_distro__ = linecache.getline(__OS_RELEASE_FILE__, 3)[3:-1]
  35. __my_distro_ver__ = linecache.getline(__OS_RELEASE_FILE__, 6)[12:-2]
  36. # Older systems may have /etc/os-release instead
  37. if not __my_distro__:
  38. __my_distro__ = linecache.getline(__LSB_RELEASE_FILE__, 1)[11:-1]
  39. __my_distro_ver__ = linecache.getline(__LSB_RELEASE_FILE__, 2)[16:-1]
  40. # Instead of trying to parse distro specific files,
  41. # warn the user ONNX Runtime may not work out of the box
  42. __my_distro__ = __my_distro__.lower()
  43. __my_distro_ver__ = __my_distro_ver__.lower()
  44. elif __my_system__ == "darwin":
  45. __my_distro__ = __my_system__
  46. __my_distro_ver__ = platform.release().lower()
  47. if int(__my_distro_ver__.split(".")[0]) < 11:
  48. warnings.warn(
  49. "Unsupported macOS version (%s). ONNX Runtime supports macOS 11.0 or later." % (__my_distro_ver__)
  50. )
  51. else:
  52. warnings.warn(
  53. "Unsupported platform (%s). ONNX Runtime supports Linux, macOS and Windows platforms, only." % __my_system__
  54. )
  55. def validate_build_package_info():
  56. import_ortmodule_exception = None
  57. has_ortmodule = False
  58. try:
  59. from onnxruntime.training.ortmodule import ORTModule # noqa
  60. has_ortmodule = True
  61. except ImportError:
  62. # ORTModule not present
  63. has_ortmodule = False
  64. except Exception as e:
  65. # this may happen if Cuda is not installed, we want to raise it after
  66. # for any exception other than not having ortmodule, we want to continue
  67. # device version validation and raise the exception after.
  68. try:
  69. from onnxruntime.training.ortmodule._fallback import ORTModuleInitException
  70. if isinstance(e, ORTModuleInitException):
  71. # ORTModule is present but not ready to run yet
  72. has_ortmodule = True
  73. except Exception:
  74. # ORTModule not present
  75. has_ortmodule = False
  76. if not has_ortmodule:
  77. import_ortmodule_exception = e
  78. package_name = ""
  79. version = ""
  80. cuda_version = ""
  81. if has_ortmodule:
  82. try:
  83. # collect onnxruntime package name, version, and cuda version
  84. from .build_and_package_info import __version__ as version
  85. from .build_and_package_info import package_name
  86. try:
  87. from .build_and_package_info import cuda_version
  88. except: # noqa
  89. pass
  90. if cuda_version:
  91. # collect cuda library build info. the library info may not be available
  92. # when the build environment has none or multiple libraries installed
  93. try:
  94. from .build_and_package_info import cudart_version
  95. except: # noqa
  96. warnings.warn("WARNING: failed to get cudart_version from onnxruntime build info.")
  97. cudart_version = None
  98. def print_build_package_info():
  99. warnings.warn("onnxruntime training package info: package_name: %s" % package_name)
  100. warnings.warn("onnxruntime training package info: __version__: %s" % version)
  101. warnings.warn("onnxruntime training package info: cuda_version: %s" % cuda_version)
  102. warnings.warn("onnxruntime build info: cudart_version: %s" % cudart_version)
  103. # collection cuda library info from current environment.
  104. from onnxruntime.capi.onnxruntime_collect_build_info import find_cudart_versions
  105. local_cudart_versions = find_cudart_versions(build_env=False, build_cuda_version=cuda_version)
  106. if cudart_version and local_cudart_versions and cudart_version not in local_cudart_versions:
  107. print_build_package_info()
  108. warnings.warn("WARNING: failed to find cudart version that matches onnxruntime build info")
  109. warnings.warn("WARNING: found cudart versions: %s" % local_cudart_versions)
  110. else:
  111. # TODO: rcom
  112. pass
  113. except Exception as e: # noqa
  114. warnings.warn("WARNING: failed to collect onnxruntime version and build info")
  115. print(e)
  116. if import_ortmodule_exception:
  117. raise import_ortmodule_exception
  118. return has_ortmodule, package_name, version, cuda_version
  119. has_ortmodule, package_name, version, cuda_version = validate_build_package_info()