m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

430 lines
18 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. import logging
  7. from enum import Enum
  8. import onnx
  9. import onnx.numpy_helper
  10. from onnx import TensorProto
  11. from onnx import onnx_pb as onnx_proto
  12. from .onnx_quantizer import ONNXQuantizer
  13. from .quant_utils import (
  14. DEQUANT_OP_NAME,
  15. QUANT_OP_NAME,
  16. QuantizedValue,
  17. QuantizedValueType,
  18. __producer__,
  19. __version__,
  20. add_dequant_output_suffix,
  21. add_dequant_suffix,
  22. add_quant_input_suffix,
  23. add_quant_output_suffix,
  24. add_quant_suffix,
  25. find_by_name,
  26. )
  27. from .registry import CreateQDQQuantizer
  28. class QDQQuantTensorType(Enum):
  29. ACTIVATION = 0
  30. WEIGHT = 1
  31. BIAS = 2
  32. class QDQTensorQuantInfo:
  33. def __init__(self, tensor_type=QDQQuantTensorType.ACTIVATION, quant_para_provider=None, axis=None):
  34. self.tensor_type = tensor_type
  35. self.quant_para_provider = quant_para_provider
  36. self.axis = axis
  37. self.is_shared = quant_para_provider is not None
  38. class QDQQuantizer(ONNXQuantizer):
  39. def __init__(
  40. self,
  41. model,
  42. per_channel,
  43. reduce_range,
  44. mode,
  45. static,
  46. weight_qType,
  47. activation_qType,
  48. tensors_range,
  49. nodes_to_quantize,
  50. nodes_to_exclude,
  51. op_types_to_quantize,
  52. extra_options=None,
  53. ):
  54. ONNXQuantizer.__init__(
  55. self,
  56. model,
  57. per_channel,
  58. reduce_range,
  59. mode,
  60. static,
  61. weight_qType,
  62. activation_qType,
  63. tensors_range,
  64. nodes_to_quantize,
  65. nodes_to_exclude,
  66. op_types_to_quantize,
  67. extra_options,
  68. )
  69. self.tensors_to_quantize = {}
  70. self.bias_to_quantize = []
  71. self.nodes_to_remove = []
  72. # Specific op types to exclude qdq quantization for their outputs.
  73. # In TRT, it's not recommended to quantize outputs for weighted ops such as Conv, Matmul, Gemm
  74. # because those ops may be followed by nodes that require high resolution inputs.
  75. # Adding QDQ for those ops' output may end up with worse accuracy.
  76. # So, we don't recommend to add QDQ to node's output under such condition.
  77. self.op_types_to_exclude_output_quantization = (
  78. []
  79. if "OpTypesToExcludeOutputQuantization" not in extra_options
  80. else extra_options["OpTypesToExcludeOutputQuantization"]
  81. )
  82. # We do quantization on Dequantizelinear's input to remove Quantizelinear for weight as an optimization.
  83. # In some cases, for example QDQ BERT model for TensorRT, QDQ should always appear as a pair.
  84. # Therefore, we need to disable this optimization and add qdq pair to weight.
  85. self.add_qdq_pair_to_weight = (
  86. False if "AddQDQPairToWeight" not in extra_options else extra_options["AddQDQPairToWeight"]
  87. )
  88. # The default behavior is that multiple nodes can share a QDQ pair as their inputs.
  89. # In TRT, QDQ pair can’t be shared between nodes, so it will create dedicated QDQ pairs for each node.
  90. self.dedicated_qdq_pair = (
  91. False if "DedicatedQDQPair" not in extra_options else extra_options["DedicatedQDQPair"]
  92. )
  93. if self.dedicated_qdq_pair:
  94. self.tensor_to_its_receiving_nodes = {}
  95. # Let user set channel axis for specific op type and it's effective only when per channel quantization is supported and per_channel is True.
  96. self.qdq_op_type_per_channel_support_to_axis = (
  97. {}
  98. if "QDQOpTypePerChannelSupportToAxis" not in extra_options
  99. else extra_options["QDQOpTypePerChannelSupportToAxis"]
  100. )
  101. def _is_tensor_quantizable(self, tensor_name):
  102. """
  103. Check if tensor can be quantized
  104. """
  105. weight = find_by_name(tensor_name, self.model.initializer())
  106. if weight is not None:
  107. if weight.data_type == onnx_proto.TensorProto.FLOAT:
  108. return True
  109. elif tensor_name in self.value_infos.keys():
  110. vi = self.value_infos[tensor_name]
  111. if vi.type.HasField("tensor_type") and vi.type.tensor_type.elem_type == TensorProto.FLOAT:
  112. return True
  113. else:
  114. logging.warning(
  115. "failed to infer the type of tensor: {}. Skip to quantize it. Please check if it is expected.".format(
  116. tensor_name
  117. )
  118. )
  119. return False
  120. def __quantize_tensor(self, tensor_name, quant_sharing_param=None, tensor_type=QDQQuantTensorType.ACTIVATION):
  121. """
  122. Quantize tensors. If quant_param_tensor is not None, tensor with name tensor_name will be quantized with same
  123. quantization parameters as tensor quant_param_tensor
  124. Args:
  125. tensor_name: name of the tensor to quantize
  126. quant_sharing_param: name of the tensor that provides quantization parameter
  127. tensor_type: QDQQuantTensorType default ACTIVATION
  128. """
  129. if self._is_tensor_quantizable(tensor_name):
  130. if quant_sharing_param:
  131. self.tensors_to_quantize[tensor_name] = QDQTensorQuantInfo(
  132. tensor_type=tensor_type, quant_para_provider=quant_sharing_param
  133. )
  134. elif tensor_name not in self.tensors_to_quantize:
  135. self.tensors_to_quantize[tensor_name] = QDQTensorQuantInfo(tensor_type=tensor_type)
  136. def quantize_activation_tensor(self, tensor_name, quant_sharing_param=None):
  137. """
  138. Quantize Activation Tensor
  139. Args:
  140. tensor_name: name of the tensor to quantize
  141. quant_sharing_param: name of the tensor that provides quantization parameter
  142. """
  143. return self.__quantize_tensor(tensor_name, quant_sharing_param, QDQQuantTensorType.ACTIVATION)
  144. def quantize_weight_tensor(self, tensor_name, quant_sharing_param=None):
  145. """
  146. Quantize Weight Tensor
  147. Args:
  148. tensor_name: name of the tensor to quantize
  149. quant_sharing_param: name of the tensor that provides quantization parameter
  150. """
  151. return self.__quantize_tensor(tensor_name, quant_sharing_param, QDQQuantTensorType.WEIGHT)
  152. def quantize_weight_tensor_per_channel(self, tensor_name, axis):
  153. weight = find_by_name(tensor_name, self.model.initializer())
  154. if weight:
  155. if weight.data_type == onnx_proto.TensorProto.FLOAT:
  156. self.tensors_to_quantize[tensor_name] = QDQTensorQuantInfo(
  157. tensor_type=QDQQuantTensorType.WEIGHT, axis=axis
  158. )
  159. else:
  160. logging.warning(
  161. "only support per-channel quantization on weight. Tensor: {} is not quantized.".format(tensor_name)
  162. )
  163. def quantize_bias_tensor(self, bias_name, input_name, weight_name, beta=1.0):
  164. weight = find_by_name(bias_name, self.model.initializer())
  165. if weight is not None:
  166. if weight.data_type == onnx_proto.TensorProto.FLOAT:
  167. self.bias_to_quantize.append((bias_name, input_name, weight_name, beta))
  168. else:
  169. logging.warning("Expected {} to be a weight".format(bias_name))
  170. def remove_node(self, node):
  171. self.nodes_to_remove.append(node)
  172. def remove_nodes(self):
  173. self.model.remove_nodes(self.nodes_to_remove)
  174. def quantize_model(self):
  175. for node in self.model.nodes():
  176. if self.should_quantize_node(node):
  177. op_quantizer = CreateQDQQuantizer(self, node)
  178. op_quantizer.quantize()
  179. if self.dedicated_qdq_pair:
  180. for tensor_name in node.input:
  181. if tensor_name not in self.tensor_to_its_receiving_nodes:
  182. self.tensor_to_its_receiving_nodes[tensor_name] = []
  183. self.tensor_to_its_receiving_nodes[tensor_name].append(node)
  184. self._quantize_normal_tensors()
  185. self._quantize_sharing_param_tensors()
  186. self._quantize_bias_tensors()
  187. self.remove_nodes()
  188. if not self.add_qdq_pair_to_weight:
  189. self.model.clean_initializers()
  190. self.model.model.producer_name = __producer__
  191. self.model.model.producer_version = __version__
  192. return self.model.model
  193. def try_replacing_upstream_output(self, upstream_output_name, output_name):
  194. if (
  195. output_name in self.quantization_params.keys()
  196. and len(self.model.input_name_to_nodes()[upstream_output_name]) == 1
  197. and not self.model.is_graph_output(upstream_output_name)
  198. and not self.model.is_graph_input(upstream_output_name)
  199. ):
  200. self.model.replace_output_of_all_nodes(upstream_output_name, output_name)
  201. if upstream_output_name in self.tensors_to_quantize:
  202. del self.tensors_to_quantize[upstream_output_name]
  203. return True
  204. return False
  205. def _create_qdq_nodes(
  206. self, q_input, q_output, quant_node_name, dq_input, dq_output, dequant_node_name, scale_name, zp_name, axis=None
  207. ):
  208. qlinear_node = onnx.helper.make_node(
  209. QUANT_OP_NAME,
  210. [q_input, scale_name, zp_name],
  211. [q_output],
  212. quant_node_name,
  213. axis=axis,
  214. )
  215. dequant_node = onnx.helper.make_node(
  216. DEQUANT_OP_NAME,
  217. [dq_input, scale_name, zp_name],
  218. [dq_output],
  219. dequant_node_name,
  220. axis=axis,
  221. )
  222. self.model.add_nodes([qlinear_node, dequant_node])
  223. def _add_qdq_pair_for_initializer(self, weight_proto, tensor_type, axis=None):
  224. weight_name = weight_proto.name
  225. if axis is not None:
  226. if self.opset_version < 13:
  227. raise ValueError("Per-Channel support with QDQ format requires onnx opset version 13 or above.")
  228. q_weight_name, zp_name, scale_name = self.quantize_weight_per_channel(
  229. weight_name, onnx_proto.TensorProto.INT8, axis, keep_float_weight=self.add_qdq_pair_to_weight
  230. )
  231. else:
  232. q_weight_name, zp_name, scale_name = self.quantize_initializer(
  233. weight_proto,
  234. self.weight_qType if tensor_type is QDQQuantTensorType.WEIGHT else self.activation_qType,
  235. keep_float_weight=self.add_qdq_pair_to_weight,
  236. )
  237. weight_dequant_output = add_dequant_output_suffix(weight_name)
  238. self.model.replace_input_of_all_nodes(weight_name, weight_dequant_output)
  239. if self.add_qdq_pair_to_weight:
  240. weight_quant_output = add_quant_output_suffix(weight_name)
  241. self._create_qdq_nodes(
  242. weight_name,
  243. weight_quant_output,
  244. add_quant_suffix(weight_name),
  245. weight_quant_output,
  246. weight_dequant_output,
  247. add_dequant_suffix(weight_name),
  248. scale_name,
  249. zp_name,
  250. axis,
  251. )
  252. else:
  253. dequant_node = onnx.helper.make_node(
  254. DEQUANT_OP_NAME,
  255. [q_weight_name, scale_name, zp_name],
  256. [weight_dequant_output],
  257. add_dequant_suffix(weight_name),
  258. axis=axis,
  259. )
  260. self.model.add_node(dequant_node)
  261. def _add_qdq_pair_for_activation(self, tensor_name, scale_name, zp_name):
  262. if (
  263. self.dedicated_qdq_pair
  264. and tensor_name in self.tensor_to_its_receiving_nodes
  265. and len(self.tensor_to_its_receiving_nodes[tensor_name]) > 1
  266. ):
  267. num_dedicated_qdq_pair = len(self.tensor_to_its_receiving_nodes[tensor_name])
  268. for i in range(num_dedicated_qdq_pair):
  269. postfix = f"_{i + 1}"
  270. tensor_name_quant_output_postfix = add_quant_output_suffix(tensor_name) + postfix
  271. tensor_name_dequant_output_postfix = add_dequant_output_suffix(tensor_name) + postfix
  272. quant_node_name_postfix = add_quant_suffix(tensor_name) + postfix
  273. dequant_node_name_postfix = add_dequant_suffix(tensor_name) + postfix
  274. self._create_qdq_nodes(
  275. tensor_name,
  276. tensor_name_quant_output_postfix,
  277. quant_node_name_postfix,
  278. tensor_name_quant_output_postfix,
  279. tensor_name_dequant_output_postfix,
  280. dequant_node_name_postfix,
  281. scale_name,
  282. zp_name,
  283. )
  284. node = self.tensor_to_its_receiving_nodes[tensor_name][i]
  285. self.model.replace_node_input(node, tensor_name, tensor_name_dequant_output_postfix)
  286. if i == 0:
  287. quantized_value = QuantizedValue(
  288. tensor_name,
  289. tensor_name_dequant_output_postfix,
  290. scale_name,
  291. zp_name,
  292. QuantizedValueType.Input,
  293. )
  294. self.quantized_value_map[tensor_name] = quantized_value
  295. else:
  296. q_input = tensor_name
  297. dq_output = add_dequant_output_suffix(tensor_name)
  298. if self.model.is_graph_output(tensor_name):
  299. q_input = add_quant_input_suffix(tensor_name)
  300. dq_output = tensor_name
  301. self.model.replace_output_of_all_nodes(tensor_name, q_input)
  302. else:
  303. self.model.replace_input_of_all_nodes(tensor_name, dq_output)
  304. self._create_qdq_nodes(
  305. q_input,
  306. add_quant_output_suffix(tensor_name),
  307. add_quant_suffix(tensor_name),
  308. add_quant_output_suffix(tensor_name),
  309. dq_output,
  310. add_dequant_suffix(tensor_name),
  311. scale_name,
  312. zp_name,
  313. )
  314. quantized_value = QuantizedValue(
  315. tensor_name,
  316. dq_output,
  317. scale_name,
  318. zp_name,
  319. QuantizedValueType.Input,
  320. )
  321. self.quantized_value_map[tensor_name] = quantized_value
  322. def _quantize_normal_tensors(self):
  323. for tensor_name, tensor_info in self.tensors_to_quantize.copy().items():
  324. if tensor_name in self.quantized_value_map.keys():
  325. continue
  326. if not tensor_info.is_shared:
  327. # Quantize the input
  328. initializer = find_by_name(tensor_name, self.model.initializer())
  329. if initializer:
  330. self._add_qdq_pair_for_initializer(initializer, tensor_info.tensor_type, tensor_info.axis)
  331. else:
  332. used_scale, used_zp = self.find_quant_scale_zp(tensor_name)
  333. data_found, scale_name, zp_name, _, _ = self._get_quantization_params(
  334. tensor_name, used_scale, used_zp
  335. )
  336. if not data_found:
  337. raise ValueError(
  338. f"Quantization parameters are not specified for param {tensor_name}. "
  339. "In static mode quantization params for inputs and outputs of nodes to be quantized are required."
  340. )
  341. self._add_qdq_pair_for_activation(tensor_name, scale_name, zp_name)
  342. del self.tensors_to_quantize[tensor_name]
  343. def _quantize_sharing_param_tensors(self):
  344. while self.tensors_to_quantize:
  345. for tensor_name, tensor_info in self.tensors_to_quantize.copy().items():
  346. tensor_provider_name = tensor_info.quant_para_provider
  347. if tensor_provider_name in self.quantized_value_map:
  348. del self.tensors_to_quantize[tensor_name]
  349. quantized_value = self.quantized_value_map[tensor_provider_name]
  350. # Quantize the input
  351. initializer = find_by_name(tensor_name, self.model.initializer())
  352. if initializer is not None:
  353. raise ValueError("Quantization parameter shared mode is not supported for weight yet")
  354. self._add_qdq_pair_for_activation(tensor_name, quantized_value.scale_name, quantized_value.zp_name)
  355. def _quantize_bias_tensors(self):
  356. for bias_name, input_name, weight_name, beta in self.bias_to_quantize:
  357. if bias_name in self.quantized_value_map.keys():
  358. continue
  359. # Quantize the input
  360. self.quantize_bias_static(bias_name, input_name, weight_name, beta)
  361. self.model.remove_initializer(find_by_name(bias_name, self.model.initializer()))
  362. quant_value = self.quantized_value_map[bias_name]
  363. inputs = [quant_value.q_name, quant_value.scale_name, quant_value.zp_name]
  364. node_name = add_dequant_suffix(bias_name)
  365. if quant_value.axis is not None:
  366. dequant_node = onnx.helper.make_node(
  367. "DequantizeLinear",
  368. inputs,
  369. [bias_name],
  370. node_name,
  371. axis=quant_value.axis,
  372. )
  373. else:
  374. dequant_node = onnx.helper.make_node(
  375. "DequantizeLinear",
  376. inputs,
  377. [bias_name],
  378. node_name,
  379. )
  380. self.model.add_node(dequant_node)
  381. def is_tensor_quantized(self, tensor_name):
  382. return tensor_name in self.tensors_to_quantize or tensor_name in self.bias_to_quantize