m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

542 lines
21 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. import numpy as np
  7. from fusion_base import Fusion
  8. from fusion_utils import FusionUtils
  9. from onnx import TensorProto, helper, numpy_helper
  10. from onnx_model import OnnxModel
  11. logger = getLogger(__name__)
  12. class FusionGptAttentionPastBase(Fusion):
  13. """Base class for GPT Attention Fusion with past state"""
  14. def __init__(self, model: OnnxModel, num_heads: int):
  15. super().__init__(model, "Attention", ["LayerNormalization", "SkipLayerNormalization"], "with past")
  16. self.num_heads = num_heads
  17. self.utils = FusionUtils(model)
  18. self.casted_attention_mask = {} # map from name of attention mask to the name that casted to int32
  19. self.mask_filter_value = None
  20. def match_past_pattern_1(self, concat_k, concat_v, output_name_to_node):
  21. # Pattern 1:
  22. # {past}
  23. # / \
  24. # / \
  25. # Gather(axes=0, indices=0) Gather(indices=1)
  26. # | |
  27. # Transpose (perm=0,1,3,2) |
  28. # | |
  29. # Concat_k Concat_v
  30. # | /
  31. # Transpose (perm=0,1,3,2) /
  32. # | /
  33. # Unsqueeze Unsqueeze
  34. # \ /
  35. # \ /
  36. # Concat
  37. # |
  38. # {present}
  39. gather = self.model.get_parent(concat_v, 0, output_name_to_node)
  40. if gather.op_type != "Gather":
  41. logger.debug("match_past_pattern_1: expect Gather for past")
  42. return None
  43. if not self.model.find_constant_input(gather, 1) == 1:
  44. logger.debug("match_past_pattern_1: expect indices=1 for Gather of past")
  45. return None
  46. past = gather.input[0]
  47. parent = self.model.get_parent(concat_k, 0, output_name_to_node)
  48. if parent.op_type == "Gather":
  49. gather_past_k = parent
  50. else:
  51. past_k_nodes = self.model.match_parent_path(concat_k, ["Transpose", "Gather"], [0, 0])
  52. if past_k_nodes is None:
  53. logger.debug("match_past_pattern_1: failed match Transpose and Gather")
  54. return None
  55. gather_past_k = past_k_nodes[-1]
  56. if not self.model.find_constant_input(gather_past_k, 0) == 1:
  57. logger.debug("match_past_pattern_1: expect indices=0 for Gather k of past")
  58. return None
  59. past_k = gather_past_k.input[0]
  60. if past != past_k:
  61. logger.debug("match_past_pattern_1: expect past to be same")
  62. return None
  63. return past
  64. def match_past_pattern_2(self, concat_k, concat_v, output_name_to_node):
  65. # Pattern 2:
  66. # Split (QKV)
  67. # / | |
  68. # / | +----------------------+
  69. # | |
  70. # | {past} |
  71. # | | |
  72. # Reshape Split Reshape
  73. # | / \ |
  74. # Transpose_k Squeeze Squeeze Transpose_v
  75. # | | \ /
  76. # +------|---+ \ /
  77. # | | \ /
  78. # Concat_k Concat_v
  79. # | |
  80. # Unsqueeze Unsqueeze
  81. # \ /
  82. # Concat
  83. # |
  84. # {present}
  85. #
  86. squeeze = self.model.get_parent(concat_v, 0, output_name_to_node)
  87. if squeeze.op_type != "Squeeze":
  88. logger.debug("match_past_pattern_2: expect Squeeze as parent of concat_v")
  89. return None
  90. split = self.model.get_parent(squeeze, 0, output_name_to_node)
  91. if split.op_type != "Split":
  92. logger.debug("match_past_pattern_2: expect Split for past path")
  93. return None
  94. opset_version = self.model.get_opset_version()
  95. if opset_version < 13:
  96. if not FusionUtils.check_node_attribute(squeeze, "axes", [0]):
  97. logger.debug("match_past_pattern_2: axes != [0] for Squeeze in past path")
  98. return None
  99. if not FusionUtils.check_node_attribute(split, "split", [1, 1]):
  100. logger.debug("match_past_pattern_2: split != [1, 1] for Split in past path")
  101. return None
  102. else:
  103. if not self.utils.check_node_input_value(squeeze, 1, [0]):
  104. logger.debug("match_past_pattern_2: axes != [0] for Squeeze in past path")
  105. return None
  106. if not self.utils.check_node_input_value(split, 1, [1, 1]):
  107. logger.debug("match_past_pattern_2: split != [1, 1] for Split in past path")
  108. return None
  109. if not FusionUtils.check_node_attribute(split, "axis", 0, default_value=0):
  110. logger.debug("match_past_pattern_2: attribute axis of Split are not expected in past path")
  111. return None
  112. past = split.input[0]
  113. past_k_nodes = self.model.match_parent_path(concat_k, ["Squeeze", "Split"], [0, 0])
  114. if past_k_nodes is None:
  115. logger.debug("match_past_pattern_2: failed to match past_k_nodes path")
  116. return None
  117. past_k = past_k_nodes[-1].input[0]
  118. if past != past_k:
  119. logger.info("match_past_pattern_2: expect past to be same")
  120. return None
  121. return past
  122. def match_present(self, concat_v, input_name_to_nodes):
  123. unsqueeze_present_v = self.model.find_first_child_by_type(
  124. concat_v, "Unsqueeze", input_name_to_nodes, recursive=False
  125. )
  126. if not unsqueeze_present_v:
  127. logger.info("expect unsqueeze for present")
  128. return None
  129. concat_present = self.model.find_first_child_by_type(
  130. unsqueeze_present_v, "Concat", input_name_to_nodes, recursive=False
  131. )
  132. if not concat_present:
  133. logger.info("expect concat for present")
  134. return None
  135. present = concat_present.output[0]
  136. return present
  137. def cast_attention_mask(self, input_name):
  138. if input_name in self.casted_attention_mask:
  139. attention_mask_input_name = self.casted_attention_mask[input_name]
  140. elif self.model.find_graph_input(input_name):
  141. casted, attention_mask_input_name = self.utils.cast_graph_input_to_int32(input_name)
  142. self.casted_attention_mask[input_name] = attention_mask_input_name
  143. else:
  144. attention_mask_input_name, cast_node = self.utils.cast_input_to_int32(input_name)
  145. self.casted_attention_mask[input_name] = attention_mask_input_name
  146. return attention_mask_input_name
  147. class FusionGptAttention(FusionGptAttentionPastBase):
  148. """
  149. Fuse GPT-2 Attention with past state subgraph into one Attention node.
  150. """
  151. def __init__(self, model: OnnxModel, num_heads: int):
  152. super().__init__(model, num_heads)
  153. def create_attention_node(
  154. self,
  155. fc_weight,
  156. fc_bias,
  157. gemm_qkv,
  158. past,
  159. present,
  160. input,
  161. output,
  162. mask,
  163. is_unidirectional,
  164. ):
  165. attention_node_name = self.model.create_node_name("GptAttention")
  166. attention_node = helper.make_node(
  167. "Attention",
  168. inputs=[input, fc_weight, fc_bias, mask, past],
  169. outputs=[attention_node_name + "_output", present],
  170. name=attention_node_name,
  171. )
  172. attention_node.domain = "com.microsoft"
  173. attention_node.attribute.extend(
  174. [
  175. helper.make_attribute("num_heads", self.num_heads),
  176. helper.make_attribute("unidirectional", 1 if is_unidirectional else 0),
  177. ]
  178. )
  179. if self.mask_filter_value is not None:
  180. attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
  181. matmul_node = helper.make_node(
  182. "MatMul",
  183. inputs=[attention_node_name + "_output", gemm_qkv.input[1]],
  184. outputs=[attention_node_name + "_matmul_output"],
  185. name=attention_node_name + "_matmul",
  186. )
  187. add_node = helper.make_node(
  188. "Add",
  189. inputs=[attention_node_name + "_matmul_output", gemm_qkv.input[2]],
  190. outputs=[output],
  191. name=attention_node_name + "_add",
  192. )
  193. self.nodes_to_add.extend([attention_node, matmul_node, add_node])
  194. self.node_name_to_graph_name[attention_node.name] = self.this_graph_name
  195. self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name
  196. self.node_name_to_graph_name[add_node.name] = self.this_graph_name
  197. def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
  198. past = None
  199. present = None
  200. return_indice = []
  201. is_normalize_node_skiplayernorm = normalize_node.op_type == "SkipLayerNormalization"
  202. qkv_nodes = None
  203. if not is_normalize_node_skiplayernorm:
  204. qkv_nodes = self.model.match_parent_path(
  205. normalize_node,
  206. ["Add", "Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"],
  207. [0, None, 0, 0, 0, 0, 0],
  208. output_name_to_node=output_name_to_node,
  209. return_indice=return_indice,
  210. ) # yapf: disable
  211. else:
  212. qkv_nodes = self.model.match_parent_path(
  213. normalize_node,
  214. ["Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"],
  215. [None, 0, 0, 0, 0, 0],
  216. output_name_to_node=output_name_to_node,
  217. return_indice=return_indice,
  218. ) # yapf: disable
  219. if qkv_nodes is None:
  220. return
  221. another_input = None
  222. if not is_normalize_node_skiplayernorm:
  223. (
  224. add_qkv,
  225. reshape_qkv,
  226. gemm_qkv,
  227. reshape_1,
  228. reshape_2,
  229. transpose_qkv,
  230. matmul_qkv,
  231. ) = qkv_nodes
  232. another_input = add_qkv.input[1 - return_indice[0]]
  233. else:
  234. (
  235. reshape_qkv,
  236. gemm_qkv,
  237. reshape_1,
  238. reshape_2,
  239. transpose_qkv,
  240. matmul_qkv,
  241. ) = qkv_nodes
  242. v_nodes = self.model.match_parent_path(matmul_qkv, ["Concat", "Transpose", "Reshape", "Split"], [1, 1, 0, 0])
  243. if v_nodes is None:
  244. logger.debug("fuse_attention: failed to match v path")
  245. return
  246. (concat_v, transpose_v, reshape_v, split_fc) = v_nodes
  247. # Try match pattern using Gemm + LayerNormalization
  248. fc_nodes = self.model.match_parent_path(
  249. split_fc,
  250. ["Reshape", "Gemm", "Reshape", "LayerNormalization"],
  251. [0, 0, 0, 0],
  252. output_name_to_node,
  253. )
  254. # Try match pattern using Gemm + SkipLayerNormalization
  255. if fc_nodes is None:
  256. fc_nodes = self.model.match_parent_path(
  257. split_fc,
  258. ["Reshape", "Gemm", "Reshape", "SkipLayerNormalization"],
  259. [0, 0, 0, 0],
  260. output_name_to_node,
  261. )
  262. # Try match pattern using MatMul
  263. if fc_nodes is None:
  264. # LayerNormalization
  265. fc_nodes = self.model.match_parent_path(
  266. split_fc,
  267. ["Add", "MatMul", "LayerNormalization"],
  268. [0, None, 0],
  269. output_name_to_node,
  270. )
  271. # SkipLayerNormalization
  272. if fc_nodes is None:
  273. fc_nodes = self.model.match_parent_path(
  274. split_fc,
  275. ["Add", "MatMul", "SkipLayerNormalization"],
  276. [0, None, 0],
  277. output_name_to_node,
  278. )
  279. if fc_nodes is None:
  280. logger.debug("fuse_attention: failed to match fc path")
  281. return
  282. fc_weight = fc_nodes[1].input[1]
  283. i, _ = self.model.get_constant_input(fc_nodes[0])
  284. fc_bias = fc_nodes[0].input[i]
  285. else:
  286. fc_weight = fc_nodes[1].input[1]
  287. fc_bias = fc_nodes[1].input[2]
  288. layernorm_before_attention = fc_nodes[-1]
  289. # `another_input` will be non-None only if
  290. # (1) SkipLayerNorm fusion wasn't turned ON
  291. # (2) SkipLayerNorm fusion was turned ON but upstream layer's LayerNorm + Add was not
  292. # fused into a SkipLayerNorm. This can happen if the shapes to the Add node are different.
  293. # So, keep the following check if SkipLayerNorm fusion is turned ON or OFF.
  294. if another_input is not None and not another_input in layernorm_before_attention.input:
  295. logger.debug("Upstream Add and (Skip)LayerNormalization shall have one same input")
  296. return
  297. is_unidirectional = True
  298. slice_mask = None
  299. input_mask_nodes = None
  300. concat_k_to_match = None
  301. qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Sub", "Mul", "Div", "MatMul"], [0, 0, 0, 0, 0])
  302. if qk_nodes is not None:
  303. (softmax_qk, sub_qk, mul_qk, div_qk, matmul_qk) = qk_nodes
  304. mask_nodes = self.model.match_parent_path(
  305. sub_qk,
  306. [
  307. "Mul",
  308. "Sub",
  309. "Slice",
  310. "Slice",
  311. "Unsqueeze",
  312. "Sub",
  313. "Squeeze",
  314. "Slice",
  315. "Shape",
  316. "Div",
  317. ],
  318. [1, 0, 1, 0, 1, 0, 0, 0, 0, 0],
  319. ) # yapf: disable
  320. if mask_nodes is None:
  321. logger.debug("fuse_attention: failed to match unidirectional mask path")
  322. return
  323. div_mask = mask_nodes[-1]
  324. slice_mask = mask_nodes[3]
  325. if div_qk != div_mask:
  326. logger.debug("fuse_attention: skip since div_qk != div_mask")
  327. return
  328. if len(mask_nodes) > 1 and mask_nodes[0].op_type == "Mul":
  329. _, mul_val = self.model.get_constant_input(mask_nodes[0])
  330. if mul_val != -10000:
  331. self.mask_filter_value = -mul_val
  332. else:
  333. # New pattern for gpt2 from PyTorch 1.5.0 and Transformers 2.9.0.
  334. i, qk_nodes, _ = self.model.match_parent_paths(
  335. matmul_qkv,
  336. [
  337. (["Softmax", "Where", "Div", "MatMul"], [0, 0, 1, 0]),
  338. (["Softmax", "Add", "Where", "Div", "MatMul"], [0, 0, None, 1, 0]),
  339. ],
  340. output_name_to_node,
  341. )
  342. if qk_nodes is None:
  343. logger.debug("fuse_attention: failed to match qk nodes")
  344. return
  345. where_qk = qk_nodes[-3]
  346. div_qk = qk_nodes[-2]
  347. matmul_qk = qk_nodes[-1]
  348. if i == 1:
  349. add_qk = qk_nodes[1]
  350. _, input_mask_nodes, _ = self.model.match_parent_paths(
  351. add_qk,
  352. [
  353. (
  354. ["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze", "Reshape"],
  355. [None, 0, 1, 0, 0, 0],
  356. ),
  357. (
  358. ["Mul", "Sub", "Unsqueeze", "Unsqueeze", "Reshape"],
  359. [None, 0, 1, 0, 0],
  360. ),
  361. (
  362. ["Mul", "Sub", "Unsqueeze", "Unsqueeze"],
  363. [None, 0, 1, 0],
  364. ), # useless cast and reshape are removed.
  365. ],
  366. output_name_to_node,
  367. ) # yapf: disable
  368. if input_mask_nodes is None:
  369. logger.debug("fuse_attention: failed to match input attention mask path")
  370. return
  371. if len(input_mask_nodes) > 1 and input_mask_nodes[0].op_type == "Mul":
  372. _, mul_val = self.model.get_constant_input(input_mask_nodes[0])
  373. if mul_val != -10000:
  374. self.mask_filter_value = mul_val
  375. mask_nodes = self.model.match_parent_path(
  376. where_qk,
  377. [
  378. "Cast",
  379. "Slice",
  380. "Slice",
  381. "Unsqueeze",
  382. "Sub",
  383. "Squeeze",
  384. "Slice",
  385. "Shape",
  386. ],
  387. [0, 0, 0, 1, 0, 0, 0, 0],
  388. output_name_to_node,
  389. ) # yapf: disable
  390. if mask_nodes is None:
  391. # TODO: match mask path for GPT2LMHeadModel_BeamSearchStep.
  392. logger.debug("fuse_attention: failed to match mask path")
  393. return
  394. slice_mask = mask_nodes[2]
  395. div_or_concat = self.model.get_parent(mask_nodes[-1], 0, output_name_to_node)
  396. if div_or_concat.op_type == "Div":
  397. div_mask = div_or_concat
  398. if div_qk != div_mask:
  399. logger.debug("fuse_attention: skip since div_qk != div_mask")
  400. return
  401. elif div_or_concat.op_type == "Concat":
  402. concat_k_to_match = div_or_concat
  403. else:
  404. logger.debug("fuse_attention: failed to match mask path")
  405. # Validate that the mask data is either lower triangular (unidirectional) or all ones
  406. mask_data = numpy_helper.to_array(self.model.get_initializer(slice_mask.input[0]))
  407. if not (
  408. len(mask_data.shape) == 4 and mask_data.shape[:2] == (1, 1) and mask_data.shape[2] == mask_data.shape[3]
  409. ):
  410. logger.debug("fuse_attention: skip since mask shape is not 1x1xWxW")
  411. return
  412. if np.allclose(mask_data, np.ones_like(mask_data)):
  413. is_unidirectional = False
  414. elif not np.allclose(mask_data, np.tril(np.ones_like(mask_data))):
  415. logger.debug("fuse_attention: skip since mask is neither lower triangular nor ones")
  416. return
  417. q_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Split"], [0, 0, 0])
  418. if q_nodes is None:
  419. logger.debug("fuse_attention: failed to match q path")
  420. return
  421. (transpose_q, reshape_q, split_q) = q_nodes
  422. if split_fc != split_q:
  423. logger.debug("fuse_attention: skip since split_fc != split_q")
  424. return
  425. k_nodes = self.model.match_parent_path(matmul_qk, ["Concat", "Transpose", "Reshape", "Split"], [1, 1, 0, 0])
  426. if k_nodes is None:
  427. # This pattern is from pytorch 1.7.1 and transformers 4.6.1
  428. k_nodes = self.model.match_parent_path(
  429. matmul_qk,
  430. ["Transpose", "Concat", "Transpose", "Reshape", "Split"],
  431. [1, 0, 1, 0, 0],
  432. )
  433. if k_nodes is None:
  434. logger.debug("fuse_attention: failed to match k path")
  435. return
  436. else:
  437. (_, concat_k, transpose_k, reshape_k, split_k) = k_nodes
  438. else:
  439. (concat_k, transpose_k, reshape_k, split_k) = k_nodes
  440. if split_fc != split_k:
  441. logger.debug("fuse_attention: skip since split_fc != split_k")
  442. return
  443. if concat_k_to_match and concat_k != concat_k_to_match:
  444. logger.debug("fuse_attention: skip since concat_k != concat_k_to_match")
  445. return
  446. attention_mask_input_name = ""
  447. if input_mask_nodes is not None:
  448. input_name = input_mask_nodes[-1].input[0]
  449. attention_mask_input_name = self.cast_attention_mask(input_name)
  450. # Match past and present paths
  451. past = self.match_past_pattern_1(concat_k, concat_v, output_name_to_node) or self.match_past_pattern_2(
  452. concat_k, concat_v, output_name_to_node
  453. )
  454. if past is None:
  455. logger.info("fuse_attention: failed to match past path")
  456. return
  457. if not self.model.find_graph_input(past):
  458. logger.debug("past is not graph input.")
  459. # For GPT2LMHeadModel_BeamSearchStep, there is an extra Gather node to select beam index so it is not graph input.
  460. present = self.match_present(concat_v, input_name_to_nodes)
  461. if present is None:
  462. logger.info("fuse_attention: failed to match present path")
  463. return
  464. if not self.model.find_graph_output(present):
  465. logger.info("expect present to be graph output")
  466. return
  467. self.create_attention_node(
  468. fc_weight,
  469. fc_bias,
  470. gemm_qkv,
  471. past,
  472. present,
  473. layernorm_before_attention.output[0],
  474. reshape_qkv.output[0],
  475. attention_mask_input_name,
  476. is_unidirectional,
  477. )
  478. # we rely on prune_graph() to clean old subgraph nodes:
  479. # qk_nodes + q_nodes + k_nodes + v_nodes + mask_nodes + [reshape_qkv, transpose_qkv, matmul_qkv]
  480. self.prune_graph = True