图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

359 lines
13 KiB

  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. import numpy as np
  7. from fusion_base import Fusion
  8. from fusion_gpt_attention import FusionGptAttentionPastBase
  9. from fusion_utils import FusionUtils
  10. from onnx import TensorProto, helper, numpy_helper
  11. from onnx_model import OnnxModel
  12. logger = getLogger(__name__)
  13. def is_close(value, expected_value):
  14. return abs(value - expected_value) <= 1e-6
  15. class FusionGptAttentionMegatron(FusionGptAttentionPastBase):
  16. """
  17. Fuse GPT-2 Attention with past state subgraph from Megatron into one Attention node.
  18. """
  19. def __init__(self, model: OnnxModel, num_heads: int):
  20. super().__init__(model, num_heads)
  21. def fuse_attention_node(
  22. self,
  23. matmul_before_split,
  24. add_before_split,
  25. past,
  26. present,
  27. input,
  28. reshape_qkv,
  29. mask,
  30. ):
  31. attention_node_name = self.model.create_node_name("GptAttention")
  32. int32_mask = self.cast_attention_mask(mask)
  33. output = reshape_qkv.output[0]
  34. i = 1 if (add_before_split.input[0] == matmul_before_split.output[0]) else 0
  35. attention_node = helper.make_node(
  36. "Attention",
  37. inputs=[
  38. input,
  39. matmul_before_split.input[1],
  40. add_before_split.input[i],
  41. int32_mask,
  42. past,
  43. ],
  44. outputs=[output, present],
  45. name=attention_node_name,
  46. )
  47. attention_node.domain = "com.microsoft"
  48. attention_node.attribute.extend(
  49. [
  50. helper.make_attribute("num_heads", self.num_heads),
  51. helper.make_attribute("unidirectional", 0), # unidirectional shall not be ON for 4D attention mask
  52. ]
  53. )
  54. if self.mask_filter_value is not None:
  55. attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
  56. nodes_to_add = [attention_node]
  57. self.nodes_to_add.extend(nodes_to_add)
  58. for node in nodes_to_add:
  59. self.node_name_to_graph_name[node.name] = self.this_graph_name
  60. self.nodes_to_remove.append(reshape_qkv)
  61. # we rely on prune_graph() to clean old subgraph nodes
  62. self.prune_graph = True
  63. def match_mask(self, sub_qk, mul_qk, matmul_qk, layernorm_before_attention):
  64. mask_nodes = self.model.match_parent_path(
  65. sub_qk, ["Mul", "Sub", "Slice", "Slice"], [1, 0, 1, 0]
  66. ) # yapf: disable
  67. if mask_nodes is None:
  68. logger.debug("fuse_attention: failed to match unidirectional mask path")
  69. return None
  70. (mul_mask, sub_mask, last_slice_mask, slice_mask) = mask_nodes
  71. if len(mask_nodes) > 1 and mask_nodes[0].op_type == "Mul":
  72. _, mul_val = self.model.get_constant_input(mask_nodes[0])
  73. if mul_val != 10000:
  74. self.mask_filter_value = -mul_val
  75. if mul_qk.input[1] != last_slice_mask.output[0]:
  76. logger.debug("fuse_attention failed: mul_qk.input[1] != last_slice_mask.output[0]")
  77. return None
  78. if not self.utils.check_node_input_value(mul_mask, 1, 10000.0):
  79. logger.debug("fuse_attention failed: mul_mask input 1 is not constant 10000.0")
  80. return None
  81. if not self.utils.check_node_input_value(sub_mask, 0, 1.0):
  82. logger.debug("fuse_attention failed: sub_mask input 0 is not constant 1.0")
  83. return None
  84. if not self.model.find_graph_input(slice_mask.input[0]):
  85. logger.info("expect slick_mask input 0 to be graph input")
  86. return None
  87. if not self.utils.check_node_input_value(last_slice_mask, 1, [0]):
  88. logger.debug("fuse_attention failed: last_slice_mask input 1 (starts) is not constant [0]")
  89. return None
  90. if not self.utils.check_node_input_value(last_slice_mask, 3, [3]):
  91. logger.debug("fuse_attention failed: last_slice_mask input 3 (axes) is not constant [3]")
  92. return False
  93. if not self.utils.check_node_input_value(last_slice_mask, 4, [1]):
  94. logger.debug("fuse_attention failed: last_slice_mask input 4 (steps) is not constant [1]")
  95. return False
  96. if not self.utils.check_node_input_value(slice_mask, 3, [2]):
  97. logger.debug("fuse_attention failed: slice_mask input 3 (axes) is not constant [2]")
  98. return None
  99. if not self.utils.check_node_input_value(slice_mask, 4, [1]):
  100. logger.debug("fuse_attention failed: slice_mask input 4 (steps) is not constant [1]")
  101. return None
  102. last_slice_path = self.model.match_parent_path(
  103. last_slice_mask, ["Unsqueeze", "Gather", "Shape", "MatMul"], [2, 0, 0, 0]
  104. )
  105. if last_slice_path is None or last_slice_path[-1] != matmul_qk:
  106. logger.debug("fuse_attention: failed to match last slice path")
  107. return None
  108. first_slice_path = self.model.match_parent_path(
  109. slice_mask, ["Unsqueeze", "Gather", "Shape", "MatMul"], [2, 0, 0, 0]
  110. )
  111. if first_slice_path is None or first_slice_path[-1] != matmul_qk:
  112. logger.debug("fuse_attention: failed to match first slice path")
  113. return None
  114. first_slice_sub = self.model.match_parent_path(
  115. slice_mask,
  116. ["Unsqueeze", "Sub", "Gather", "Shape", "MatMul"],
  117. [1, 0, 0, 0, 0],
  118. )
  119. if first_slice_sub is None or first_slice_sub[-1] != matmul_qk:
  120. logger.debug("fuse_attention: failed to match last slice sub path")
  121. return None
  122. first_slice_sub_1 = self.model.match_parent_path(
  123. slice_mask,
  124. ["Unsqueeze", "Sub", "Gather", "Shape", "LayerNormalization"],
  125. [1, 0, 1, 0, 0],
  126. )
  127. if first_slice_sub_1 is None:
  128. first_slice_sub_1 = self.model.match_parent_path(
  129. slice_mask,
  130. ["Unsqueeze", "Sub", "Gather", "Shape", "SkipLayerNormalization"],
  131. [1, 0, 1, 0, 0],
  132. )
  133. if first_slice_sub_1 is None or first_slice_sub_1[-1] != layernorm_before_attention:
  134. logger.debug("fuse_attention: failed to match last slice sub path 1")
  135. return None
  136. return slice_mask.input[0]
  137. def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
  138. past = None
  139. present = None
  140. is_normalize_node_skiplayernorm = normalize_node.op_type == "SkipLayerNormalization"
  141. qkv_nodes = None
  142. if not is_normalize_node_skiplayernorm:
  143. qkv_nodes = self.model.match_parent_path(
  144. normalize_node,
  145. ["Add", "Add", "MatMul", "Reshape", "Transpose", "MatMul"],
  146. [0, 1, None, 0, 0, 0],
  147. output_name_to_node=output_name_to_node,
  148. ) # yapf: disable
  149. else:
  150. qkv_nodes = self.model.match_parent_path(
  151. normalize_node,
  152. ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
  153. [1, None, 0, 0, 0],
  154. output_name_to_node=output_name_to_node,
  155. ) # yapf: disable
  156. if qkv_nodes is None:
  157. return
  158. skip_input = None
  159. if not is_normalize_node_skiplayernorm:
  160. (
  161. add_skip,
  162. add_after_attention,
  163. matmul_after_attention,
  164. reshape_qkv,
  165. transpose_qkv,
  166. matmul_qkv,
  167. ) = qkv_nodes
  168. skip_input = add_skip.input[0]
  169. else:
  170. (
  171. add_after_attention,
  172. matmul_after_attention,
  173. reshape_qkv,
  174. transpose_qkv,
  175. matmul_qkv,
  176. ) = qkv_nodes
  177. skip_input = normalize_node.input[0]
  178. v_nodes = self.model.match_parent_path(
  179. matmul_qkv,
  180. [
  181. "Concat",
  182. "Transpose",
  183. "Reshape",
  184. "Split",
  185. "Add",
  186. "MatMul",
  187. "LayerNormalization",
  188. ],
  189. [1, 1, 0, 0, 0, None, 0],
  190. ) # yapf: disable
  191. if v_nodes is None:
  192. v_nodes = self.model.match_parent_path(
  193. matmul_qkv,
  194. [
  195. "Concat",
  196. "Transpose",
  197. "Reshape",
  198. "Split",
  199. "Add",
  200. "MatMul",
  201. "SkipLayerNormalization",
  202. ],
  203. [1, 1, 0, 0, 0, None, 0],
  204. ) # yapf: disable
  205. if v_nodes is None:
  206. logger.debug("fuse_attention: failed to match v path")
  207. return
  208. (
  209. concat_v,
  210. transpose_v,
  211. reshape_v,
  212. split_v,
  213. add_before_split,
  214. matmul_before_split,
  215. layernorm_before_attention,
  216. ) = v_nodes
  217. if (
  218. layernorm_before_attention.op_type == "LayerNormalization"
  219. and skip_input != layernorm_before_attention.input[0]
  220. ):
  221. logger.debug("fuse_attention: skip_input != layernorm_before_attention.input[0]")
  222. return
  223. if (
  224. layernorm_before_attention.op_type == "SkipLayerNormalization"
  225. and skip_input != layernorm_before_attention.output[3]
  226. ):
  227. logger.debug("fuse_attention: skip_input != layernorm_before_attention.input[0]")
  228. return
  229. qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Sub", "Mul", "MatMul"], [0, 0, 0, 0])
  230. if qk_nodes is None:
  231. logger.debug("fuse_attention: failed to match qk path")
  232. return None
  233. (softmax_qk, sub_qk, mul_qk, matmul_qk) = qk_nodes
  234. if self.model.get_node_attribute(softmax_qk, "axis") != 3:
  235. logger.debug("fuse_attention failed: softmax_qk axis != 3")
  236. return None
  237. attention_mask = self.match_mask(sub_qk, mul_qk, matmul_qk, layernorm_before_attention)
  238. q_nodes = self.model.match_parent_path(matmul_qk, ["Div", "Transpose", "Reshape", "Split"], [0, 0, 0, 0])
  239. if q_nodes is None:
  240. logger.debug("fuse_attention: failed to match q path")
  241. return
  242. (div_q, transpose_q, reshape_q, split_q) = q_nodes
  243. if split_v != split_q:
  244. logger.debug("fuse_attention: skip since split_v != split_q")
  245. return
  246. k_nodes = self.model.match_parent_path(
  247. matmul_qk,
  248. ["Div", "Transpose", "Concat", "Transpose", "Reshape", "Split"],
  249. [1, 0, 0, 1, 0, 0],
  250. )
  251. if k_nodes is None:
  252. logger.debug("fuse_attention: failed to match k path")
  253. return
  254. (div_k, _, concat_k, transpose_k, reshape_k, split_k) = k_nodes
  255. if split_v != split_k:
  256. logger.debug("fuse_attention: skip since split_v != split_k")
  257. return
  258. i, value = self.model.get_constant_input(reshape_k)
  259. if not (
  260. isinstance(value, np.ndarray)
  261. and list(value.shape) == [4]
  262. and value[0] == 0
  263. and value[1] == 0
  264. and value[2] > 0
  265. and value[3] > 0
  266. ):
  267. logger.debug("fuse_attention: reshape constant input is not [0, 0, N, H]")
  268. return
  269. num_heads = value[2]
  270. if num_heads != self.num_heads:
  271. logger.info(f"Detected num_heads={num_heads}. Ignore user specified value {self.num_heads}")
  272. self.num_heads = num_heads
  273. hidden_size_per_head = value[3]
  274. i, value = self.model.get_constant_input(div_k)
  275. expected_value = float(np.sqrt(np.sqrt(hidden_size_per_head)))
  276. if not is_close(value, expected_value):
  277. logger.debug(f"fuse_attention: div_k value={value} expected={expected_value}")
  278. return
  279. i, value = self.model.get_constant_input(div_q)
  280. if not is_close(value, expected_value):
  281. logger.debug(f"fuse_attention: div_q value={value} expected={expected_value}")
  282. return
  283. # Match past and present paths
  284. past = self.match_past_pattern_2(concat_k, concat_v, output_name_to_node)
  285. if past is None:
  286. logger.debug("fuse_attention: match past failed")
  287. return
  288. if not self.model.find_graph_input(past):
  289. logger.debug("fuse_attention: past is not graph input.")
  290. # For GPT2LMHeadModel_BeamSearchStep, there is an extra Gather node to select beam index so it is not graph input.
  291. present = self.match_present(concat_v, input_name_to_nodes)
  292. if present is None:
  293. logger.debug("fuse_attention: match present failed")
  294. return
  295. if not self.model.find_graph_output(present):
  296. logger.info("fuse_attention: expect present to be graph output")
  297. return
  298. self.fuse_attention_node(
  299. matmul_before_split,
  300. add_before_split,
  301. past,
  302. present,
  303. layernorm_before_attention.output[0],
  304. reshape_qkv,
  305. attention_mask,
  306. )