图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

53 lines
1.7 KiB

  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. """
  6. Implements ONNX's backend API.
  7. """
  8. from typing import Any, Tuple
  9. from onnx.backend.base import BackendRep
  10. from onnxruntime import RunOptions
  11. class OnnxRuntimeBackendRep(BackendRep):
  12. """
  13. Computes the prediction for a pipeline converted into
  14. an :class:`onnxruntime.InferenceSession` node.
  15. """
  16. def __init__(self, session):
  17. """
  18. :param session: :class:`onnxruntime.InferenceSession`
  19. """
  20. self._session = session
  21. def run(self, inputs, **kwargs): # type: (Any, **Any) -> Tuple[Any, ...]
  22. """
  23. Computes the prediction.
  24. See :meth:`onnxruntime.InferenceSession.run`.
  25. """
  26. options = RunOptions()
  27. for k, v in kwargs.items():
  28. if hasattr(options, k):
  29. setattr(options, k, v)
  30. if isinstance(inputs, list):
  31. inps = {}
  32. for i, inp in enumerate(self._session.get_inputs()):
  33. inps[inp.name] = inputs[i]
  34. outs = self._session.run(None, inps, options)
  35. if isinstance(outs, list):
  36. return outs
  37. else:
  38. output_names = [o.name for o in self._session.get_outputs()]
  39. return [outs[name] for name in output_names]
  40. else:
  41. inp = self._session.get_inputs()
  42. if len(inp) != 1:
  43. raise RuntimeError("Model expect {0} inputs".format(len(inp)))
  44. inps = {inp[0].name: inputs}
  45. return self._session.run(None, inps, options)