m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

140 lines
5.9 KiB

6 months ago
  1. # --------------------------------------------------------------------------
  2. # Copyright (c) Microsoft, Intel Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. import logging
  7. import tempfile
  8. import traceback
  9. from pathlib import Path
  10. import onnx
  11. import onnxruntime
  12. from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
  13. from .quant_utils import add_pre_process_metadata
  14. logger = logging.getLogger(__name__)
  15. def quant_pre_process(
  16. input_model_path: str,
  17. output_model_path: str,
  18. skip_optimization: bool = False,
  19. skip_onnx_shape: bool = False,
  20. skip_symbolic_shape: bool = False,
  21. auto_merge: bool = False,
  22. int_max: int = 2**31 - 1,
  23. guess_output_rank: bool = False,
  24. verbose: int = 0,
  25. save_as_external_data: bool = False,
  26. all_tensors_to_one_file: bool = False,
  27. external_data_location: str = "./",
  28. external_data_size_threshold: int = 1024,
  29. ) -> None:
  30. """Shape inference and model optimization, in preparation for quantization.
  31. Args:
  32. input_model_path: Path to the input model file")
  33. output_model_path: Path to the output model file
  34. skip_optimization: Skip model optimization step if true. This may result in ONNX shape
  35. inference failure for some models.
  36. skip_onnx_shape: Skip ONNX shape inference. Symbolic shape inference is most effective
  37. with transformer based models. Skipping all shape inferences may
  38. reduce the effectiveness of quantization, as a tensor with unknown
  39. shape can not be quantized.
  40. skip_symbolic_shape: Skip symbolic shape inference. Symbolic shape inference is most
  41. effective with transformer based models. Skipping all shape
  42. inferences may reduce the effectiveness of quantization, as a tensor
  43. with unknown shape can not be quantized.
  44. auto_merge: For symbolic shape inference, automatically merge symbolic dims when
  45. conflict happens.
  46. int_max: For symbolic shape inference, specify the maximum value for integer to be
  47. treated as boundless for ops like slice
  48. guess_output_rank: Guess output rank to be the same as input 0 for unknown ops
  49. verbose: Logs detailed info of inference, 0: turn off, 1: warnings, 3: detailed
  50. save_as_external_data: Saving an ONNX model to external data
  51. all_tensors_to_one_file: Saving all the external data to one file
  52. external_data_location: The file location to save the external file
  53. external_data_size_threshold: The size threshold for external data
  54. """
  55. with tempfile.TemporaryDirectory(prefix="pre.quant.") as quant_tmp_dir:
  56. temp_path = Path(quant_tmp_dir)
  57. model = None
  58. if not skip_symbolic_shape:
  59. logger.info("Performing symbolic shape inference...")
  60. model = SymbolicShapeInference.infer_shapes(
  61. onnx.load(input_model_path),
  62. int_max,
  63. auto_merge,
  64. guess_output_rank,
  65. verbose,
  66. )
  67. if not skip_optimization:
  68. # Use ORT optimizers (native code) to optimize model
  69. if not skip_symbolic_shape:
  70. # Need to save the inferenced model to file so as to run the optimizer
  71. input_model_path = str(temp_path / "symbolic_shape_inferred.onnx")
  72. onnx.save(model, input_model_path)
  73. model = None
  74. opt_model_path = str(temp_path / "optimized.onnx")
  75. try:
  76. sess_option = onnxruntime.SessionOptions()
  77. sess_option.optimized_model_filepath = opt_model_path
  78. sess_option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC
  79. _ = onnxruntime.InferenceSession(input_model_path, sess_option, providers=["CPUExecutionProvider"])
  80. except Exception as e:
  81. logger.error(
  82. "ONNX Runtime Model Optimization Failed! Consider rerun with option `--skip_optimization'."
  83. )
  84. logger.error(traceback.format_exc())
  85. input_model_path = opt_model_path
  86. if not skip_onnx_shape:
  87. # ONNX shape inference.
  88. # According to docs, infer_shapes_path should be used for 2G+ models.
  89. # If the skip optimization is specified, we could be dealing with a
  90. # large model. So be on the safe side, save the model
  91. if model is not None:
  92. input_model_path = str(temp_path / "symbolic_shape_inferred.onnx")
  93. if save_as_external_data:
  94. onnx.save_model(
  95. model,
  96. input_model_path,
  97. save_as_external_data=True,
  98. all_tensors_to_one_file=all_tensors_to_one_file,
  99. size_threshold=external_data_size_threshold,
  100. convert_attribute=False,
  101. )
  102. else:
  103. onnx.save(model, input_model_path)
  104. model = None
  105. inferred_model_path = str(temp_path / "onnx_shape_inferred.onnx")
  106. onnx.shape_inference.infer_shapes_path(input_model_path, inferred_model_path)
  107. model = onnx.load(inferred_model_path)
  108. if model is None:
  109. model = onnx.load(input_model_path)
  110. add_pre_process_metadata(model)
  111. if save_as_external_data:
  112. onnx.save_model(
  113. model,
  114. output_model_path,
  115. save_as_external_data=True,
  116. all_tensors_to_one_file=all_tensors_to_one_file,
  117. location=external_data_location,
  118. size_threshold=external_data_size_threshold,
  119. convert_attribute=False,
  120. )
  121. else:
  122. onnx.save(model, output_model_path)