m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

157 lines
5.4 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import argparse
  6. import os
  7. import sys
  8. from timeit import default_timer as timer
  9. import numpy as np
  10. import onnxruntime as onnxrt
  11. float_dict = {
  12. "tensor(float16)": "float16",
  13. "tensor(float)": "float32",
  14. "tensor(double)": "float64",
  15. }
  16. integer_dict = {
  17. "tensor(int32)": "int32",
  18. "tensor(int8)": "int8",
  19. "tensor(uint8)": "uint8",
  20. "tensor(int16)": "int16",
  21. "tensor(uint16)": "uint16",
  22. "tensor(int64)": "int64",
  23. "tensor(uint64)": "uint64",
  24. }
  25. def generate_feeds(sess, symbolic_dims={}):
  26. feeds = {}
  27. for input_meta in sess.get_inputs():
  28. # replace any symbolic dimensions
  29. shape = []
  30. for dim in input_meta.shape:
  31. if not dim:
  32. # unknown dim
  33. shape.append(1)
  34. elif type(dim) == str:
  35. # symbolic dim. see if we have a value otherwise use 1
  36. if dim in symbolic_dims:
  37. shape.append(int(symbolic_dims[dim]))
  38. else:
  39. shape.append(1)
  40. else:
  41. shape.append(dim)
  42. if input_meta.type in float_dict:
  43. feeds[input_meta.name] = np.random.rand(*shape).astype(float_dict[input_meta.type])
  44. elif input_meta.type in integer_dict:
  45. feeds[input_meta.name] = np.random.uniform(high=1000, size=tuple(shape)).astype(
  46. integer_dict[input_meta.type]
  47. )
  48. elif input_meta.type == "tensor(bool)":
  49. feeds[input_meta.name] = np.random.randint(2, size=tuple(shape)).astype("bool")
  50. else:
  51. print("unsupported input type {} for input {}".format(input_meta.type, input_meta.name))
  52. sys.exit(-1)
  53. return feeds
  54. # simple test program for loading onnx model, feeding all inputs and running the model num_iters times.
  55. def run_model(
  56. model_path,
  57. num_iters=1,
  58. debug=None,
  59. profile=None,
  60. symbolic_dims={},
  61. feeds=None,
  62. override_initializers=True,
  63. ):
  64. if debug:
  65. print("Pausing execution ready for debugger to attach to pid: {}".format(os.getpid()))
  66. print("Press key to continue.")
  67. sys.stdin.read(1)
  68. sess_options = None
  69. if profile:
  70. sess_options = onnxrt.SessionOptions()
  71. sess_options.enable_profiling = True
  72. sess_options.profile_file_prefix = os.path.basename(model_path)
  73. sess = onnxrt.InferenceSession(
  74. model_path,
  75. sess_options=sess_options,
  76. providers=onnxrt.get_available_providers(),
  77. )
  78. meta = sess.get_modelmeta()
  79. if not feeds:
  80. feeds = generate_feeds(sess, symbolic_dims)
  81. if override_initializers:
  82. # Starting with IR4 some initializers provide default values
  83. # and can be overridden (available in IR4). For IR < 4 models
  84. # the list would be empty
  85. for initializer in sess.get_overridable_initializers():
  86. shape = [dim if dim else 1 for dim in initializer.shape]
  87. if initializer.type in float_dict:
  88. feeds[initializer.name] = np.random.rand(*shape).astype(float_dict[initializer.type])
  89. elif initializer.type in integer_dict:
  90. feeds[initializer.name] = np.random.uniform(high=1000, size=tuple(shape)).astype(
  91. integer_dict[initializer.type]
  92. )
  93. elif initializer.type == "tensor(bool)":
  94. feeds[initializer.name] = np.random.randint(2, size=tuple(shape)).astype("bool")
  95. else:
  96. print("unsupported initializer type {} for initializer {}".format(initializer.type, initializer.name))
  97. sys.exit(-1)
  98. start = timer()
  99. for i in range(num_iters):
  100. outputs = sess.run([], feeds) # fetch all outputs
  101. end = timer()
  102. print("model: {}".format(meta.graph_name))
  103. print("version: {}".format(meta.version))
  104. print("iterations: {}".format(num_iters))
  105. print("avg latency: {} ms".format(((end - start) * 1000) / num_iters))
  106. if profile:
  107. trace_file = sess.end_profiling()
  108. print("trace file written to: {}".format(trace_file))
  109. return 0, feeds, num_iters > 0 and outputs
  110. if __name__ == "__main__":
  111. parser = argparse.ArgumentParser(description="Simple ONNX Runtime Test Tool.")
  112. parser.add_argument("model_path", help="model path")
  113. parser.add_argument(
  114. "num_iters",
  115. nargs="?",
  116. type=int,
  117. default=1000,
  118. help="model run iterations. default=1000",
  119. )
  120. parser.add_argument(
  121. "--debug",
  122. action="store_true",
  123. help="pause execution to allow attaching a debugger.",
  124. )
  125. parser.add_argument("--profile", action="store_true", help="enable chrome timeline trace profiling.")
  126. parser.add_argument(
  127. "--symbolic_dims",
  128. default={},
  129. type=lambda s: dict(x.split("=") for x in s.split(",")),
  130. help="Comma separated name=value pairs for any symbolic dimensions in the model input. "
  131. "e.g. --symbolic_dims batch=1,seqlen=5. "
  132. "If not provided, the value of 1 will be used for all symbolic dimensions.",
  133. )
  134. args = parser.parse_args()
  135. exit_code, _, _ = run_model(args.model_path, args.num_iters, args.debug, args.profile, args.symbolic_dims)
  136. sys.exit(exit_code)