You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
|
|
# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- """
Implements ONNX's backend API. """
from typing import Any, Tuple
from onnx.backend.base import BackendRep
from onnxruntime import RunOptions
class OnnxRuntimeBackendRep(BackendRep): """
Computes the prediction for a pipeline converted into an :class:`onnxruntime.InferenceSession` node. """
def __init__(self, session): """
:param session: :class:`onnxruntime.InferenceSession` """
self._session = session
def run(self, inputs, **kwargs): # type: (Any, **Any) -> Tuple[Any, ...] """
Computes the prediction. See :meth:`onnxruntime.InferenceSession.run`. """
options = RunOptions() for k, v in kwargs.items(): if hasattr(options, k): setattr(options, k, v)
if isinstance(inputs, list): inps = {} for i, inp in enumerate(self._session.get_inputs()): inps[inp.name] = inputs[i] outs = self._session.run(None, inps, options) if isinstance(outs, list): return outs else: output_names = [o.name for o in self._session.get_outputs()] return [outs[name] for name in output_names] else: inp = self._session.get_inputs() if len(inp) != 1: raise RuntimeError("Model expect {0} inputs".format(len(inp))) inps = {inp[0].name: inputs} return self._session.run(None, inps, options)
|