m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

552 lines
27 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. import logging
  7. import tempfile
  8. from pathlib import Path
  9. from .calibrate import CalibrationDataReader, CalibrationMethod, create_calibrator
  10. from .onnx_quantizer import ONNXQuantizer
  11. from .qdq_quantizer import QDQQuantizer
  12. from .quant_utils import QuantFormat, QuantizationMode, QuantType, load_model, model_has_pre_process_metadata
  13. from .registry import IntegerOpsRegistry, QLinearOpsRegistry
  14. class QuantConfig:
  15. def __init__(
  16. self,
  17. activation_type=QuantType.QUInt8,
  18. weight_type=QuantType.QInt8,
  19. op_types_to_quantize=None,
  20. nodes_to_quantize=None,
  21. nodes_to_exclude=None,
  22. per_channel=False,
  23. reduce_range=False,
  24. optimize_model=True,
  25. use_external_data_format=False,
  26. ):
  27. """
  28. This is the Base class for both Static and Dynamic Quantize Configuration
  29. Args:
  30. activation_type:
  31. quantization data type of activation. Please refer to
  32. https://onnxruntime.ai/docs/performance/quantization.html for more details on data type selection
  33. weight_type:
  34. quantization data type of weight. Please refer to
  35. https://onnxruntime.ai/docs/performance/quantization.html for more details on data type selection
  36. op_types_to_quantize:
  37. specify the types of operators to quantize, like ['Conv'] to quantize Conv only.
  38. It quantizes all supported operators by default.
  39. nodes_to_quantize:
  40. List of nodes names to quantize. When this list is not None only the nodes in this list
  41. are quantized.
  42. example:
  43. [
  44. 'Conv__224',
  45. 'Conv__252'
  46. ]
  47. nodes_to_exclude:
  48. List of nodes names to exclude. The nodes in this list will be excluded from quantization
  49. when it is not None.
  50. per_channel: quantize weights per channel
  51. reduce_range:
  52. quantize weights with 7-bits. It may improve the accuracy for some models running on non-VNNI machine,
  53. especially for per-channel mode
  54. optimize_model: Deprecating Soon! Optimize model before quantization. NOT recommended, optimization will
  55. change the computation graph, making debugging of quantization loss difficult.
  56. use_external_data_format: option used for large size (>2GB) model. Set to False by default.
  57. """
  58. nodes_to_exclude = nodes_to_exclude or []
  59. nodes_to_quantize = nodes_to_quantize or []
  60. op_types_to_quantize = op_types_to_quantize or []
  61. self.op_types_to_quantize = op_types_to_quantize
  62. self.per_channel = per_channel
  63. self.reduce_range = reduce_range
  64. self.weight_type = weight_type
  65. self.activation_type = activation_type
  66. self.nodes_to_quantize = nodes_to_quantize
  67. self.nodes_to_exclude = nodes_to_exclude
  68. self.optimize_model = optimize_model
  69. self.use_external_data_format = use_external_data_format
  70. class StaticQuantConfig(QuantConfig):
  71. def __init__(
  72. self,
  73. calibration_data_reader: CalibrationDataReader,
  74. calibrate_method=CalibrationMethod.MinMax,
  75. quant_format=QuantFormat.QDQ,
  76. activation_type=QuantType.QInt8,
  77. weight_type=QuantType.QInt8,
  78. op_types_to_quantize=None,
  79. nodes_to_quantize=None,
  80. nodes_to_exclude=None,
  81. per_channel=False,
  82. reduce_range=False,
  83. optimize_model=True,
  84. use_external_data_format=False,
  85. extra_options=None,
  86. ):
  87. """
  88. This is the derived class for static Quantize Configuration
  89. Args:
  90. calibration_data_reader:
  91. a calibration data reader. It enumerates calibration data and generates inputs for the original model.
  92. calibrate_method:
  93. Current calibration methods supported are MinMax, Entropy and Percentile.
  94. quant_format: QuantFormat{QOperator, QDQ}.
  95. QOperator format quantizes the model with quantized operators directly.
  96. QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
  97. extra_options:
  98. key value pair dictionary for various options in different case. Current used:
  99. extra.Sigmoid.nnapi = True/False (Default is False)
  100. ActivationSymmetric = True/False: symmetrize calibration data for activations (default is False).
  101. WeightSymmetric = True/False: symmetrize calibration data for weights (default is True).
  102. EnableSubgraph = True/False : Default is False. If enabled, subgraph will be quantized.
  103. Dyanmic mode currently is supported. Will support more in future.
  104. ForceQuantizeNoInputCheck = True/False :
  105. By default, some latent operators like maxpool, transpose, do not quantize if their input is not
  106. quantized already. Setting to True to force such operator always quantize input and so generate
  107. quantized output. Also the True behavior could be disabled per node using the nodes_to_exclude.
  108. MatMulConstBOnly = True/False:
  109. Default is False for static mode. If enabled, only MatMul with const B will be quantized.
  110. AddQDQPairToWeight = True/False :
  111. Default is False which quantizes floating-point weight and feeds it to solely inserted
  112. DeQuantizeLinear node. If True, it remains floating-point weight and inserts both
  113. QuantizeLinear/DeQuantizeLinear nodes to weight.
  114. OpTypesToExcludeOutputQuantization = list of op type :
  115. Default is []. If any op type is specified, it won't quantize the output of ops with this
  116. specific op types.
  117. DedicatedQDQPair = True/False :
  118. Default is False. When inserting QDQ pair, multiple nodes can share a single QDQ pair as their
  119. inputs. If True, it will create identical and dedicated QDQ pair for each node.
  120. QDQOpTypePerChannelSupportToAxis = dictionary :
  121. Default is {}. Set channel axis for specific op type, for example: {'MatMul': 1}, and it's
  122. effective only when per channel quantization is supported and per_channel is True. If specific
  123. op type supports per channel quantization but not explicitly specified with channel axis,
  124. default channel axis will be used.
  125. CalibTensorRangeSymmetric = True/False :
  126. Default is False. If enabled, the final range of tensor during calibration will be explicitly
  127. set to symmetric to central point "0".
  128. CalibMovingAverage = True/False :
  129. Default is False. If enabled, the moving average of the minimum and maximum values will be
  130. computed when the calibration method selected is MinMax.
  131. CalibMovingAverageConstant = float :
  132. Default is 0.01. Constant smoothing factor to use when computing the moving average of the
  133. minimum and maximum values. Effective only when the calibration method selected is MinMax and
  134. when CalibMovingAverage is set to True.
  135. execution_provider : A enum indicates the Execution Provider such as: CPU, TRT, NNAPI, SNE, etc.
  136. Raises:
  137. ValueError: Raise ValueError if execution provider is unknown
  138. """
  139. super().__init__(
  140. activation_type=activation_type,
  141. weight_type=weight_type,
  142. op_types_to_quantize=op_types_to_quantize,
  143. nodes_to_quantize=nodes_to_quantize,
  144. nodes_to_exclude=nodes_to_exclude,
  145. per_channel=per_channel,
  146. reduce_range=reduce_range,
  147. optimize_model=optimize_model,
  148. use_external_data_format=use_external_data_format,
  149. )
  150. self.calibration_data_reader = calibration_data_reader
  151. self.calibrate_method = calibrate_method
  152. self.quant_format = quant_format
  153. self.extra_options = extra_options or {}
  154. class DynamicQuantConfig(QuantConfig):
  155. def __init__(
  156. self,
  157. weight_type=QuantType.QInt8,
  158. op_types_to_quantize=None,
  159. nodes_to_quantize=None,
  160. nodes_to_exclude=None,
  161. per_channel=False,
  162. reduce_range=False,
  163. optimize_model=True,
  164. use_external_data_format=False,
  165. extra_options=None,
  166. ):
  167. """
  168. This is a class for dynamic Quant Configuration
  169. Args:
  170. extra_options: key value pair dictionary for various options in different case. Current used:
  171. extra.Sigmoid.nnapi = True/False (Default is False)
  172. ActivationSymmetric = True/False: symmetrize calibration data for activations (default is False).
  173. WeightSymmetric = True/False: symmetrize calibration data for weights (default is True).
  174. EnableSubgraph = True/False :
  175. Default is False. If enabled, subgraph will be quantized. Dynamic mode currently is supported. Will
  176. support more in the future.
  177. ForceQuantizeNoInputCheck = True/False :
  178. By default, some latent operators like maxpool, transpose, do not quantize if their input is not
  179. quantized already. Setting to True to force such operator always quantize input and so generate
  180. quantized output. Also the True behavior could be disabled per node using the nodes_to_exclude.
  181. MatMulConstBOnly = True/False:
  182. Default is True for dynamic mode. If enabled, only MatMul with const B will be quantized.
  183. execution_provider : A enum indicates the Execution Provider such as: CPU, TRT, NNAPI, SNE, etc.
  184. Raises:
  185. ValueError: Raise ValueError if execution provider is unknown
  186. """
  187. super().__init__(
  188. op_types_to_quantize=op_types_to_quantize,
  189. per_channel=per_channel,
  190. reduce_range=reduce_range,
  191. weight_type=weight_type,
  192. nodes_to_quantize=nodes_to_quantize,
  193. nodes_to_exclude=nodes_to_exclude,
  194. optimize_model=optimize_model,
  195. use_external_data_format=use_external_data_format,
  196. )
  197. self.extra_options = extra_options or {}
  198. def check_static_quant_arguments(quant_format: QuantFormat, activation_type: QuantType, weight_type: QuantType):
  199. if activation_type == QuantType.QInt8 and weight_type == QuantType.QUInt8:
  200. raise ValueError(
  201. "ONNXRuntime quantization doesn't support data format:"
  202. "activation_type=QuantType.QInt8, weight_type = QuantType.QUInt8"
  203. )
  204. if activation_type == QuantType.QInt8 and weight_type == QuantType.QInt8 and quant_format != QuantFormat.QDQ:
  205. logging.warning(
  206. "Please use QuantFormat.QDQ for activation type QInt8 and weight type QInt8. "
  207. "Or it will lead to bad performance on x64."
  208. )
  209. def quantize_static(
  210. model_input,
  211. model_output,
  212. calibration_data_reader: CalibrationDataReader,
  213. quant_format=QuantFormat.QDQ,
  214. op_types_to_quantize=None,
  215. per_channel=False,
  216. reduce_range=False,
  217. activation_type=QuantType.QInt8,
  218. weight_type=QuantType.QInt8,
  219. nodes_to_quantize=None,
  220. nodes_to_exclude=None,
  221. optimize_model=True,
  222. use_external_data_format=False,
  223. calibrate_method=CalibrationMethod.MinMax,
  224. extra_options=None,
  225. ):
  226. """
  227. Given an onnx model and calibration data reader, create a quantized onnx model and save it into a file
  228. It is recommended to use QuantFormat.QDQ format from 1.11 with activation_type = QuantType.QInt8 and weight_type
  229. = QuantType.QInt8. If model is targeted to GPU/TRT, symmetric activation and weight are required. If model is
  230. targeted to CPU, asymmetric activation and symmetric weight are recommended for balance of performance and
  231. accuracy.
  232. Args:
  233. model_input: file path of model to quantize
  234. model_output: file path of quantized model
  235. calibration_data_reader: a calibration data reader. It
  236. enumerates calibration data and generates inputs for the
  237. original model.
  238. quant_format: QuantFormat{QOperator, QDQ}.
  239. QOperator format quantizes the model with quantized operators directly.
  240. QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
  241. activation_type:
  242. quantization data type of activation. Please refer to
  243. https://onnxruntime.ai/docs/performance/quantization.html for more details on data type selection
  244. calibrate_method:
  245. Current calibration methods supported are MinMax and Entropy.
  246. Please use CalibrationMethod.MinMax or CalibrationMethod.Entropy as options.
  247. op_types_to_quantize:
  248. specify the types of operators to quantize, like ['Conv'] to quantize Conv only.
  249. It quantizes all supported operators by default.
  250. per_channel: quantize weights per channel
  251. reduce_range:
  252. quantize weights with 7-bits. It may improve the accuracy for some models running on non-VNNI machine,
  253. especially for per-channel mode
  254. weight_type:
  255. quantization data type of weight. Please refer to
  256. https://onnxruntime.ai/docs/performance/quantization.html for more details on data type selection
  257. nodes_to_quantize:
  258. List of nodes names to quantize. When this list is not None only the nodes in this list
  259. are quantized.
  260. example:
  261. [
  262. 'Conv__224',
  263. 'Conv__252'
  264. ]
  265. nodes_to_exclude:
  266. List of nodes names to exclude. The nodes in this list will be excluded from quantization
  267. when it is not None.
  268. optimize_model: Deprecating Soon! Optimize model before quantization. NOT recommended, optimization will
  269. change the computation graph, making debugging of quantization loss difficult.
  270. use_external_data_format: option used for large size (>2GB) model. Set to False by default.
  271. extra_options:
  272. key value pair dictionary for various options in different case. Current used:
  273. extra.Sigmoid.nnapi = True/False (Default is False)
  274. ActivationSymmetric = True/False: symmetrize calibration data for activations (default is False).
  275. WeightSymmetric = True/False: symmetrize calibration data for weights (default is True).
  276. EnableSubgraph = True/False : Default is False. If enabled, subgraph will be quantized.
  277. Dyanmic mode currently is supported. Will support more in the future.
  278. ForceQuantizeNoInputCheck = True/False :
  279. By default, some latent operators like maxpool, transpose, do not quantize if their input is not
  280. quantized already. Setting to True to force such operator always quantize input and so generate
  281. quantized output. Also, the True behavior could be disabled per node using the nodes_to_exclude.
  282. MatMulConstBOnly = True/False:
  283. Default is False for static mode. If enabled, only MatMul with const B will be quantized.
  284. AddQDQPairToWeight = True/False :
  285. Default is False which quantizes floating-point weight and feeds it to solely inserted
  286. DeQuantizeLinear node. If True, it remains floating-point weight and inserts both
  287. QuantizeLinear/DeQuantizeLinear nodes to weight.
  288. OpTypesToExcludeOutputQuantization = list of op type :
  289. Default is []. If any op type is specified, it won't quantize the output of ops with this
  290. specific op types.
  291. DedicatedQDQPair = True/False :
  292. Default is False. When inserting QDQ pair, multiple nodes can share a single QDQ pair as their
  293. inputs. If True, it will create identical and dedicated QDQ pair for each node.
  294. QDQOpTypePerChannelSupportToAxis = dictionary :
  295. Default is {}. Set channel axis for specific op type, for example: {'MatMul': 1}, and it's
  296. effective only when per channel quantization is supported and per_channel is True. If specific
  297. op type supports per channel quantization but not explicitly specified with channel axis,
  298. default channel axis will be used.
  299. CalibTensorRangeSymmetric = True/False :
  300. Default is False. If enabled, the final range of tensor during calibration will be explicitly
  301. set to symmetric to central point "0".
  302. CalibMovingAverage = True/False :
  303. Default is False. If enabled, the moving average of the minimum and maximum values will be
  304. computed when the calibration method selected is MinMax.
  305. CalibMovingAverageConstant = float :
  306. Default is 0.01. Constant smoothing factor to use when computing the moving average of the
  307. minimum and maximum values. Effective only when the calibration method selected is MinMax and
  308. when CalibMovingAverage is set to True.
  309. """
  310. extra_options = extra_options or {}
  311. nodes_to_exclude = nodes_to_exclude or []
  312. nodes_to_quantize = nodes_to_quantize or []
  313. op_types_to_quantize = op_types_to_quantize or []
  314. mode = QuantizationMode.QLinearOps
  315. if not op_types_to_quantize or len(op_types_to_quantize) == 0:
  316. op_types_to_quantize = list(QLinearOpsRegistry.keys())
  317. model = load_model(Path(model_input), optimize_model)
  318. pre_processed: bool = model_has_pre_process_metadata(model)
  319. if not pre_processed:
  320. logging.warning(
  321. "Please consider pre-processing before quantization. See "
  322. "https://github.com/microsoft/onnxruntime-inference-examples/blob/main/quantization/image_classification"
  323. "/cpu/ReadMe.md "
  324. )
  325. calib_extra_options_keys = [
  326. ("CalibTensorRangeSymmetric", "symmetric"),
  327. ("CalibMovingAverage", "moving_average"),
  328. ("CalibMovingAverageConstant", "averaging_constant"),
  329. ]
  330. calib_extra_options = {
  331. key: extra_options.get(name) for (name, key) in calib_extra_options_keys if name in extra_options
  332. }
  333. with tempfile.TemporaryDirectory(prefix="ort.quant.") as quant_tmp_dir:
  334. calibrator = create_calibrator(
  335. model,
  336. op_types_to_quantize,
  337. augmented_model_path=Path(quant_tmp_dir).joinpath("augmented_model.onnx").as_posix(),
  338. calibrate_method=calibrate_method,
  339. use_external_data_format=use_external_data_format,
  340. extra_options=calib_extra_options,
  341. )
  342. calibrator.collect_data(calibration_data_reader)
  343. tensors_range = calibrator.compute_range()
  344. del calibrator
  345. check_static_quant_arguments(quant_format, activation_type, weight_type)
  346. if quant_format is QuantFormat.QOperator:
  347. quantizer = ONNXQuantizer(
  348. model,
  349. per_channel,
  350. reduce_range,
  351. mode,
  352. True, # static
  353. weight_type,
  354. activation_type,
  355. tensors_range,
  356. nodes_to_quantize,
  357. nodes_to_exclude,
  358. op_types_to_quantize,
  359. extra_options,
  360. )
  361. else:
  362. quantizer = QDQQuantizer(
  363. model,
  364. per_channel,
  365. reduce_range,
  366. mode,
  367. True, # static
  368. weight_type,
  369. activation_type,
  370. tensors_range,
  371. nodes_to_quantize,
  372. nodes_to_exclude,
  373. op_types_to_quantize,
  374. extra_options,
  375. )
  376. quantizer.quantize_model()
  377. quantizer.model.save_model_to_file(model_output, use_external_data_format)
  378. if not pre_processed:
  379. logging.warning(
  380. "Please consider pre-processing before quantization. See "
  381. "https://github.com/microsoft/onnxruntime-inference-examples/blob/main/quantization/image_classification"
  382. "/cpu/ReadMe.md "
  383. )
  384. def quantize_dynamic(
  385. model_input: Path,
  386. model_output: Path,
  387. op_types_to_quantize=None,
  388. per_channel=False,
  389. reduce_range=False,
  390. weight_type=QuantType.QInt8,
  391. nodes_to_quantize=None,
  392. nodes_to_exclude=None,
  393. optimize_model=True,
  394. use_external_data_format=False,
  395. extra_options=None,
  396. ):
  397. """Given an onnx model, create a quantized onnx model and save it into a file
  398. Args:
  399. model_input: file path of model to quantize
  400. model_output: file path of quantized model
  401. op_types_to_quantize:
  402. specify the types of operators to quantize, like ['Conv'] to quantize Conv only.
  403. It quantizes all supported operators by default.
  404. per_channel: quantize weights per channel
  405. reduce_range:
  406. quantize weights with 7-bits. It may improve the accuracy for some models running on non-VNNI machine,
  407. especially for per-channel mode
  408. weight_type:
  409. quantization data type of weight. Please refer to
  410. https://onnxruntime.ai/docs/performance/quantization.html for more details on data type selection
  411. nodes_to_quantize:
  412. List of nodes names to quantize. When this list is not None only the nodes in this list
  413. are quantized.
  414. example:
  415. [
  416. 'Conv__224',
  417. 'Conv__252'
  418. ]
  419. nodes_to_exclude:
  420. List of nodes names to exclude. The nodes in this list will be excluded from quantization
  421. when it is not None.
  422. optimize_model: Deprecating Soon! Optimize model before quantization. NOT recommended, optimization will
  423. change the computation graph, making debugging of quantization loss difficult.
  424. use_external_data_format: option used for large size (>2GB) model. Set to False by default.
  425. extra_options:
  426. key value pair dictionary for various options in different case. Current used:
  427. extra.Sigmoid.nnapi = True/False (Default is False)
  428. ActivationSymmetric = True/False: symmetrize calibration data for activations (default is False).
  429. WeightSymmetric = True/False: symmetrize calibration data for weights (default is True).
  430. EnableSubgraph = True/False :
  431. Default is False. If enabled, subgraph will be quantized. Dynamic mode currently is supported. Will
  432. support more in the future.
  433. ForceQuantizeNoInputCheck = True/False :
  434. By default, some latent operators like maxpool, transpose, do not quantize if their input is not
  435. quantized already. Setting to True to force such operator always quantize input and so generate
  436. quantized output. Also the True behavior could be disabled per node using the nodes_to_exclude.
  437. MatMulConstBOnly = True/False:
  438. Default is True for dynamic mode. If enabled, only MatMul with const B will be quantized.
  439. """
  440. extra_options = extra_options or {}
  441. nodes_to_exclude = nodes_to_exclude or []
  442. nodes_to_quantize = nodes_to_quantize or []
  443. op_types_to_quantize = op_types_to_quantize or []
  444. mode = QuantizationMode.IntegerOps
  445. if not op_types_to_quantize or len(op_types_to_quantize) == 0:
  446. op_types_to_quantize = list(IntegerOpsRegistry.keys())
  447. model = load_model(Path(model_input), optimize_model)
  448. if "MatMulConstBOnly" not in extra_options:
  449. extra_options["MatMulConstBOnly"] = True
  450. quantizer = ONNXQuantizer(
  451. model,
  452. per_channel,
  453. reduce_range,
  454. mode,
  455. False, # static
  456. weight_type,
  457. QuantType.QUInt8, # dynamic activation only supports uint8
  458. None,
  459. nodes_to_quantize,
  460. nodes_to_exclude,
  461. op_types_to_quantize,
  462. extra_options,
  463. )
  464. quantizer.quantize_model()
  465. quantizer.model.save_model_to_file(model_output, use_external_data_format)
  466. def quantize(
  467. model_input: Path,
  468. model_output: Path,
  469. quant_config: QuantConfig,
  470. ):
  471. """Quantize a model with QuantConfig.
  472. Args:
  473. model_input (Path): Path to the model to quantize.
  474. model_output (Path): Path to save the quantized model.
  475. quant_config (QuantConfig): Quantization Configuration.
  476. """
  477. if isinstance(quant_config, StaticQuantConfig):
  478. quantize_static(
  479. model_input,
  480. model_output,
  481. quant_config.calibration_data_reader,
  482. calibrate_method=quant_config.calibrate_method,
  483. quant_format=quant_config.quant_format,
  484. activation_type=quant_config.activation_type,
  485. weight_type=quant_config.weight_type,
  486. op_types_to_quantize=quant_config.op_types_to_quantize,
  487. nodes_to_quantize=quant_config.nodes_to_quantize,
  488. nodes_to_exclude=quant_config.nodes_to_exclude,
  489. per_channel=quant_config.per_channel,
  490. reduce_range=quant_config.reduce_range,
  491. optimize_model=quant_config.optimize_model,
  492. use_external_data_format=quant_config.use_external_data_format,
  493. extra_options=quant_config.extra_options,
  494. )
  495. elif isinstance(quant_config, DynamicQuantConfig):
  496. quantize_dynamic(
  497. model_input,
  498. model_output,
  499. weight_type=quant_config.weight_type,
  500. op_types_to_quantize=quant_config.op_types_to_quantize,
  501. nodes_to_quantize=quant_config.nodes_to_quantize,
  502. nodes_to_exclude=quant_config.nodes_to_exclude,
  503. per_channel=quant_config.per_channel,
  504. reduce_range=quant_config.reduce_range,
  505. optimize_model=quant_config.optimize_model,
  506. use_external_data_format=quant_config.use_external_data_format,
  507. extra_options=quant_config.extra_options,
  508. )
  509. else:
  510. raise TypeError("Invalid quantization config type, it must be either StaticQuantConfig or DynamicQuantConfig.")