m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

640 lines
25 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from enum import Enum
  6. from logging import getLogger
  7. from os import name
  8. from sys import path
  9. from typing import Tuple, Union
  10. import numpy as np
  11. from fusion_base import Fusion
  12. from fusion_options import AttentionMaskFormat
  13. from fusion_utils import FusionUtils, NumpyHelper
  14. from onnx import NodeProto, TensorProto, helper, numpy_helper
  15. from onnx_model import OnnxModel
  16. from shape_infer_helper import SymbolicShapeInferenceHelper, get_shape_from_type_proto
  17. logger = getLogger(__name__)
  18. class AttentionMask:
  19. """
  20. Fuse Attention subgraph into one Attention node.
  21. """
  22. def __init__(self, model: OnnxModel):
  23. self.model = model
  24. # A lookup table with mask input as key, and mask index output as value
  25. self.mask_indice = {}
  26. # A lookup table with mask input as key, and cast (to int32) output as value
  27. self.mask_casted = {}
  28. self.utils = FusionUtils(model)
  29. self.mask_format = AttentionMaskFormat.MaskIndexEnd
  30. def set_mask_format(self, mask_format: AttentionMaskFormat):
  31. self.mask_format = mask_format
  32. def set_mask_indice(self, mask, mask_index):
  33. if mask in self.mask_indice:
  34. assert mask_index == self.mask_indice[mask]
  35. self.mask_indice[mask] = mask_index
  36. def get_first_mask(self):
  37. assert len(self.mask_indice) > 0
  38. return next(iter(self.mask_indice))
  39. def process_mask(self, input: str) -> str:
  40. if self.mask_format == AttentionMaskFormat.NoMask:
  41. return None
  42. if input in self.mask_indice:
  43. return self.mask_indice[input]
  44. # Add cast to convert int64 to int32
  45. if self.model.find_graph_input(input):
  46. casted, input_name = self.utils.cast_graph_input_to_int32(input)
  47. else:
  48. input_name, cast_node = self.utils.cast_input_to_int32(input)
  49. casted = True
  50. if casted:
  51. self.mask_casted[input] = input_name
  52. # Attention supports int32 attention mask (2D) since 1.4.0
  53. if self.mask_format == AttentionMaskFormat.AttentionMask:
  54. self.mask_indice[input] = input_name
  55. return input_name
  56. # Add a mask processing node to convert attention mask to mask index (1D)
  57. output_name = self.model.create_node_name("mask_index")
  58. mask_index_node = helper.make_node(
  59. "ReduceSum",
  60. inputs=[input_name],
  61. outputs=[output_name],
  62. name=self.model.create_node_name("ReduceSum", "MaskReduceSum"),
  63. )
  64. mask_index_node.attribute.extend([helper.make_attribute("axes", [1]), helper.make_attribute("keepdims", 0)])
  65. self.model.add_node(mask_index_node)
  66. self.mask_indice[input] = output_name
  67. return output_name
  68. class FusionAttention(Fusion):
  69. """
  70. Fuse Attention subgraph into one Attention node.
  71. """
  72. def __init__(
  73. self,
  74. model: OnnxModel,
  75. hidden_size: int,
  76. num_heads: int,
  77. attention_mask: AttentionMask,
  78. use_multi_head_attention: bool = False,
  79. ):
  80. attention_op_name = "MultiHeadAttention" if use_multi_head_attention else "Attention"
  81. super().__init__(model, attention_op_name, ["SkipLayerNormalization", "LayerNormalization"])
  82. self.hidden_size = hidden_size
  83. self.num_heads = num_heads
  84. self.attention_mask = attention_mask
  85. self.use_multi_head_attention = use_multi_head_attention
  86. self.mask_filter_value = None
  87. # Flags to show warning only once
  88. self.num_heads_warning = True
  89. self.hidden_size_warning = True
  90. def get_num_heads_and_hidden_size_from_concat(self, concat: NodeProto) -> Tuple[int, int]:
  91. """
  92. Detect num_heads and hidden_size from Concat node in the following subgraph:
  93. SkipLayerNormalization or EmbedLayerNormalization
  94. / |
  95. MatMul Shape
  96. | |
  97. Add Gather(indices=0)
  98. | |
  99. | Unsqueeze
  100. | |
  101. | Concat (*, -1, 12, 64)
  102. | /
  103. Reshape
  104. |
  105. Transpose
  106. """
  107. if len(concat.input) == 4:
  108. num_heads = self.model.get_constant_value(concat.input[2])
  109. head_size = self.model.get_constant_value(concat.input[3])
  110. if (
  111. isinstance(num_heads, np.ndarray)
  112. and num_heads.size == 1
  113. and isinstance(head_size, np.ndarray)
  114. and head_size.size == 1
  115. ):
  116. return num_heads[0], num_heads[0] * head_size[0]
  117. return self.num_heads, self.hidden_size
  118. def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> Tuple[int, int]:
  119. """Detect num_heads and hidden_size from a reshape node.
  120. Args:
  121. reshape_q (NodeProto): reshape node for Q
  122. Returns:
  123. Tuple[int, int]: num_heads and hidden_size
  124. """
  125. # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
  126. q_shape = self.model.get_initializer(reshape_q.input[1])
  127. if q_shape is None:
  128. concat = self.model.get_parent(reshape_q, 1)
  129. if concat is not None and concat.op_type == "Concat":
  130. return self.get_num_heads_and_hidden_size_from_concat(concat)
  131. logger.debug(f"{reshape_q.input[1]} is not initializer.")
  132. return self.num_heads, self.hidden_size # Fall back to user specified value
  133. q_shape_value = NumpyHelper.to_array(q_shape)
  134. if len(q_shape_value) != 4 or (q_shape_value[2] <= 0 or q_shape_value[3] <= 0):
  135. logger.debug(f"q_shape_value={q_shape_value}. Expected value are like [0, 0, num_heads, head_size].")
  136. return self.num_heads, self.hidden_size # Fall back to user specified value
  137. num_heads = q_shape_value[2]
  138. head_size = q_shape_value[3]
  139. hidden_size = num_heads * head_size
  140. if self.num_heads > 0 and num_heads != self.num_heads:
  141. if self.num_heads_warning:
  142. logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.")
  143. self.num_heads_warning = False # Do not show the warning more than once
  144. if self.hidden_size > 0 and hidden_size != self.hidden_size:
  145. if self.hidden_size_warning:
  146. logger.warning(
  147. f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value."
  148. )
  149. self.hidden_size_warning = False # Do not show the warning more than once
  150. return num_heads, hidden_size
  151. def get_add_qk_str(self, add_qk: NodeProto):
  152. shape_infer = self.model.infer_runtime_shape(update=True)
  153. if shape_infer is None:
  154. return
  155. input_0_shape = shape_infer.get_edge_shape(add_qk.input[0])
  156. input_1_shape = shape_infer.get_edge_shape(add_qk.input[1])
  157. if input_0_shape is None or input_1_shape is None:
  158. logger.debug(f"one of the inputs of {add_qk} is None")
  159. return None
  160. if input_0_shape != input_1_shape:
  161. logger.debug(f"the shape of two inputs of {add_qk} is not same")
  162. return None
  163. return add_qk.input[1]
  164. def create_attention_node(
  165. self,
  166. mask_index: str,
  167. q_matmul: NodeProto,
  168. k_matmul: NodeProto,
  169. v_matmul: NodeProto,
  170. q_add: NodeProto,
  171. k_add: NodeProto,
  172. v_add: NodeProto,
  173. num_heads: int,
  174. hidden_size: int,
  175. input: str,
  176. output: str,
  177. add_qk_str: str,
  178. ) -> Union[NodeProto, None]:
  179. """Create an Attention node.
  180. Args:
  181. mask_index (str): mask input
  182. q_matmul (NodeProto): MatMul node in fully connection for Q
  183. k_matmul (NodeProto): MatMul node in fully connection for K
  184. v_matmul (NodeProto): MatMul node in fully connection for V
  185. q_add (NodeProto): Add bias node in fully connection for Q
  186. k_add (NodeProto): Add bias node in fully connection for K
  187. v_add (NodeProto): Add bias node in fully connection for V
  188. num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
  189. hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
  190. input (str): input name
  191. output (str): output name
  192. Returns:
  193. Union[NodeProto, None]: the node created or None if failed.
  194. """
  195. assert num_heads > 0
  196. if hidden_size > 0 and (hidden_size % num_heads) != 0:
  197. logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
  198. return None
  199. q_weight = self.model.get_initializer(q_matmul.input[1])
  200. k_weight = self.model.get_initializer(k_matmul.input[1])
  201. v_weight = self.model.get_initializer(v_matmul.input[1])
  202. q_bias = self.model.get_initializer(q_add.input[1]) or self.model.get_initializer(q_add.input[0])
  203. k_bias = self.model.get_initializer(k_add.input[1]) or self.model.get_initializer(k_add.input[0])
  204. v_bias = self.model.get_initializer(v_add.input[1]) or self.model.get_initializer(v_add.input[0])
  205. if q_weight is None:
  206. print(
  207. f"{q_matmul.input[1]} is not an initializer. "
  208. "Please set do_constant_folding=True in torch.onnx.export to unblock attention fusion"
  209. )
  210. return None
  211. if not (k_weight and v_weight and q_bias and k_bias):
  212. return None
  213. qw = NumpyHelper.to_array(q_weight)
  214. kw = NumpyHelper.to_array(k_weight)
  215. vw = NumpyHelper.to_array(v_weight)
  216. # assert q and k have same shape as expected
  217. assert qw.shape == kw.shape
  218. qw_in_size = qw.shape[0]
  219. kw_in_size = kw.shape[0]
  220. vw_in_size = vw.shape[0]
  221. assert qw_in_size == kw_in_size == vw_in_size
  222. if hidden_size > 0 and hidden_size != qw_in_size:
  223. logger.warning(
  224. f"Input hidden size ({hidden_size}) is not same as weight matrix dimension of q,k,v ({qw_in_size}). "
  225. "Please provide a correct input hidden size or pass in 0"
  226. )
  227. is_qkv_diff_dims = False
  228. if qw.shape != vw.shape:
  229. is_qkv_diff_dims = True
  230. # All the matrices can have the same shape or q, k matrics can have the same shape with v being different
  231. # For 2d weights, the shapes would be [in_size, out_size].
  232. # For 3d weights, shape would be [in_size, a, b] where a*b = out_size
  233. qw_out_size = np.prod(qw.shape[1:])
  234. kw_out_size = np.prod(kw.shape[1:])
  235. vw_out_size = np.prod(vw.shape[1:])
  236. qkv_weight_dim = 0
  237. if is_qkv_diff_dims:
  238. qkv_weight = np.concatenate((qw, kw, vw), axis=1)
  239. qkv_weight_dim = qw_out_size + kw_out_size + vw_out_size
  240. else:
  241. qkv_weight = np.stack((qw, kw, vw), axis=1)
  242. qkv_weight_dim = 3 * qw_out_size
  243. qb = NumpyHelper.to_array(q_bias)
  244. kb = NumpyHelper.to_array(k_bias)
  245. vb = NumpyHelper.to_array(v_bias)
  246. q_bias_shape = np.prod(qb.shape)
  247. k_bias_shape = np.prod(kb.shape)
  248. v_bias_shape = np.prod(vb.shape)
  249. assert q_bias_shape == k_bias_shape == qw_out_size
  250. assert v_bias_shape == vw_out_size
  251. qkv_bias_dim = 0
  252. if is_qkv_diff_dims:
  253. qkv_bias = np.concatenate((qb, kb, vb), axis=0)
  254. qkv_bias_dim = q_bias_shape + k_bias_shape + v_bias_shape
  255. else:
  256. qkv_bias = np.stack((qb, kb, vb), axis=0)
  257. qkv_bias_dim = 3 * q_bias_shape
  258. attention_node_name = self.model.create_node_name("Attention")
  259. if not self.use_multi_head_attention:
  260. weight = helper.make_tensor(
  261. name=attention_node_name + "_qkv_weight",
  262. data_type=TensorProto.FLOAT,
  263. dims=[qw_in_size, qkv_weight_dim],
  264. vals=qkv_weight.flatten().tolist(),
  265. )
  266. # Sometimes weights and bias are stored in fp16
  267. if q_weight.data_type == 10:
  268. weight.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(weight).astype(np.float16), weight.name))
  269. self.model.add_initializer(weight, self.this_graph_name)
  270. bias = helper.make_tensor(
  271. name=attention_node_name + "_qkv_bias",
  272. data_type=TensorProto.FLOAT,
  273. dims=[qkv_bias_dim],
  274. vals=qkv_bias.flatten().tolist(),
  275. )
  276. if q_bias.data_type == 10:
  277. bias.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(bias).astype(np.float16), bias.name))
  278. self.model.add_initializer(bias, self.this_graph_name)
  279. # For MultiHeadAttention operator, use separated inputs for query, key and value, and no weights.
  280. if self.use_multi_head_attention:
  281. if add_qk_str is not None:
  282. logger.debug("MultiHeadAttention does not support relative_position_bias: cannot fuse the attention.")
  283. return None
  284. attention_inputs = [
  285. q_matmul.output[0],
  286. k_matmul.output[0],
  287. v_matmul.output[0],
  288. attention_node_name + "_qkv_bias",
  289. ]
  290. if mask_index is not None:
  291. attention_inputs.append(mask_index)
  292. attention_node = helper.make_node(
  293. "MultiHeadAttention",
  294. inputs=attention_inputs,
  295. outputs=[output],
  296. name=attention_node_name,
  297. )
  298. else:
  299. attention_inputs = [
  300. input,
  301. attention_node_name + "_qkv_weight",
  302. attention_node_name + "_qkv_bias",
  303. ]
  304. if mask_index is not None:
  305. attention_inputs.append(mask_index)
  306. else:
  307. attention_inputs.append("")
  308. if add_qk_str is not None:
  309. attention_inputs.append("") # no past
  310. attention_inputs.append(add_qk_str)
  311. attention_node = helper.make_node(
  312. "Attention",
  313. inputs=attention_inputs,
  314. outputs=[output],
  315. name=attention_node_name,
  316. )
  317. attention_node.domain = "com.microsoft"
  318. attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
  319. if is_qkv_diff_dims:
  320. attention_node.attribute.extend(
  321. [helper.make_attribute("qkv_hidden_sizes", [qw_out_size, kw_out_size, vw_out_size])]
  322. )
  323. if self.mask_filter_value is not None:
  324. attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
  325. return attention_node
  326. def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
  327. # Sometimes we can not fuse skiplayernormalization since the add before layernorm has an output that used by nodes outside skiplayernorm
  328. # Conceptually we treat add before layernorm as skiplayernorm node since they share the same pattern
  329. start_node = normalize_node
  330. if normalize_node.op_type == "LayerNormalization":
  331. add_before_layernorm = self.model.match_parent(normalize_node, "Add", 0)
  332. if add_before_layernorm is not None:
  333. start_node = add_before_layernorm
  334. else:
  335. return
  336. # SkipLayerNormalization has two inputs, and one of them is the root input for attention.
  337. qkv_nodes = self.model.match_parent_path(
  338. start_node,
  339. ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
  340. [None, None, 0, 0, 0],
  341. )
  342. einsum_node = None
  343. if qkv_nodes is not None:
  344. (_, _, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes
  345. else:
  346. # Match Albert
  347. qkv_nodes = self.model.match_parent_path(
  348. start_node, ["Add", "Einsum", "Transpose", "MatMul"], [1, None, 0, 0]
  349. )
  350. if qkv_nodes is not None:
  351. (_, einsum_node, transpose_qkv, matmul_qkv) = qkv_nodes
  352. else:
  353. return
  354. other_inputs = []
  355. for i, input in enumerate(start_node.input):
  356. if input not in output_name_to_node:
  357. continue
  358. if input == qkv_nodes[0].output[0]:
  359. continue
  360. other_inputs.append(input)
  361. if len(other_inputs) != 1:
  362. return
  363. root_input = other_inputs[0]
  364. """
  365. Match flaubert Mask
  366. |
  367. Mul --> LayerNormalization --> Attention --> MatMul --> Add
  368. | |
  369. | |
  370. +---------------------------------------------------------
  371. """
  372. mul_before_layernorm = self.model.match_parent(start_node, "Mul", 0)
  373. if mul_before_layernorm is not None:
  374. mul_children = input_name_to_nodes[mul_before_layernorm.output[0]]
  375. if mul_children is not None and len(mul_children) == 2:
  376. layernorm_node = mul_children[1]
  377. if layernorm_node.op_type == "LayerNormalization":
  378. root_input = layernorm_node.output[0]
  379. else:
  380. return
  381. elif mul_children is not None and len(mul_children) == 5:
  382. root_input = mul_before_layernorm.output[0]
  383. else:
  384. return
  385. elif normalize_node.op_type == "LayerNormalization":
  386. children = input_name_to_nodes[root_input]
  387. for child in children:
  388. if child.op_type == "LayerNormalization":
  389. root_input = child.output[0]
  390. children = input_name_to_nodes[root_input]
  391. children_types = [child.op_type for child in children]
  392. if children_types.count("MatMul") != 3:
  393. return
  394. v_nodes = self.model.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None])
  395. if v_nodes is None:
  396. logger.debug("fuse_attention: failed to match v path")
  397. return
  398. (_, _, add_v, matmul_v) = v_nodes
  399. is_distill = False
  400. is_distill_add = False
  401. qk_paths = {
  402. "path1": (["Softmax", "Add", "Div", "MatMul"], [0, 0, None, 0]),
  403. "path2": (["Softmax", "Add", "Mul", "MatMul"], [0, 0, None, 0]),
  404. "path3": (["Softmax", "Where", "MatMul", "Div"], [0, 0, 2, 0]),
  405. "path4": (["Softmax", "Add", "Where", "MatMul"], [0, 0, 0, 2]),
  406. }
  407. qk_nodes = None
  408. for k, v in qk_paths.items():
  409. qk_nodes = self.model.match_parent_path(matmul_qkv, v[0], v[1])
  410. if qk_nodes is None:
  411. continue
  412. if k == "path3":
  413. is_distill = True
  414. if k == "path4":
  415. is_distill_add = True
  416. break
  417. if qk_nodes is None:
  418. logger.debug("fuse_attention: failed to match qk path")
  419. return
  420. add_qk = None
  421. matmul_qk = None
  422. where_qk = None
  423. if is_distill:
  424. (_, where_qk, matmul_qk, _) = qk_nodes
  425. elif is_distill_add:
  426. (_, add_qk, where_qk, matmul_qk) = qk_nodes
  427. else:
  428. (_, add_qk, _, matmul_qk) = qk_nodes
  429. q_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, None])
  430. if q_nodes is None:
  431. q_nodes = self.model.match_parent_path(
  432. matmul_qk,
  433. ["Div", "Transpose", "Reshape", "Add", "MatMul"],
  434. [0, 0, 0, 0, None],
  435. )
  436. if q_nodes is None:
  437. logger.debug("fuse_attention: failed to match q path")
  438. return
  439. reshape_q = q_nodes[-3]
  440. add_q = q_nodes[-2]
  441. matmul_q = q_nodes[-1]
  442. k_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None])
  443. if k_nodes is None:
  444. k_nodes = self.model.match_parent_path(
  445. matmul_qk,
  446. ["Transpose", "Transpose", "Reshape", "Add", "MatMul"],
  447. [1, 0, 0, 0, None],
  448. )
  449. if k_nodes is None:
  450. logger.debug("fuse_attention: failed to match k path")
  451. return
  452. add_k = k_nodes[-2]
  453. matmul_k = k_nodes[-1]
  454. # Note that Cast might be removed by OnnxRuntime so we match two patterns here.
  455. mask_nodes = None
  456. add_qk_str = None
  457. if is_distill:
  458. _, mask_nodes, _ = self.model.match_parent_paths(
  459. where_qk,
  460. [
  461. (["Expand", "Reshape", "Equal"], [0, 0, 0]),
  462. (["Equal", "Unsqueeze", "Unsqueeze"], [0, 0, 0]),
  463. (["Cast", "Expand", "Reshape", "Equal"], [0, 0, 0, 0]),
  464. ],
  465. output_name_to_node,
  466. )
  467. elif is_distill_add:
  468. _, mask_nodes, _ = self.model.match_parent_paths(
  469. where_qk,
  470. [
  471. (["Cast", "Equal", "Unsqueeze", "Unsqueeze"], [0, 0, 0, 0]),
  472. (["Equal", "Unsqueeze", "Unsqueeze"], [0, 0, 0]),
  473. ],
  474. output_name_to_node,
  475. )
  476. if add_qk is not None:
  477. add_qk_str = self.get_add_qk_str(add_qk)
  478. if add_qk_str is None:
  479. logger.debug(f"fuse_attention: failed to verify shape inference of {add_qk}")
  480. return
  481. else:
  482. _, mask_nodes, _ = self.model.match_parent_paths(
  483. add_qk,
  484. [
  485. (
  486. ["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"],
  487. [None, 0, 1, 0, 0],
  488. ),
  489. (["Mul", "Sub", "Unsqueeze", "Unsqueeze"], [None, 0, 1, 0]),
  490. ],
  491. output_name_to_node,
  492. )
  493. if mask_nodes is None:
  494. logger.debug("fuse_attention: failed to match mask path")
  495. return
  496. if len(mask_nodes) > 1 and mask_nodes[0].op_type == "Mul":
  497. _, mul_val = self.model.get_constant_input(mask_nodes[0])
  498. if mul_val != -10000:
  499. self.mask_filter_value = mul_val
  500. if matmul_v.input[0] == root_input and matmul_q.input[0] == root_input and matmul_k.input[0] == root_input:
  501. mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
  502. attention_last_node = reshape_qkv if einsum_node is None else transpose_qkv
  503. q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
  504. # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
  505. # the input_hidden_size represents the input hidden size, this is used as needed but hidden sizes for Q, K are extracted appropriately
  506. new_node = self.create_attention_node(
  507. mask_index,
  508. matmul_q,
  509. matmul_k,
  510. matmul_v,
  511. add_q,
  512. add_k,
  513. add_v,
  514. q_num_heads,
  515. q_hidden_size,
  516. root_input,
  517. attention_last_node.output[0],
  518. add_qk_str,
  519. )
  520. if new_node is None:
  521. return
  522. self.nodes_to_add.append(new_node)
  523. self.node_name_to_graph_name[new_node.name] = self.this_graph_name
  524. if einsum_node is not None:
  525. unique_index = einsum_node.input[0]
  526. new_edge = "edge_modified_" + unique_index
  527. shape_tensor = helper.make_tensor(
  528. name="shape_modified_tensor" + unique_index,
  529. data_type=TensorProto.INT64,
  530. dims=[4],
  531. vals=np.int64([0, 0, q_num_heads, int(q_hidden_size / q_num_heads)]).tobytes(),
  532. raw=True,
  533. )
  534. self.model.add_initializer(shape_tensor, self.this_graph_name)
  535. self.model.add_node(
  536. helper.make_node(
  537. "Reshape",
  538. [attention_last_node.output[0], shape_tensor.name],
  539. [new_edge],
  540. "reshape_modified_" + unique_index,
  541. ),
  542. self.this_graph_name,
  543. )
  544. einsum_node.input[0] = new_edge
  545. self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
  546. self.nodes_to_remove.extend(qk_nodes)
  547. # For MultiHeadAttention operator, MatMul nodes for Q/K/V projection shall not be fused.
  548. self.nodes_to_remove.extend(q_nodes if not self.use_multi_head_attention else q_nodes[:-1])
  549. self.nodes_to_remove.extend(k_nodes if not self.use_multi_head_attention else k_nodes[:-1])
  550. self.nodes_to_remove.extend(v_nodes if not self.use_multi_head_attention else v_nodes[:-1])
  551. # Use prune graph to remove mask nodes since they are shared by all attention nodes.
  552. self.prune_graph = True