You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
145 lines
6.0 KiB
145 lines
6.0 KiB
# -------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
# --------------------------------------------------------------------------
|
|
"""
|
|
Check OS requirements for ONNX Runtime Python Bindings.
|
|
"""
|
|
import linecache
|
|
import platform
|
|
import warnings
|
|
|
|
|
|
def check_distro_info():
|
|
__my_distro__ = ""
|
|
__my_distro_ver__ = ""
|
|
__my_system__ = platform.system().lower()
|
|
|
|
__OS_RELEASE_FILE__ = "/etc/os-release"
|
|
__LSB_RELEASE_FILE__ = "/etc/lsb-release"
|
|
|
|
if __my_system__ == "windows":
|
|
__my_distro__ = __my_system__
|
|
__my_distro_ver__ = platform.release().lower()
|
|
|
|
if __my_distro_ver__ != "10":
|
|
warnings.warn(
|
|
"Unsupported Windows version (%s). ONNX Runtime supports Windows 10 and above, only."
|
|
% __my_distro_ver__
|
|
)
|
|
elif __my_system__ == "linux":
|
|
"""Although the 'platform' python module for getting Distro information works well on standard OS images
|
|
running on real hardware, it is not accurate when running on Azure VMs, Git Bash, Cygwin, etc.
|
|
The returned values for release and version are unpredictable for virtualized or emulated environments.
|
|
/etc/os-release and /etc/lsb_release files, on the other hand, are guaranteed to exist and have standard values
|
|
in all OSes supported by onnxruntime. The former is the current standard file to check OS info and the latter
|
|
is its predecessor.
|
|
"""
|
|
# Newer systems have /etc/os-release with relevant distro info
|
|
__my_distro__ = linecache.getline(__OS_RELEASE_FILE__, 3)[3:-1]
|
|
__my_distro_ver__ = linecache.getline(__OS_RELEASE_FILE__, 6)[12:-2]
|
|
|
|
# Older systems may have /etc/os-release instead
|
|
if not __my_distro__:
|
|
__my_distro__ = linecache.getline(__LSB_RELEASE_FILE__, 1)[11:-1]
|
|
__my_distro_ver__ = linecache.getline(__LSB_RELEASE_FILE__, 2)[16:-1]
|
|
|
|
# Instead of trying to parse distro specific files,
|
|
# warn the user ONNX Runtime may not work out of the box
|
|
__my_distro__ = __my_distro__.lower()
|
|
__my_distro_ver__ = __my_distro_ver__.lower()
|
|
elif __my_system__ == "darwin":
|
|
__my_distro__ = __my_system__
|
|
__my_distro_ver__ = platform.release().lower()
|
|
|
|
if int(__my_distro_ver__.split(".")[0]) < 11:
|
|
warnings.warn(
|
|
"Unsupported macOS version (%s). ONNX Runtime supports macOS 11.0 or later." % (__my_distro_ver__)
|
|
)
|
|
else:
|
|
warnings.warn(
|
|
"Unsupported platform (%s). ONNX Runtime supports Linux, macOS and Windows platforms, only." % __my_system__
|
|
)
|
|
|
|
|
|
def validate_build_package_info():
|
|
import_ortmodule_exception = None
|
|
|
|
has_ortmodule = False
|
|
try:
|
|
from onnxruntime.training.ortmodule import ORTModule # noqa
|
|
|
|
has_ortmodule = True
|
|
except ImportError:
|
|
# ORTModule not present
|
|
has_ortmodule = False
|
|
except Exception as e:
|
|
# this may happen if Cuda is not installed, we want to raise it after
|
|
# for any exception other than not having ortmodule, we want to continue
|
|
# device version validation and raise the exception after.
|
|
try:
|
|
from onnxruntime.training.ortmodule._fallback import ORTModuleInitException
|
|
|
|
if isinstance(e, ORTModuleInitException):
|
|
# ORTModule is present but not ready to run yet
|
|
has_ortmodule = True
|
|
except Exception:
|
|
# ORTModule not present
|
|
has_ortmodule = False
|
|
|
|
if not has_ortmodule:
|
|
import_ortmodule_exception = e
|
|
|
|
package_name = ""
|
|
version = ""
|
|
cuda_version = ""
|
|
|
|
if has_ortmodule:
|
|
try:
|
|
# collect onnxruntime package name, version, and cuda version
|
|
from .build_and_package_info import __version__ as version
|
|
from .build_and_package_info import package_name
|
|
|
|
try:
|
|
from .build_and_package_info import cuda_version
|
|
except: # noqa
|
|
pass
|
|
|
|
if cuda_version:
|
|
# collect cuda library build info. the library info may not be available
|
|
# when the build environment has none or multiple libraries installed
|
|
try:
|
|
from .build_and_package_info import cudart_version
|
|
except: # noqa
|
|
warnings.warn("WARNING: failed to get cudart_version from onnxruntime build info.")
|
|
cudart_version = None
|
|
|
|
def print_build_package_info():
|
|
warnings.warn("onnxruntime training package info: package_name: %s" % package_name)
|
|
warnings.warn("onnxruntime training package info: __version__: %s" % version)
|
|
warnings.warn("onnxruntime training package info: cuda_version: %s" % cuda_version)
|
|
warnings.warn("onnxruntime build info: cudart_version: %s" % cudart_version)
|
|
|
|
# collection cuda library info from current environment.
|
|
from onnxruntime.capi.onnxruntime_collect_build_info import find_cudart_versions
|
|
|
|
local_cudart_versions = find_cudart_versions(build_env=False, build_cuda_version=cuda_version)
|
|
if cudart_version and local_cudart_versions and cudart_version not in local_cudart_versions:
|
|
print_build_package_info()
|
|
warnings.warn("WARNING: failed to find cudart version that matches onnxruntime build info")
|
|
warnings.warn("WARNING: found cudart versions: %s" % local_cudart_versions)
|
|
else:
|
|
# TODO: rcom
|
|
pass
|
|
|
|
except Exception as e: # noqa
|
|
warnings.warn("WARNING: failed to collect onnxruntime version and build info")
|
|
print(e)
|
|
|
|
if import_ortmodule_exception:
|
|
raise import_ortmodule_exception
|
|
|
|
return has_ortmodule, package_name, version, cuda_version
|
|
|
|
|
|
has_ortmodule, package_name, version, cuda_version = validate_build_package_info()
|