m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

421 lines
16 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from typing import Tuple
  7. import numpy as np
  8. from fusion_attention import AttentionMask
  9. from fusion_base import Fusion
  10. from fusion_utils import FusionUtils, NumpyHelper
  11. from onnx import NodeProto, helper
  12. from onnx_model import OnnxModel
  13. logger = getLogger(__name__)
  14. class FusionQOrderedAttention(Fusion):
  15. def __init__(
  16. self,
  17. model: OnnxModel,
  18. hidden_size: int,
  19. num_heads: int,
  20. attention_mask: AttentionMask,
  21. ):
  22. self.hidden_size = hidden_size
  23. self.num_heads = num_heads
  24. self.attention_mask = attention_mask
  25. super().__init__(model, "QOrderedAttention", "QOrderedLayerNormalization")
  26. def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> Tuple[int, int]:
  27. """Detect num_heads and hidden_size from a reshape node.
  28. Args:
  29. reshape_q (NodeProto): reshape node for Q
  30. Returns:
  31. Tuple[int, int]: num_heads and hidden_size
  32. """
  33. # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
  34. q_shape = self.model.get_initializer(reshape_q.input[1])
  35. if q_shape is None:
  36. logger.debug(f"{reshape_q.input[1]} is not initializer.")
  37. # Check if the second input to Reshape flows through a Constant node
  38. # TODO: Investigate why FusionAttention doesn't have such logic
  39. constant_node = self.model.match_parent_path(reshape_q, ["Constant"], [1])
  40. if constant_node is None:
  41. return self.num_heads, self.hidden_size # Fall back to user specified value
  42. else:
  43. constant_node = constant_node[0]
  44. if len(constant_node.attribute) != 1:
  45. return self.num_heads, self.hidden_size # Fall back to user specified value
  46. # This is assuming it is a Tensor attribute (this is a safe assumption)
  47. q_shape = constant_node.attribute[0].t
  48. q_shape_value = NumpyHelper.to_array(q_shape)
  49. if len(q_shape_value) != 4 or (q_shape_value[2] <= 0 or q_shape_value[3] <= 0):
  50. logger.debug(f"q_shape_value={q_shape_value}. Expected value are like [0, 0, num_heads, head_size].")
  51. return self.num_heads, self.hidden_size # Fall back to user specified value
  52. num_heads = q_shape_value[2]
  53. head_size = q_shape_value[3]
  54. hidden_size = num_heads * head_size
  55. if self.num_heads > 0 and num_heads != self.num_heads:
  56. if self.num_heads_warning:
  57. logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.")
  58. self.num_heads_warning = False # Do not show the warning more than once
  59. if self.hidden_size > 0 and hidden_size != self.hidden_size:
  60. if self.hidden_size_warning:
  61. logger.warning(
  62. f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value."
  63. )
  64. self.hidden_size_warning = False # Do not show the warning more than once
  65. return num_heads, hidden_size
  66. def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
  67. add_before_layernorm = self.model.match_parent_path(
  68. normalize_node,
  69. ["QuantizeLinear", "Add"],
  70. [0, 0],
  71. )
  72. if add_before_layernorm is not None:
  73. start_node = add_before_layernorm[-1]
  74. else:
  75. return
  76. # Input QDQ nodes
  77. dequantize_input = self.model.match_parent_path(
  78. start_node,
  79. ["DequantizeLinear"],
  80. [None],
  81. )
  82. if dequantize_input is None:
  83. logger.debug("fuse_qordered_attention: failed to match input qdq nodes path")
  84. return
  85. dequantize_input = dequantize_input[-1]
  86. # QKV nodes
  87. qkv_nodes = self.model.match_parent_path(
  88. start_node,
  89. ["Add", "MatMul", "Reshape", "Transpose", "DequantizeLinear", "QuantizeLinear", "MatMul"],
  90. [None, None, 0, 0, 0, 0, 0],
  91. )
  92. if qkv_nodes is None:
  93. logger.debug("fuse_qordered_attention: failed to match qkv path")
  94. return
  95. (_, projection_matmul, reshape_qkv, transpose_qkv, dequantize_qkv, quantize_qkv, matmul_qkv) = qkv_nodes
  96. # Make sure the Q/DQ has the proper zero points and constant per-tensor scales
  97. if not FusionUtils.check_qdq_node_for_fusion(quantize_qkv, self.model):
  98. return
  99. if not FusionUtils.check_qdq_node_for_fusion(dequantize_qkv, self.model):
  100. return
  101. # Identify the root input to the Attention node
  102. other_inputs = []
  103. for i, input in enumerate(start_node.input):
  104. if input not in output_name_to_node:
  105. continue
  106. if input == qkv_nodes[0].output[0]:
  107. continue
  108. other_inputs.append(input)
  109. if len(other_inputs) != 1:
  110. return
  111. root_input = other_inputs[0]
  112. # V nodes
  113. v_nodes = self.model.match_parent_path(
  114. matmul_qkv,
  115. ["Transpose", "Reshape", "DequantizeLinear", "QuantizeLinear", "Add", "MatMul"],
  116. [1, 0, 0, 0, 0, None],
  117. )
  118. if v_nodes is None:
  119. logger.debug("fuse_qordered_attention: failed to match v path")
  120. return
  121. (_, _, dequantize_v, quantize_v, add_v, matmul_v) = v_nodes
  122. # Make sure the Q/DQ has the proper zero points and constant per-tensor scales
  123. if not FusionUtils.check_qdq_node_for_fusion(quantize_v, self.model):
  124. return
  125. if not FusionUtils.check_qdq_node_for_fusion(dequantize_v, self.model):
  126. return
  127. # V MatMul weight
  128. dequantize_v_matmul_weight = self.model.match_parent_path(matmul_v, ["DequantizeLinear"], [1])
  129. if dequantize_v_matmul_weight is None:
  130. logger.debug("fuse_qordered_attention: failed to match v path")
  131. return
  132. dequantize_v_matmul_weight = dequantize_v_matmul_weight[0]
  133. if self.model.get_constant_value(dequantize_v_matmul_weight.input[0]) is None:
  134. return
  135. # Make sure the upstream DequantizeLinear-1 has the proper zero points and scales
  136. # Per-channel scales are supported for weights alone
  137. if not FusionUtils.check_qdq_node_for_fusion(dequantize_v_matmul_weight, self.model, False):
  138. return
  139. # QK nodes
  140. qk_nodes = self.model.match_parent_path(
  141. matmul_qkv,
  142. [
  143. "DequantizeLinear",
  144. "QuantizeLinear",
  145. "Softmax",
  146. "Add",
  147. "Div",
  148. "DequantizeLinear",
  149. "QuantizeLinear",
  150. "MatMul",
  151. ],
  152. [0, 0, 0, 0, None, 0, 0, 0],
  153. )
  154. if qk_nodes is None:
  155. logger.debug("fuse_qordered_attention: failed to match qk path")
  156. return
  157. (
  158. dequantize_qk_softmax,
  159. quantize_qk_softmax,
  160. softmax_qk,
  161. add_qk,
  162. div_qk,
  163. dequantize_qk,
  164. quantize_qk,
  165. matmul_qk,
  166. ) = qk_nodes
  167. # Make sure the Q/DQ has the proper zero points and constant per-tensor scales
  168. if not FusionUtils.check_qdq_node_for_fusion(quantize_qk_softmax, self.model):
  169. return
  170. if not FusionUtils.check_qdq_node_for_fusion(dequantize_qk_softmax, self.model):
  171. return
  172. if not FusionUtils.check_qdq_node_for_fusion(quantize_qk, self.model):
  173. return
  174. if not FusionUtils.check_qdq_node_for_fusion(dequantize_qk, self.model):
  175. return
  176. # Q nodes
  177. q_nodes = self.model.match_parent_path(
  178. matmul_qk,
  179. ["Transpose", "Reshape", "DequantizeLinear", "QuantizeLinear", "Add", "MatMul"],
  180. [0, 0, 0, 0, 0, None],
  181. )
  182. if q_nodes is None:
  183. logger.debug("fuse_qordered_attention: failed to match q path")
  184. return
  185. (_, reshape_q, dequantize_q, quantize_q, add_q, matmul_q) = q_nodes
  186. # Make sure the Q/DQ has the proper zero points and constant per-tensor scales
  187. if not FusionUtils.check_qdq_node_for_fusion(quantize_q, self.model):
  188. return
  189. if not FusionUtils.check_qdq_node_for_fusion(dequantize_q, self.model):
  190. return
  191. # Q MatMul weight
  192. dequantize_q_matmul_weight = self.model.match_parent_path(matmul_q, ["DequantizeLinear"], [1])
  193. if dequantize_q_matmul_weight is None:
  194. logger.debug("fuse_qordered_attention: failed to match q path")
  195. return
  196. dequantize_q_matmul_weight = dequantize_q_matmul_weight[0]
  197. if self.model.get_constant_value(dequantize_q_matmul_weight.input[0]) is None:
  198. return
  199. # Make sure the upstream DequantizeLinear-1 has the proper zero points and scales
  200. # Per-channel scales are supported for weights alone
  201. if not FusionUtils.check_qdq_node_for_fusion(dequantize_q_matmul_weight, self.model, False):
  202. return
  203. # K nodes
  204. k_nodes = self.model.match_parent_path(
  205. matmul_qk,
  206. ["Transpose", "Reshape", "DequantizeLinear", "QuantizeLinear", "Add", "MatMul"],
  207. [1, 0, 0, 0, 0, None],
  208. )
  209. if k_nodes is None:
  210. logger.debug("fuse_qordered_attention: failed to match k path")
  211. return
  212. (_, _, dequantize_k, quantize_k, add_k, matmul_k) = k_nodes
  213. # Make sure the Q/DQ has the proper zero points and constant per-tensor scales
  214. if not FusionUtils.check_qdq_node_for_fusion(quantize_k, self.model):
  215. return
  216. if not FusionUtils.check_qdq_node_for_fusion(dequantize_k, self.model):
  217. return
  218. # K MatMul weight
  219. dequantize_k_matmul_weight = self.model.match_parent_path(matmul_k, ["DequantizeLinear"], [1])
  220. if dequantize_k_matmul_weight is None:
  221. logger.debug("fuse_qordered_attention: failed to match k path")
  222. return
  223. dequantize_k_matmul_weight = dequantize_k_matmul_weight[0]
  224. if self.model.get_constant_value(dequantize_k_matmul_weight.input[0]) is None:
  225. return
  226. # Make sure the upstream DequantizeLinear-1 has the proper zero points and scales
  227. # Per-channel scales are supported for weights alone
  228. if not FusionUtils.check_qdq_node_for_fusion(dequantize_k_matmul_weight, self.model, False):
  229. return
  230. # Mask nodes
  231. mask_nodes = self.model.match_parent_path(
  232. add_qk, ["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"], [None, 0, 1, 0, 0]
  233. )
  234. if mask_nodes is None:
  235. logger.debug("fuse_qordered_attention: failed to match mask_nodes path")
  236. return
  237. # Ascertain `qkv_hidden_sizes` attribute value
  238. q_weight = self.model.get_initializer(dequantize_q_matmul_weight.input[0])
  239. k_weight = self.model.get_initializer(dequantize_k_matmul_weight.input[0])
  240. v_weight = self.model.get_initializer(dequantize_v_matmul_weight.input[0])
  241. qw = NumpyHelper.to_array(q_weight)
  242. kw = NumpyHelper.to_array(k_weight)
  243. vw = NumpyHelper.to_array(v_weight)
  244. qw_out_size = np.prod(qw.shape[1:])
  245. kw_out_size = np.prod(kw.shape[1:])
  246. vw_out_size = np.prod(vw.shape[1:])
  247. # Form QOrderedAttention node
  248. if matmul_v.input[0] == root_input and matmul_q.input[0] == root_input and matmul_k.input[0] == root_input:
  249. mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
  250. # Ascertain `num_heads` and `hidden_size`
  251. num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
  252. # Formulate the inputs
  253. # Actual quantized input
  254. attention_inputs = [dequantize_input.input[0]]
  255. attention_inputs.append(dequantize_input.input[1])
  256. attention_inputs.append(dequantize_q.input[1])
  257. attention_inputs.append(dequantize_k.input[1])
  258. attention_inputs.append(dequantize_v.input[1])
  259. attention_inputs.append(dequantize_q_matmul_weight.input[0])
  260. attention_inputs.append(dequantize_k_matmul_weight.input[0])
  261. attention_inputs.append(dequantize_v_matmul_weight.input[0])
  262. attention_inputs.append(dequantize_q_matmul_weight.input[1])
  263. attention_inputs.append(dequantize_k_matmul_weight.input[1])
  264. attention_inputs.append(dequantize_v_matmul_weight.input[1])
  265. if self.model.get_initializer(add_q.input[0]):
  266. attention_inputs.append(add_q.input[0])
  267. else: # second input is the constant bias
  268. attention_inputs.append(add_q.input[1])
  269. if self.model.get_initializer(add_k.input[0]):
  270. attention_inputs.append(add_k.input[0])
  271. else: # second input is the constant bias
  272. attention_inputs.append(add_k.input[1])
  273. if self.model.get_initializer(add_v.input[0]):
  274. attention_inputs.append(add_v.input[0])
  275. else: # second input is the constant bias
  276. attention_inputs.append(add_v.input[1])
  277. attention_inputs.append(quantize_qk.input[1])
  278. attention_inputs.append(quantize_qk_softmax.input[1])
  279. attention_inputs.append(dequantize_qkv.input[1])
  280. # Mask input
  281. if mask_index is not None:
  282. attention_inputs.append(mask_index)
  283. else:
  284. attention_inputs.append("")
  285. # The MatMul weight 'B' and 'bias' need some post-processing
  286. # Transpose weight 'B' from order ROW to order COL
  287. # This offline transpose is needed only while using the CUDA EP
  288. # TODO: Make this fusion logic EP-agnostic ?
  289. q_weight_tensor = self.model.get_initializer(dequantize_q_matmul_weight.input[0])
  290. FusionUtils.transpose_2d_int8_tensor(q_weight_tensor)
  291. k_weight_tensor = self.model.get_initializer(dequantize_k_matmul_weight.input[0])
  292. FusionUtils.transpose_2d_int8_tensor(k_weight_tensor)
  293. v_weight_tensor = self.model.get_initializer(dequantize_v_matmul_weight.input[0])
  294. FusionUtils.transpose_2d_int8_tensor(v_weight_tensor)
  295. # Name and create Attention node
  296. attention_node_name = self.model.create_node_name("QOrderedAttention")
  297. attention_node = helper.make_node(
  298. "QOrderedAttention",
  299. inputs=attention_inputs,
  300. outputs=[reshape_qkv.output[0]],
  301. name=attention_node_name,
  302. )
  303. self.model.replace_node_input(dequantize_qkv, dequantize_qkv.input[0], attention_node.output[0])
  304. self.model.replace_node_input(projection_matmul, projection_matmul.input[0], dequantize_qkv.output[0])
  305. attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
  306. attention_node.attribute.extend([helper.make_attribute("order_input", 1)])
  307. attention_node.attribute.extend([helper.make_attribute("order_weight", 0)])
  308. attention_node.attribute.extend([helper.make_attribute("order_output", 1)])
  309. attention_node.attribute.extend(
  310. [helper.make_attribute("qkv_hidden_sizes", [qw_out_size, kw_out_size, vw_out_size])]
  311. )
  312. attention_node.domain = "com.microsoft"
  313. self.nodes_to_add.append(attention_node)
  314. self.node_name_to_graph_name[attention_node.name] = self.this_graph_name
  315. self.nodes_to_remove.extend([reshape_qkv, transpose_qkv, quantize_qkv, matmul_qkv])
  316. self.nodes_to_remove.extend(qk_nodes)
  317. self.nodes_to_remove.extend(q_nodes)
  318. self.nodes_to_remove.extend(k_nodes)
  319. self.nodes_to_remove.extend(v_nodes)
  320. self.nodes_to_remove.extend(
  321. [dequantize_q_matmul_weight, dequantize_k_matmul_weight, dequantize_v_matmul_weight]
  322. )
  323. # Use prune graph to remove mask nodes since they are shared by all attention nodes.
  324. # self.nodes_to_remove.extend(mask_nodes)
  325. self.prune_graph = True