m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

942 lines
37 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import collections
  6. import collections.abc
  7. import os
  8. import warnings
  9. from onnxruntime.capi import _pybind_state as C
  10. def get_ort_device_type(device_type, device_index):
  11. if device_type == "cuda":
  12. return C.OrtDevice.cuda()
  13. elif device_type == "cpu":
  14. return C.OrtDevice.cpu()
  15. elif device_type == "ort":
  16. return C.get_ort_device(device_index).device_type()
  17. else:
  18. raise Exception("Unsupported device type: " + device_type)
  19. def check_and_normalize_provider_args(providers, provider_options, available_provider_names):
  20. """
  21. Validates the 'providers' and 'provider_options' arguments and returns a
  22. normalized version.
  23. :param providers: Optional sequence of providers in order of decreasing
  24. precedence. Values can either be provider names or tuples of
  25. (provider name, options dict).
  26. :param provider_options: Optional sequence of options dicts corresponding
  27. to the providers listed in 'providers'.
  28. :param available_provider_names: The available provider names.
  29. :return: Tuple of (normalized 'providers' sequence, normalized
  30. 'provider_options' sequence).
  31. 'providers' can contain either names or names and options. When any options
  32. are given in 'providers', 'provider_options' should not be used.
  33. The normalized result is a tuple of:
  34. 1. Sequence of provider names in the same order as 'providers'.
  35. 2. Sequence of corresponding provider options dicts with string keys and
  36. values. Unspecified provider options yield empty dicts.
  37. """
  38. if providers is None:
  39. return [], []
  40. provider_name_to_options = collections.OrderedDict()
  41. def set_provider_options(name, options):
  42. if name not in available_provider_names:
  43. warnings.warn(
  44. "Specified provider '{}' is not in available provider names."
  45. "Available providers: '{}'".format(name, ", ".join(available_provider_names))
  46. )
  47. if name in provider_name_to_options:
  48. warnings.warn("Duplicate provider '{}' encountered, ignoring.".format(name))
  49. return
  50. normalized_options = {str(key): str(value) for key, value in options.items()}
  51. provider_name_to_options[name] = normalized_options
  52. if not isinstance(providers, collections.abc.Sequence):
  53. raise ValueError("'providers' should be a sequence.")
  54. if provider_options is not None:
  55. if not isinstance(provider_options, collections.abc.Sequence):
  56. raise ValueError("'provider_options' should be a sequence.")
  57. if len(providers) != len(provider_options):
  58. raise ValueError("'providers' and 'provider_options' should be the same length if both are given.")
  59. if not all([isinstance(provider, str) for provider in providers]):
  60. raise ValueError("Only string values for 'providers' are supported if 'provider_options' is given.")
  61. if not all([isinstance(options_for_provider, dict) for options_for_provider in provider_options]):
  62. raise ValueError("'provider_options' values must be dicts.")
  63. for name, options in zip(providers, provider_options):
  64. set_provider_options(name, options)
  65. else:
  66. for provider in providers:
  67. if isinstance(provider, str):
  68. set_provider_options(provider, dict())
  69. elif (
  70. isinstance(provider, tuple)
  71. and len(provider) == 2
  72. and isinstance(provider[0], str)
  73. and isinstance(provider[1], dict)
  74. ):
  75. set_provider_options(provider[0], provider[1])
  76. else:
  77. raise ValueError("'providers' values must be either strings or (string, dict) tuples.")
  78. return list(provider_name_to_options.keys()), list(provider_name_to_options.values())
  79. class Session:
  80. """
  81. This is the main class used to run a model.
  82. """
  83. def __init__(self):
  84. # self._sess is managed by the derived class and relies on bindings from C.InferenceSession
  85. self._sess = None
  86. self._enable_fallback = True
  87. def get_session_options(self):
  88. "Return the session options. See :class:`onnxruntime.SessionOptions`."
  89. return self._sess_options
  90. def get_inputs(self):
  91. "Return the inputs metadata as a list of :class:`onnxruntime.NodeArg`."
  92. return self._inputs_meta
  93. def get_outputs(self):
  94. "Return the outputs metadata as a list of :class:`onnxruntime.NodeArg`."
  95. return self._outputs_meta
  96. def get_overridable_initializers(self):
  97. "Return the inputs (including initializers) metadata as a list of :class:`onnxruntime.NodeArg`."
  98. return self._overridable_initializers
  99. def get_modelmeta(self):
  100. "Return the metadata. See :class:`onnxruntime.ModelMetadata`."
  101. return self._model_meta
  102. def get_providers(self):
  103. "Return list of registered execution providers."
  104. return self._providers
  105. def get_provider_options(self):
  106. "Return registered execution providers' configurations."
  107. return self._provider_options
  108. def set_providers(self, providers=None, provider_options=None):
  109. """
  110. Register the input list of execution providers. The underlying session is re-created.
  111. :param providers: Optional sequence of providers in order of decreasing
  112. precedence. Values can either be provider names or tuples of
  113. (provider name, options dict). If not provided, then all available
  114. providers are used with the default precedence.
  115. :param provider_options: Optional sequence of options dicts corresponding
  116. to the providers listed in 'providers'.
  117. 'providers' can contain either names or names and options. When any options
  118. are given in 'providers', 'provider_options' should not be used.
  119. The list of providers is ordered by precedence. For example
  120. `['CUDAExecutionProvider', 'CPUExecutionProvider']`
  121. means execute a node using CUDAExecutionProvider if capable,
  122. otherwise execute using CPUExecutionProvider.
  123. """
  124. # recreate the underlying C.InferenceSession
  125. self._reset_session(providers, provider_options)
  126. def disable_fallback(self):
  127. """
  128. Disable session.run() fallback mechanism.
  129. """
  130. self._enable_fallback = False
  131. def enable_fallback(self):
  132. """
  133. Enable session.Run() fallback mechanism. If session.Run() fails due to an internal Execution Provider failure,
  134. reset the Execution Providers enabled for this session.
  135. If GPU is enabled, fall back to CUDAExecutionProvider.
  136. otherwise fall back to CPUExecutionProvider.
  137. """
  138. self._enable_fallback = True
  139. def run(self, output_names, input_feed, run_options=None):
  140. """
  141. Compute the predictions.
  142. :param output_names: name of the outputs
  143. :param input_feed: dictionary ``{ input_name: input_value }``
  144. :param run_options: See :class:`onnxruntime.RunOptions`.
  145. :return: list of results, every result is either a numpy array,
  146. a sparse tensor, a list or a dictionary.
  147. ::
  148. sess.run([output_name], {input_name: x})
  149. """
  150. num_required_inputs = len(self._inputs_meta)
  151. num_inputs = len(input_feed)
  152. # the graph may have optional inputs used to override initializers. allow for that.
  153. if num_inputs < num_required_inputs:
  154. raise ValueError("Model requires {} inputs. Input Feed contains {}".format(num_required_inputs, num_inputs))
  155. if not output_names:
  156. output_names = [output.name for output in self._outputs_meta]
  157. try:
  158. return self._sess.run(output_names, input_feed, run_options)
  159. except C.EPFail as err:
  160. if self._enable_fallback:
  161. print("EP Error: {} using {}".format(str(err), self._providers))
  162. print("Falling back to {} and retrying.".format(self._fallback_providers))
  163. self.set_providers(self._fallback_providers)
  164. # Fallback only once.
  165. self.disable_fallback()
  166. return self._sess.run(output_names, input_feed, run_options)
  167. else:
  168. raise
  169. def run_with_ort_values(self, output_names, input_dict_ort_values, run_options=None):
  170. """
  171. Compute the predictions.
  172. :param output_names: name of the outputs
  173. :param input_dict_ort_values: dictionary ``{ input_name: input_ort_value }``
  174. See ``OrtValue`` class how to create `OrtValue`
  175. from numpy array or `SparseTensor`
  176. :param run_options: See :class:`onnxruntime.RunOptions`.
  177. :return: an array of `OrtValue`
  178. ::
  179. sess.run([output_name], {input_name: x})
  180. """
  181. def invoke(sess, output_names, input_dict_ort_values, run_options):
  182. input_dict = {}
  183. for n, v in input_dict_ort_values.items():
  184. input_dict[n] = v._get_c_value()
  185. result = sess.run_with_ort_values(input_dict, output_names, run_options)
  186. if not isinstance(result, C.OrtValueVector):
  187. raise TypeError("run_with_ort_values() must return a instance of type 'OrtValueVector'.")
  188. ort_values = [OrtValue(v) for v in result]
  189. return ort_values
  190. num_required_inputs = len(self._inputs_meta)
  191. num_inputs = len(input_dict_ort_values)
  192. # the graph may have optional inputs used to override initializers. allow for that.
  193. if num_inputs < num_required_inputs:
  194. raise ValueError("Model requires {} inputs. Input Feed contains {}".format(num_required_inputs, num_inputs))
  195. if not output_names:
  196. output_names = [output.name for output in self._outputs_meta]
  197. try:
  198. return invoke(self._sess, output_names, input_dict_ort_values, run_options)
  199. except C.EPFail as err:
  200. if self._enable_fallback:
  201. print("EP Error: {} using {}".format(str(err), self._providers))
  202. print("Falling back to {} and retrying.".format(self._fallback_providers))
  203. self.set_providers(self._fallback_providers)
  204. # Fallback only once.
  205. self.disable_fallback()
  206. return invoke(self._sess, output_names, input_dict_ort_values, run_options)
  207. else:
  208. raise
  209. def end_profiling(self):
  210. """
  211. End profiling and return results in a file.
  212. The results are stored in a filename if the option
  213. :meth:`onnxruntime.SessionOptions.enable_profiling`.
  214. """
  215. return self._sess.end_profiling()
  216. def get_profiling_start_time_ns(self):
  217. """
  218. Return the nanoseconds of profiling's start time
  219. Comparable to time.monotonic_ns() after Python 3.3
  220. On some platforms, this timer may not be as precise as nanoseconds
  221. For instance, on Windows and MacOS, the precision will be ~100ns
  222. """
  223. return self._sess.get_profiling_start_time_ns
  224. def io_binding(self):
  225. "Return an onnxruntime.IOBinding object`."
  226. return IOBinding(self)
  227. def run_with_iobinding(self, iobinding, run_options=None):
  228. """
  229. Compute the predictions.
  230. :param iobinding: the iobinding object that has graph inputs/outputs bind.
  231. :param run_options: See :class:`onnxruntime.RunOptions`.
  232. """
  233. self._sess.run_with_iobinding(iobinding._iobinding, run_options)
  234. def run_with_ortvaluevector(self, run_options, feed_names, feeds, fetch_names, fetches, fetch_devices):
  235. """
  236. Compute the predictions similar to other run_*() methods but with minimal C++/Python conversion overhead.
  237. :param run_options: See :class:`onnxruntime.RunOptions`.
  238. :param feed_names: list of input names.
  239. :param feeds: list of input OrtValue.
  240. :param fetch_names: list of output names.
  241. :param fetches: list of output OrtValue.
  242. :param fetch_devices: list of output devices.
  243. """
  244. self._sess.run_with_ortvaluevector(run_options, feed_names, feeds, fetch_names, fetches, fetch_devices)
  245. class InferenceSession(Session):
  246. """
  247. This is the main class used to run a model.
  248. """
  249. def __init__(self, path_or_bytes, sess_options=None, providers=None, provider_options=None, **kwargs):
  250. """
  251. :param path_or_bytes: filename or serialized ONNX or ORT format model in a byte string
  252. :param sess_options: session options
  253. :param providers: Optional sequence of providers in order of decreasing
  254. precedence. Values can either be provider names or tuples of
  255. (provider name, options dict). If not provided, then all available
  256. providers are used with the default precedence.
  257. :param provider_options: Optional sequence of options dicts corresponding
  258. to the providers listed in 'providers'.
  259. The model type will be inferred unless explicitly set in the SessionOptions.
  260. To explicitly set:
  261. ::
  262. so = onnxruntime.SessionOptions()
  263. # so.add_session_config_entry('session.load_model_format', 'ONNX') or
  264. so.add_session_config_entry('session.load_model_format', 'ORT')
  265. A file extension of '.ort' will be inferred as an ORT format model.
  266. All other filenames are assumed to be ONNX format models.
  267. 'providers' can contain either names or names and options. When any options
  268. are given in 'providers', 'provider_options' should not be used.
  269. The list of providers is ordered by precedence. For example
  270. `['CUDAExecutionProvider', 'CPUExecutionProvider']`
  271. means execute a node using `CUDAExecutionProvider`
  272. if capable, otherwise execute using `CPUExecutionProvider`.
  273. """
  274. Session.__init__(self)
  275. if isinstance(path_or_bytes, str):
  276. self._model_path = path_or_bytes
  277. self._model_bytes = None
  278. elif isinstance(path_or_bytes, bytes):
  279. self._model_path = None
  280. self._model_bytes = path_or_bytes # TODO: This is bad as we're holding the memory indefinitely
  281. else:
  282. raise TypeError("Unable to load from type '{0}'".format(type(path_or_bytes)))
  283. self._sess_options = sess_options
  284. self._sess_options_initial = sess_options
  285. self._enable_fallback = True
  286. self._read_config_from_model = os.environ.get("ORT_LOAD_CONFIG_FROM_MODEL") == "1"
  287. # internal parameters that we don't expect to be used in general so aren't documented
  288. disabled_optimizers = kwargs["disabled_optimizers"] if "disabled_optimizers" in kwargs else None
  289. try:
  290. self._create_inference_session(providers, provider_options, disabled_optimizers)
  291. except ValueError:
  292. if self._enable_fallback:
  293. print("EP Error using {}".format(providers))
  294. print("Falling back to {} and retrying.".format(self._fallback_providers))
  295. self._create_inference_session(self._fallback_providers, None)
  296. # Fallback only once.
  297. self.disable_fallback()
  298. else:
  299. raise
  300. def _create_inference_session(self, providers, provider_options, disabled_optimizers=None):
  301. available_providers = C.get_available_providers()
  302. # Tensorrt can fall back to CUDA. All others fall back to CPU.
  303. if "TensorrtExecutionProvider" in available_providers:
  304. self._fallback_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
  305. elif "MIGraphXExecutionProvider" in available_providers:
  306. self._fallback_providers = ["ROCMExecutionProvider", "CPUExecutionProvider"]
  307. else:
  308. self._fallback_providers = ["CPUExecutionProvider"]
  309. # validate providers and provider_options before other initialization
  310. providers, provider_options = check_and_normalize_provider_args(
  311. providers, provider_options, available_providers
  312. )
  313. if providers == [] and len(available_providers) > 1:
  314. self.disable_fallback()
  315. raise ValueError(
  316. "This ORT build has {} enabled. ".format(available_providers)
  317. + "Since ORT 1.9, you are required to explicitly set "
  318. + "the providers parameter when instantiating InferenceSession. For example, "
  319. "onnxruntime.InferenceSession(..., providers={}, ...)".format(available_providers)
  320. )
  321. session_options = self._sess_options if self._sess_options else C.get_default_session_options()
  322. if self._model_path:
  323. sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
  324. else:
  325. sess = C.InferenceSession(session_options, self._model_bytes, False, self._read_config_from_model)
  326. if disabled_optimizers is None:
  327. disabled_optimizers = set()
  328. elif not isinstance(disabled_optimizers, set):
  329. # convert to set. assumes iterable
  330. disabled_optimizers = set(disabled_optimizers)
  331. # initialize the C++ InferenceSession
  332. sess.initialize_session(providers, provider_options, disabled_optimizers)
  333. self._sess = sess
  334. self._sess_options = self._sess.session_options
  335. self._inputs_meta = self._sess.inputs_meta
  336. self._outputs_meta = self._sess.outputs_meta
  337. self._overridable_initializers = self._sess.overridable_initializers
  338. self._model_meta = self._sess.model_meta
  339. self._providers = self._sess.get_providers()
  340. self._provider_options = self._sess.get_provider_options()
  341. self._profiling_start_time_ns = self._sess.get_profiling_start_time_ns
  342. def _reset_session(self, providers, provider_options):
  343. "release underlying session object."
  344. # meta data references session internal structures
  345. # so they must be set to None to decrement _sess reference count.
  346. self._sess_options = None
  347. self._inputs_meta = None
  348. self._outputs_meta = None
  349. self._overridable_initializers = None
  350. self._model_meta = None
  351. self._providers = None
  352. self._provider_options = None
  353. self._profiling_start_time_ns = None
  354. # create a new C.InferenceSession
  355. self._sess = None
  356. self._sess_options = self._sess_options_initial
  357. self._create_inference_session(providers, provider_options)
  358. class IOBinding:
  359. """
  360. This class provides API to bind input/output to a specified device, e.g. GPU.
  361. """
  362. def __init__(self, session):
  363. self._iobinding = C.SessionIOBinding(session._sess)
  364. self._numpy_obj_references = {}
  365. def bind_cpu_input(self, name, arr_on_cpu):
  366. """
  367. bind an input to array on CPU
  368. :param name: input name
  369. :param arr_on_cpu: input values as a python array on CPU
  370. """
  371. # Hold a reference to the numpy object as the bound OrtValue is backed
  372. # directly by the data buffer of the numpy object and so the numpy object
  373. # must be around until this IOBinding instance is around
  374. self._numpy_obj_references[name] = arr_on_cpu
  375. self._iobinding.bind_input(name, arr_on_cpu)
  376. def bind_input(self, name, device_type, device_id, element_type, shape, buffer_ptr):
  377. """
  378. :param name: input name
  379. :param device_type: e.g. cpu, cuda
  380. :param device_id: device id, e.g. 0
  381. :param element_type: input element type
  382. :param shape: input shape
  383. :param buffer_ptr: memory pointer to input data
  384. """
  385. self._iobinding.bind_input(
  386. name,
  387. C.OrtDevice(
  388. get_ort_device_type(device_type, device_id),
  389. C.OrtDevice.default_memory(),
  390. device_id,
  391. ),
  392. element_type,
  393. shape,
  394. buffer_ptr,
  395. )
  396. def bind_ortvalue_input(self, name, ortvalue):
  397. """
  398. :param name: input name
  399. :param ortvalue: OrtValue instance to bind
  400. """
  401. self._iobinding.bind_ortvalue_input(name, ortvalue._ortvalue)
  402. def synchronize_inputs(self):
  403. self._iobinding.synchronize_inputs()
  404. def bind_output(
  405. self,
  406. name,
  407. device_type="cpu",
  408. device_id=0,
  409. element_type=None,
  410. shape=None,
  411. buffer_ptr=None,
  412. ):
  413. """
  414. :param name: output name
  415. :param device_type: e.g. cpu, cuda, cpu by default
  416. :param device_id: device id, e.g. 0
  417. :param element_type: output element type
  418. :param shape: output shape
  419. :param buffer_ptr: memory pointer to output data
  420. """
  421. # Follow the `if` path when the user has not provided any pre-allocated buffer but still
  422. # would like to bind an output to a specific device (e.g. cuda).
  423. # Pre-allocating an output buffer may not be an option for the user as :
  424. # (1) They may not want to use a custom allocator specific to the device they want to bind the output to,
  425. # in which case ORT will allocate the memory for the user
  426. # (2) The output has a dynamic shape and hence the size of the buffer may not be fixed across runs
  427. if buffer_ptr is None:
  428. self._iobinding.bind_output(
  429. name,
  430. C.OrtDevice(
  431. get_ort_device_type(device_type, device_id),
  432. C.OrtDevice.default_memory(),
  433. device_id,
  434. ),
  435. )
  436. else:
  437. if element_type is None or shape is None:
  438. raise ValueError("`element_type` and `shape` are to be provided if pre-allocated memory is provided")
  439. self._iobinding.bind_output(
  440. name,
  441. C.OrtDevice(
  442. get_ort_device_type(device_type, device_id),
  443. C.OrtDevice.default_memory(),
  444. device_id,
  445. ),
  446. element_type,
  447. shape,
  448. buffer_ptr,
  449. )
  450. def bind_ortvalue_output(self, name, ortvalue):
  451. """
  452. :param name: output name
  453. :param ortvalue: OrtValue instance to bind
  454. """
  455. self._iobinding.bind_ortvalue_output(name, ortvalue._ortvalue)
  456. def synchronize_outputs(self):
  457. self._iobinding.synchronize_outputs()
  458. def get_outputs(self):
  459. """
  460. Returns the output OrtValues from the Run() that preceded the call.
  461. The data buffer of the obtained OrtValues may not reside on CPU memory
  462. """
  463. outputs = self._iobinding.get_outputs()
  464. if not isinstance(outputs, C.OrtValueVector):
  465. raise TypeError("get_outputs() must return an instance of type 'OrtValueVector'.")
  466. return [OrtValue(ortvalue) for ortvalue in outputs]
  467. def get_outputs_as_ortvaluevector(self):
  468. return self._iobinding.get_outputs()
  469. def copy_outputs_to_cpu(self):
  470. """Copy output contents to CPU (if on another device). No-op if already on the CPU."""
  471. return self._iobinding.copy_outputs_to_cpu()
  472. def clear_binding_inputs(self):
  473. self._iobinding.clear_binding_inputs()
  474. def clear_binding_outputs(self):
  475. self._iobinding.clear_binding_outputs()
  476. class OrtValue:
  477. """
  478. A data structure that supports all ONNX data formats (tensors and non-tensors) that allows users
  479. to place the data backing these on a device, for example, on a CUDA supported device.
  480. This class provides APIs to construct and deal with OrtValues.
  481. """
  482. def __init__(self, ortvalue, numpy_obj=None):
  483. if isinstance(ortvalue, C.OrtValue):
  484. self._ortvalue = ortvalue
  485. # Hold a ref count to the numpy object if the OrtValue is backed directly
  486. # by its data buffer so that it isn't destroyed when the OrtValue is in use
  487. self._numpy_obj = numpy_obj
  488. else:
  489. # An end user won't hit this error
  490. raise ValueError(
  491. "`Provided ortvalue` needs to be of type " + "`onnxruntime.capi.onnxruntime_pybind11_state.OrtValue`"
  492. )
  493. def _get_c_value(self):
  494. return self._ortvalue
  495. @staticmethod
  496. def ortvalue_from_numpy(numpy_obj, device_type="cpu", device_id=0):
  497. """
  498. Factory method to construct an OrtValue (which holds a Tensor) from a given Numpy object
  499. A copy of the data in the Numpy object is held by the OrtValue only if the device is NOT cpu
  500. :param numpy_obj: The Numpy object to construct the OrtValue from
  501. :param device_type: e.g. cpu, cuda, cpu by default
  502. :param device_id: device id, e.g. 0
  503. """
  504. # Hold a reference to the numpy object (if device_type is 'cpu') as the OrtValue
  505. # is backed directly by the data buffer of the numpy object and so the numpy object
  506. # must be around until this OrtValue instance is around
  507. return OrtValue(
  508. C.OrtValue.ortvalue_from_numpy(
  509. numpy_obj,
  510. C.OrtDevice(
  511. get_ort_device_type(device_type, device_id),
  512. C.OrtDevice.default_memory(),
  513. device_id,
  514. ),
  515. ),
  516. numpy_obj if device_type.lower() == "cpu" else None,
  517. )
  518. @staticmethod
  519. def ortvalue_from_shape_and_type(shape=None, element_type=None, device_type="cpu", device_id=0):
  520. """
  521. Factory method to construct an OrtValue (which holds a Tensor) from given shape and element_type
  522. :param shape: List of integers indicating the shape of the OrtValue
  523. :param element_type: The data type of the elements in the OrtValue (numpy type)
  524. :param device_type: e.g. cpu, cuda, cpu by default
  525. :param device_id: device id, e.g. 0
  526. """
  527. if shape is None or element_type is None:
  528. raise ValueError("`element_type` and `shape` are to be provided if pre-allocated memory is provided")
  529. return OrtValue(
  530. C.OrtValue.ortvalue_from_shape_and_type(
  531. shape,
  532. element_type,
  533. C.OrtDevice(
  534. get_ort_device_type(device_type, device_id),
  535. C.OrtDevice.default_memory(),
  536. device_id,
  537. ),
  538. )
  539. )
  540. @staticmethod
  541. def ort_value_from_sparse_tensor(sparse_tensor):
  542. """
  543. The function will construct an OrtValue instance from a valid SparseTensor
  544. The new instance of OrtValue will assume the ownership of sparse_tensor
  545. """
  546. return OrtValue(C.OrtValue.ort_value_from_sparse_tensor(sparse_tensor._get_c_tensor()))
  547. def as_sparse_tensor(self):
  548. """
  549. The function will return SparseTensor contained in this OrtValue
  550. """
  551. return SparseTensor(self._ortvalue.as_sparse_tensor())
  552. def data_ptr(self):
  553. """
  554. Returns the address of the first element in the OrtValue's data buffer
  555. """
  556. return self._ortvalue.data_ptr()
  557. def device_name(self):
  558. """
  559. Returns the name of the device where the OrtValue's data buffer resides e.g. cpu, cuda
  560. """
  561. return self._ortvalue.device_name().lower()
  562. def shape(self):
  563. """
  564. Returns the shape of the data in the OrtValue
  565. """
  566. return self._ortvalue.shape()
  567. def data_type(self):
  568. """
  569. Returns the data type of the data in the OrtValue
  570. """
  571. return self._ortvalue.data_type()
  572. def element_type(self):
  573. """
  574. Returns the proto type of the data in the OrtValue
  575. if the OrtValue is a tensor.
  576. """
  577. return self._ortvalue.element_type()
  578. def has_value(self):
  579. """
  580. Returns True if the OrtValue corresponding to an
  581. optional type contains data, else returns False
  582. """
  583. return self._ortvalue.has_value()
  584. def is_tensor(self):
  585. """
  586. Returns True if the OrtValue contains a Tensor, else returns False
  587. """
  588. return self._ortvalue.is_tensor()
  589. def is_sparse_tensor(self):
  590. """
  591. Returns True if the OrtValue contains a SparseTensor, else returns False
  592. """
  593. return self._ortvalue.is_sparse_tensor()
  594. def is_tensor_sequence(self):
  595. """
  596. Returns True if the OrtValue contains a Tensor Sequence, else returns False
  597. """
  598. return self._ortvalue.is_tensor_sequence()
  599. def numpy(self):
  600. """
  601. Returns a Numpy object from the OrtValue.
  602. Valid only for OrtValues holding Tensors. Throws for OrtValues holding non-Tensors.
  603. Use accessors to gain a reference to non-Tensor objects such as SparseTensor
  604. """
  605. return self._ortvalue.numpy()
  606. def update_inplace(self, np_arr):
  607. """
  608. Update the OrtValue in place with a new Numpy array. The numpy contents
  609. are copied over to the device memory backing the OrtValue. It can be used
  610. to update the input valuess for an InferenceSession with CUDA graph
  611. enabled or other scenarios where the OrtValue needs to be updated while
  612. the memory address can not be changed.
  613. """
  614. self._ortvalue.update_inplace(np_arr)
  615. class OrtDevice:
  616. """
  617. A data structure that exposes the underlying C++ OrtDevice
  618. """
  619. def __init__(self, c_ort_device):
  620. """
  621. Internal constructor
  622. """
  623. if isinstance(c_ort_device, C.OrtDevice):
  624. self._ort_device = c_ort_device
  625. else:
  626. raise ValueError(
  627. "`Provided object` needs to be of type " + "`onnxruntime.capi.onnxruntime_pybind11_state.OrtDevice`"
  628. )
  629. def _get_c_device(self):
  630. """
  631. Internal accessor to underlying object
  632. """
  633. return self._ort_device
  634. @staticmethod
  635. def make(ort_device_name, device_id):
  636. return OrtDevice(
  637. C.OrtDevice(
  638. get_ort_device_type(ort_device_name, device_id),
  639. C.OrtDevice.default_memory(),
  640. device_id,
  641. )
  642. )
  643. def device_id(self):
  644. return self._ort_device.device_id()
  645. def device_type(self):
  646. return self._ort_device.device_type()
  647. class SparseTensor:
  648. """
  649. A data structure that project the C++ SparseTensor object
  650. The class provides API to work with the object.
  651. Depending on the format, the class will hold more than one buffer
  652. depending on the format
  653. """
  654. def __init__(self, sparse_tensor):
  655. """
  656. Internal constructor
  657. """
  658. if isinstance(sparse_tensor, C.SparseTensor):
  659. self._tensor = sparse_tensor
  660. else:
  661. # An end user won't hit this error
  662. raise ValueError(
  663. "`Provided object` needs to be of type " + "`onnxruntime.capi.onnxruntime_pybind11_state.SparseTensor`"
  664. )
  665. def _get_c_tensor(self):
  666. return self._tensor
  667. @staticmethod
  668. def sparse_coo_from_numpy(dense_shape, values, coo_indices, ort_device):
  669. """
  670. Factory method to construct a SparseTensor in COO format from given arguments
  671. :param dense_shape: 1-D numpy array(int64) or a python list that contains a dense_shape of the sparse tensor
  672. must be on cpu memory
  673. :param values: a homogeneous, contiguous 1-D numpy array that contains non-zero elements of the tensor
  674. of a type.
  675. :param coo_indices: contiguous numpy array(int64) that contains COO indices for the tensor. coo_indices may
  676. have a 1-D shape when it contains a linear index of non-zero values and its length must be equal to
  677. that of the values. It can also be of 2-D shape, in which has it contains pairs of coordinates for
  678. each of the nnz values and its length must be exactly twice of the values length.
  679. :param ort_device: - describes the backing memory owned by the supplied nummpy arrays. Only CPU memory is
  680. suppored for non-numeric data types.
  681. For primitive types, the method will map values and coo_indices arrays into native memory and will use
  682. them as backing storage. It will increment the reference count for numpy arrays and it will decrement it
  683. on GC. The buffers may reside in any storage either CPU or GPU.
  684. For strings and objects, it will create a copy of the arrays in CPU memory as ORT does not support those
  685. on other devices and their memory can not be mapped.
  686. """
  687. return SparseTensor(
  688. C.SparseTensor.sparse_coo_from_numpy(dense_shape, values, coo_indices, ort_device._get_c_device())
  689. )
  690. @staticmethod
  691. def sparse_csr_from_numpy(dense_shape, values, inner_indices, outer_indices, ort_device):
  692. """
  693. Factory method to construct a SparseTensor in CSR format from given arguments
  694. :param dense_shape: 1-D numpy array(int64) or a python list that contains a dense_shape of the
  695. sparse tensor (rows, cols) must be on cpu memory
  696. :param values: a contiguous, homogeneous 1-D numpy array that contains non-zero elements of the tensor
  697. of a type.
  698. :param inner_indices: contiguous 1-D numpy array(int64) that contains CSR inner indices for the tensor.
  699. Its length must be equal to that of the values.
  700. :param outer_indices: contiguous 1-D numpy array(int64) that contains CSR outer indices for the tensor.
  701. Its length must be equal to the number of rows + 1.
  702. :param ort_device: - describes the backing memory owned by the supplied nummpy arrays. Only CPU memory is
  703. suppored for non-numeric data types.
  704. For primitive types, the method will map values and indices arrays into native memory and will use them as
  705. backing storage. It will increment the reference count and it will decrement then count when it is GCed.
  706. The buffers may reside in any storage either CPU or GPU.
  707. For strings and objects, it will create a copy of the arrays in CPU memory as ORT does not support those
  708. on other devices and their memory can not be mapped.
  709. """
  710. return SparseTensor(
  711. C.SparseTensor.sparse_csr_from_numpy(
  712. dense_shape,
  713. values,
  714. inner_indices,
  715. outer_indices,
  716. ort_device._get_c_device(),
  717. )
  718. )
  719. def values(self):
  720. """
  721. The method returns a numpy array that is backed by the native memory
  722. if the data type is numeric. Otherwise, the returned numpy array that contains
  723. copies of the strings.
  724. """
  725. return self._tensor.values()
  726. def as_coo_view(self):
  727. """
  728. The method will return coo representation of the sparse tensor which will enable
  729. querying COO indices. If the instance did not contain COO format, it would throw.
  730. You can query coo indices as:
  731. ::
  732. coo_indices = sparse_tensor.as_coo_view().indices()
  733. which will return a numpy array that is backed by the native memory.
  734. """
  735. return self._tensor.get_coo_data()
  736. def as_csrc_view(self):
  737. """
  738. The method will return CSR(C) representation of the sparse tensor which will enable
  739. querying CRS(C) indices. If the instance dit not contain CSR(C) format, it would throw.
  740. You can query indices as:
  741. ::
  742. inner_ndices = sparse_tensor.as_csrc_view().inner()
  743. outer_ndices = sparse_tensor.as_csrc_view().outer()
  744. returning numpy arrays backed by the native memory.
  745. """
  746. return self._tensor.get_csrc_data()
  747. def as_blocksparse_view(self):
  748. """
  749. The method will return coo representation of the sparse tensor which will enable
  750. querying BlockSparse indices. If the instance did not contain BlockSparse format, it would throw.
  751. You can query coo indices as:
  752. ::
  753. block_sparse_indices = sparse_tensor.as_blocksparse_view().indices()
  754. which will return a numpy array that is backed by the native memory
  755. """
  756. return self._tensor.get_blocksparse_data()
  757. def to_cuda(self, ort_device):
  758. """
  759. Returns a copy of this instance on the specified cuda device
  760. :param ort_device: with name 'cuda' and valid gpu device id
  761. The method will throw if:
  762. - this instance contains strings
  763. - this instance is already on GPU. Cross GPU copy is not supported
  764. - CUDA is not present in this build
  765. - if the specified device is not valid
  766. """
  767. return SparseTensor(self._tensor.to_cuda(ort_device._get_c_device()))
  768. def format(self):
  769. """
  770. Returns a OrtSparseFormat enumeration
  771. """
  772. return self._tensor.format
  773. def dense_shape(self):
  774. """
  775. Returns a numpy array(int64) containing a dense shape of a sparse tensor
  776. """
  777. return self._tensor.dense_shape()
  778. def data_type(self):
  779. """
  780. Returns a string data type of the data in the OrtValue
  781. """
  782. return self._tensor.data_type()
  783. def device_name(self):
  784. """
  785. Returns the name of the device where the SparseTensor data buffers reside e.g. cpu, cuda
  786. """
  787. return self._tensor.device_name().lower()