m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

782 lines
33 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from typing import Dict, List, Tuple, Union
  7. from fusion_base import Fusion
  8. from fusion_utils import FusionUtils
  9. from onnx import NodeProto, TensorProto, helper
  10. from onnx_model import OnnxModel
  11. logger = getLogger(__name__)
  12. class FusionEmbedLayerNoMask(Fusion):
  13. """
  14. Fuse embedding layer into one node (EmbedLayerNormalization).
  15. It supports the following model types: BERT, DistilBert, ALBert.
  16. """
  17. def __init__(self, model: OnnxModel, description: str = "no mask"):
  18. super().__init__(
  19. model,
  20. "EmbedLayerNormalization",
  21. ["LayerNormalization", "SkipLayerNormalization"],
  22. description,
  23. )
  24. self.utils = FusionUtils(model)
  25. self.shape_infer_helper = self.model.infer_runtime_shape({}, update=True)
  26. # The following will be reset in each fuse call of FusionEmbedLayerNormalization
  27. self.attention = None
  28. self.embed_node = None
  29. def match_two_gather(self, add: NodeProto) -> Union[None, Tuple[NodeProto, NodeProto]]:
  30. gather_0_path = self.model.match_parent_path(add, ["Gather"], [0])
  31. if gather_0_path is None:
  32. return None
  33. gather_1_path = self.model.match_parent_path(add, ["Gather"], [1])
  34. if gather_1_path is None:
  35. return None
  36. return gather_0_path[0], gather_1_path[0]
  37. def check_attention_subgraph(
  38. self,
  39. layernorm: NodeProto,
  40. input_name_to_nodes: Dict[str, List[NodeProto]],
  41. is_distil_bert: bool,
  42. ) -> bool:
  43. """Check that LayerNormalization has a child of Attention node or subgraph like Attention.
  44. Args:
  45. layernorm (NodeProto): LayerNormalization node
  46. input_name_to_nodes (Dict[str, List[NodeProto]]): map from input name to nodes
  47. is_distil_bert (bool): whether it is DistilBert or not
  48. Returns:
  49. bool: whether there is Attention node or subgraph like Attention
  50. """
  51. self.attention = self.model.find_first_child_by_type(
  52. layernorm, "Attention", input_name_to_nodes, recursive=False
  53. )
  54. if self.attention is not None:
  55. return True
  56. if layernorm.output[0] not in input_name_to_nodes:
  57. return False
  58. children = input_name_to_nodes[layernorm.output[0]]
  59. children_types = sorted([child.op_type for child in children])
  60. # Try find MultiHeadAttention
  61. if children_types == ["MatMul", "MatMul", "MatMul", "SkipLayerNormalization"]:
  62. for node in children:
  63. if node.op_type == "SkipLayerNormalization":
  64. path1 = self.model.match_parent_path(
  65. node,
  66. ["Add", "MatMul", "MultiHeadAttention", "MatMul"],
  67. [None, None, 0, 0],
  68. )
  69. if path1 is not None and path1[-1].input[0] == layernorm.output[0]:
  70. self.cross_attention = path1[2]
  71. return True
  72. # In case user disables attention fusion, check whether subgraph looks like Attention.
  73. # For Albert, there is MatMul+Add after embedding layer before attention.
  74. if len(children) == 1 and children[0].op_type == "MatMul" and children[0].output[0] in input_name_to_nodes:
  75. grandchildren = input_name_to_nodes[children[0].output[0]]
  76. if (
  77. len(grandchildren) == 1
  78. and grandchildren[0].op_type == "Add"
  79. and grandchildren[0].output[0] in input_name_to_nodes
  80. ):
  81. nodes = input_name_to_nodes[grandchildren[0].output[0]]
  82. for node in nodes:
  83. if node.op_type == "Attention":
  84. self.attention = node
  85. return True
  86. children_types = sorted([child.op_type for child in nodes])
  87. # Two Shape nodes might be merged by ORT
  88. if is_distil_bert:
  89. # SkipLayerNormailization might exist when model has been optimized by ORT first.
  90. if (
  91. children_types != ["MatMul", "MatMul", "MatMul", "Shape", "SkipLayerNormalization"]
  92. and children_types != ["Add", "MatMul", "MatMul", "MatMul", "Shape", "Shape"]
  93. and children_types != ["Add", "MatMul", "MatMul", "MatMul", "Shape"]
  94. ):
  95. logger.debug("No Attention like subgraph in children of LayerNormalization")
  96. return False
  97. else:
  98. if children_types != ["Add", "MatMul", "MatMul", "MatMul",] and children_types != [
  99. "MatMul",
  100. "MatMul",
  101. "MatMul",
  102. "SkipLayerNormalization",
  103. ]:
  104. logger.debug("No Attention like subgraph in children of LayerNormalization")
  105. return False
  106. return True
  107. def match_position_embedding_distilbert(self, position_embedding_gather, input_ids, output_name_to_node):
  108. """ Match position embedding path from input_ids to Gather for DistilBert.
  109. Pattern is like the following:
  110. (input_ids)
  111. |
  112. Shape
  113. | \
  114. | Gather (indices=1)
  115. | |
  116. | Cast (optional)
  117. | |
  118. | Range (start=0, end=*, delta=1)
  119. | |
  120. | Unsqueeze
  121. | /
  122. Expand
  123. |
  124. Gather
  125. """
  126. # remove after tests pass
  127. path1 = self.model.match_parent_path(position_embedding_gather, ["Expand", "Shape"], [1, 1])
  128. if path1 is None:
  129. path1 = self.model.match_parent_path(
  130. position_embedding_gather,
  131. ["Expand", "Where", "Reshape", "Shape"],
  132. [1, 1, 2, 0],
  133. )
  134. if path1 is None:
  135. return False
  136. expand, shape = path1[0], path1[-1]
  137. if shape.input[0] != input_ids:
  138. return False
  139. _, path2, _ = self.model.match_parent_paths(
  140. expand,
  141. [
  142. (["Unsqueeze", "Range", "Cast", "Gather", "Shape"], [0, 0, 1, 0, 0]),
  143. (["Unsqueeze", "Range", "Gather", "Shape"], [0, 0, 1, 0]),
  144. ],
  145. output_name_to_node,
  146. )
  147. if path2 is None:
  148. return False
  149. range_node = path2[1]
  150. if not (
  151. self.utils.check_node_input_value(range_node, 0, 0) and self.utils.check_node_input_value(range_node, 2, 1)
  152. ):
  153. return False
  154. gather_node = path2[-2]
  155. if not (self.utils.check_node_input_value(gather_node, 1, 1)):
  156. return False
  157. shape_node = path2[-1]
  158. if shape_node.input[0] != input_ids:
  159. return False
  160. return True
  161. def match_position_embedding_roberta(self, position_embedding_gather, input_ids, output_name_to_node):
  162. """Match position embedding path from input_ids to Gather for Roberta.
  163. Roberta Embedding Layer Pattern (* is optional since it might be removed by ORT, ? is the padding word id):
  164. (input_ids) --> Equal(B=?) -- Not -- Cast(to=6) -- CumSum(axis=1) -- Mul -- Cast(to=7) -- Add(B=1) -- Cast(to=7)* --> Gather
  165. | ^
  166. V |
  167. +------------------------------+
  168. Roberta new pattern from transformers v4.9:
  169. (input_ids) --> Equal(B=?) -- Not -- Cast(to=6) -- CumSum(axis=1) -- Add(B=0) -- Mul -- Cast(to=7) -- Add(B=1) --> Gather
  170. | ^
  171. V |
  172. +-------------------------------------------+
  173. start_node = position_embedding_gather
  174. start_index = 1
  175. # match optional Cast node.
  176. parent = self.model.get_parent(start_node, start_index, output_name_to_node)
  177. if parent is None:
  178. return
  179. if parent.op_type == "Cast":
  180. if OnnxModel.get_node_attribute(parent, "to") != 7:
  181. return
  182. start_node = parent
  183. start_index = 0
  184. i, path, return_indices = self.model.match_parent_paths(
  185. start_node,
  186. [ (['Add', 'Cast', 'Mul', 'CumSum', 'Cast', 'Not', 'Equal'], [start_index, 0, 0, 0, 0, 0, 0]),
  187. (['Add', 'Cast', 'Mul', 'Add', 'CumSum', 'Cast', 'Not', 'Equal'], [start_index, 0, 0, 0, 0, 0, 0, 0])],
  188. output_name_to_node)
  189. if path is not None:
  190. # constant input of Add shall be 1.
  191. i, value = self.model.get_constant_input(path[0])
  192. if value != 1:
  193. return False
  194. _, self.padding_word_id = self.model.get_constant_input(path[-1])
  195. return input_ids == path[-1].input[0]
  196. """
  197. return False
  198. def match_position_embedding_bert(self, position_embedding_gather, input_ids, output_name_to_node):
  199. """ Match position embedding path from input_ids to Gather for BERT.
  200. BERT Embedding Layer Pattern:
  201. (input_ids)
  202. / \
  203. / Shape
  204. / |
  205. / Gather (indices=1)
  206. / |
  207. / Add (optional, B=0)
  208. / |
  209. Gather (segment_ids) Unsqueeze (axes=0)
  210. \ | |
  211. \ Gather Slice (data[1,512], starts=0, ends=*, axes=1, steps=1)
  212. \ / |
  213. Add Gather
  214. \ /
  215. Add
  216. |
  217. LayerNormalization
  218. """
  219. path = self.model.match_parent_path(
  220. position_embedding_gather,
  221. ["Slice", "Unsqueeze"],
  222. [1, 2],
  223. output_name_to_node,
  224. )
  225. if path is None:
  226. return False
  227. slice, unsqueeze = path
  228. slice_weight = self.model.get_constant_value(slice.input[0])
  229. if not (
  230. slice_weight is not None
  231. and len(slice_weight.shape) == 2
  232. and slice_weight.shape[0] == 1
  233. and self.utils.check_node_input_value(slice, 1, [0])
  234. and self.utils.check_node_input_value(slice, 3, [1])
  235. and (len(slice.input) == 4 or self.utils.check_node_input_value(slice, 4, [1]))
  236. ):
  237. return False
  238. opset_version = self.model.get_opset_version()
  239. if opset_version < 13:
  240. if not FusionUtils.check_node_attribute(unsqueeze, "axes", [0]):
  241. return False
  242. else:
  243. if not self.utils.check_node_input_value(unsqueeze, 1, [0]):
  244. return False
  245. node = self.model.get_parent(unsqueeze, 0, output_name_to_node)
  246. if node is None:
  247. return False
  248. if node.op_type == "Add":
  249. if not self.utils.check_node_input_value(node, 1, 0):
  250. return False
  251. gather = self.model.get_parent(node, 0, output_name_to_node)
  252. else:
  253. gather = node
  254. if gather is None or gather.op_type != "Gather":
  255. return False
  256. if not (self.utils.check_node_input_value(gather, 1, 1)):
  257. return False
  258. shape = self.model.get_parent(gather, 0, output_name_to_node)
  259. if shape is None or shape.op_type != "Shape":
  260. return False
  261. return input_ids == shape.input[0]
  262. def match_position_embedding(self, position_embedding_gather, input_ids, output_name_to_node):
  263. if self.match_position_embedding_bert(position_embedding_gather, input_ids, output_name_to_node):
  264. return True
  265. # TODO: Support roberta (position starts from 2 instead of 0) in EmbedLayerNormalization kernel
  266. # related: https://github.com/huggingface/transformers/issues/10736
  267. # if self.match_position_embedding_roberta(position_embedding_gather, input_ids, output_name_to_node):
  268. # return True
  269. if self.match_position_embedding_distilbert(position_embedding_gather, input_ids, output_name_to_node):
  270. return True
  271. return False
  272. def check_embedding(self, word_embedding_gather, segment_embedding_gather, position_embedding_gather):
  273. """Sanity check of embedding weights, and match hidden_size of weights and shape of inputs."""
  274. input_ids = word_embedding_gather.input[1]
  275. segment_ids = segment_embedding_gather.input[1] if segment_embedding_gather else None
  276. position_ids = position_embedding_gather.input[1]
  277. if self.shape_infer_helper is not None:
  278. input_ids_shape = self.shape_infer_helper.get_edge_shape(input_ids)
  279. position_ids_shape = self.shape_infer_helper.get_edge_shape(position_ids)
  280. assert input_ids_shape and position_ids_shape
  281. if not (
  282. len(input_ids_shape) == 2
  283. and len(position_ids_shape) == 2
  284. and input_ids_shape[1] == position_ids_shape[1]
  285. ):
  286. logger.info(
  287. "Cannot fuse EmbedLayerNormalization: input_ids and position_ids not matched in 2nd dimension: {} vs {}".format(
  288. input_ids_shape, position_ids_shape
  289. )
  290. )
  291. return False
  292. if segment_ids and not self.shape_infer_helper.compare_shape(input_ids, segment_ids):
  293. logger.info(
  294. "Cannot fuse EmbedLayerNormalization: input_ids and segment_ids does not have same shape: {} != {}".format(
  295. input_ids_shape,
  296. self.shape_infer_helper.get_edge_shape(segment_ids),
  297. )
  298. )
  299. return False
  300. word_embedding_table = self.model.get_constant_value(word_embedding_gather.input[0])
  301. if word_embedding_table is None or len(word_embedding_table.shape) != 2:
  302. logger.info("Cannot fuse EmbedLayerNormalization: word embedding table is not expected")
  303. return False
  304. position_embedding_table = self.model.get_constant_value(position_embedding_gather.input[0])
  305. if (
  306. position_embedding_table is None
  307. or len(position_embedding_table.shape) != 2
  308. or (word_embedding_table.shape[1] != position_embedding_table.shape[1])
  309. ):
  310. logger.info("Cannot fuse EmbedLayerNormalization: position embedding table is not expected")
  311. return False
  312. if segment_ids:
  313. segment_embedding_table = self.model.get_constant_value(segment_embedding_gather.input[0])
  314. if (
  315. segment_embedding_table is None
  316. or len(segment_embedding_table.shape) != 2
  317. or (word_embedding_table.shape[1] != segment_embedding_table.shape[1])
  318. ):
  319. logger.info("Cannot fuse EmbedLayerNormalization: segment embedding table is not expected")
  320. return False
  321. # In normal case, word embeding table is the largest, and segment embedding table is the smallest, while postion embedding table is in between.
  322. # TODO: use other information (like initializer names) to identify different embedding weights automatically.
  323. if word_embedding_table.shape[0] <= position_embedding_table.shape[0]:
  324. logger.warning(
  325. f"word_embedding_table ({word_embedding_gather.input[0]}) size {word_embedding_table.shape[0]} <= position_embedding_table ({position_embedding_gather.input[0]}) size {position_embedding_table.shape[0]}"
  326. )
  327. if segment_ids:
  328. if word_embedding_table.shape[0] <= segment_embedding_table.shape[0]:
  329. logger.warning(
  330. f"word_embedding_table ({word_embedding_gather.input[0]}) size {word_embedding_table.shape[0]} <= segment_embedding_table ({segment_embedding_gather.input[0]}) size {segment_embedding_table.shape[0]}"
  331. )
  332. if position_embedding_table.shape[0] <= segment_embedding_table.shape[0]:
  333. logger.warning(
  334. f"position_embedding_table ({position_embedding_gather.input[0]}) size {position_embedding_table.shape[0]} <= segment_embedding_table ({segment_embedding_gather.input[0]}) size {segment_embedding_table.shape[0]}"
  335. )
  336. return True
  337. def cast_to_int32(self, input_name: str) -> Tuple[str, Union[None, NodeProto]]:
  338. """Cast a graph input or node input to int32.
  339. Args:
  340. input_name (str): name of graph input or node input
  341. Returns:
  342. A tuple of casted input name and the cast node.
  343. int32_output (str): If input is int32, it is the input name, Otherwise it is output name of Cast node.
  344. input_cast_node (Union[None, NodeProto]): Cast node. It could be None if input is int32.
  345. """
  346. input_cast_node = None
  347. graph_input = self.model.find_graph_input(input_name)
  348. if graph_input is not None:
  349. if graph_input.type.tensor_type.elem_type != TensorProto.INT32:
  350. int32_output, input_cast_node = self.utils.cast_input_to_int32(input_name)
  351. else:
  352. int32_output = input_name
  353. else:
  354. int32_output, input_cast_node = self.utils.cast_input_to_int32(input_name)
  355. return int32_output, input_cast_node
  356. def create_fused_node(
  357. self,
  358. input_ids: str,
  359. layernorm: NodeProto,
  360. word_embedding_gather: NodeProto,
  361. position_embedding_gather: NodeProto,
  362. segment_embedding_gather: Union[None, NodeProto],
  363. position_ids: str = None,
  364. embedding_sum_output=False,
  365. ):
  366. """Create an EmbedLayerNormalization node. Note that segment embedding is optional.
  367. Args:
  368. input_ids (str): input_ids for word embeddings
  369. layernorm (NodeProto): LayerNormalization or SkipLayerNormalization node.
  370. word_embedding_gather (NodeProto): the Gather node for word embedding
  371. position_embedding_gather (NodeProto): the Gather node for position embedding
  372. segment_embedding_gather (Union[None, NodeProto]): the Gather node for segment embedding, or None.
  373. Returns:
  374. NodeProto: the EmbedLayerNormalization node created.
  375. """
  376. nodes_to_add = []
  377. input_ids, _ = self.cast_to_int32(input_ids)
  378. node_name = self.model.create_node_name("EmbedLayerNormalization")
  379. if layernorm.op_type == "LayerNormalization":
  380. gamma = layernorm.input[1]
  381. beta = layernorm.input[2]
  382. else: # SkipLayerNormalization
  383. gamma = layernorm.input[2]
  384. beta = layernorm.input[3]
  385. embed_node_inputs = None
  386. if segment_embedding_gather is not None:
  387. segment_ids, _ = self.cast_to_int32(segment_embedding_gather.input[1])
  388. embed_node_inputs = [
  389. input_ids,
  390. segment_ids,
  391. word_embedding_gather.input[0],
  392. position_embedding_gather.input[0],
  393. segment_embedding_gather.input[0],
  394. gamma,
  395. beta,
  396. ]
  397. else: # no segment embedding
  398. embed_node_inputs = [
  399. input_ids,
  400. "",
  401. word_embedding_gather.input[0],
  402. position_embedding_gather.input[0],
  403. "",
  404. gamma,
  405. beta,
  406. ]
  407. if position_ids is not None:
  408. # Adding an empty input for mask before position_ids
  409. embed_node_inputs.append("")
  410. position_ids, _ = self.cast_to_int32(position_ids)
  411. embed_node_inputs.append(position_ids)
  412. embed_node_outputs = [node_name + "_output", node_name + "_dummy_mask_index"]
  413. if embedding_sum_output:
  414. embed_node_outputs.append(node_name + "_embedding_sum")
  415. embed_node = helper.make_node(
  416. "EmbedLayerNormalization",
  417. embed_node_inputs,
  418. outputs=embed_node_outputs,
  419. name=node_name,
  420. )
  421. embed_node.domain = "com.microsoft"
  422. # Pass attribute "epsilon" from normalize node to EmbedLayerNormalization.
  423. for att in layernorm.attribute:
  424. if att.name == "epsilon":
  425. embed_node.attribute.extend([att])
  426. # Set default value to 1e-12 if no attribute is found.
  427. # OnnxRuntime 1.2.0 or older has no epsilon attribute. The optimized model can only work for 1.3.0 or later.
  428. if len(embed_node.attribute) == 0:
  429. embed_node.attribute.extend([helper.make_attribute("epsilon", 1.0e-12)])
  430. # Make sure new EmbedLayerNormalization node is the last one in self.nodes_to_add.
  431. nodes_to_add.append(embed_node)
  432. for node in nodes_to_add:
  433. self.node_name_to_graph_name[node.name] = self.this_graph_name
  434. self.nodes_to_add.extend(nodes_to_add)
  435. self.embed_node = embed_node
  436. return embed_node
  437. def finish_fusion(self, layernorm, embed_node):
  438. self.model.replace_input_of_all_nodes(layernorm.output[0], embed_node.output[0])
  439. # use prune graph to remove nodes that is not needed
  440. self.prune_graph = True
  441. def is_embedding_sum_needed(self, add_before_layer_norm):
  442. """Check that Add before layer norm has an output to add before next layernorm
  443. Args:
  444. add_before_layer_norm (NodeProto): Add before any LayerNormalization node in topological order of graph
  445. Returns:
  446. bool: whether there is an extra output needed out of embed layer norm node
  447. """
  448. nodes = self.model.get_children(add_before_layer_norm)
  449. return len(nodes) > 1
  450. def fuse_gpt2(self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node):
  451. # graph checks
  452. # gpt2 has no segment embedding, subgraph pattern is like
  453. # input_ids position_ids
  454. # | |
  455. # Gather Gather
  456. # \ /
  457. # Add _ _ _ _ _
  458. # | |
  459. # LayerNormalization |
  460. # | |
  461. # Attention |
  462. # | |
  463. # Matmul |
  464. # | /
  465. # Add /
  466. # \ /
  467. # Add
  468. two_gather = self.match_two_gather(add_before_layernorm)
  469. if two_gather is None:
  470. return False
  471. word_embedding_gather, position_embedding_gather = two_gather
  472. input_ids = word_embedding_gather.input[1]
  473. position_ids = position_embedding_gather.input[1]
  474. if not self.check_attention_subgraph(layernorm, input_name_to_nodes, is_distil_bert=False):
  475. return False
  476. if not self.check_embedding(word_embedding_gather, None, position_embedding_gather):
  477. return False
  478. # If the add_before_layernorm node is an Add node, then the add_output output is the first index
  479. # output of this node.
  480. # If the add_before_layernorm node is SkipLayerNormalization node, then the add_output output
  481. # is the (optional) fourth index output of this node.
  482. add_output = None
  483. optional_embedding_sum_output = False
  484. if (add_before_layernorm.op_type == "Add" and self.is_embedding_sum_needed(add_before_layernorm)) or (
  485. add_before_layernorm.op_type == "SkipLayerNormalization" and len(add_before_layernorm.output) >= 4
  486. ):
  487. optional_embedding_sum_output = True
  488. add_output = (
  489. add_before_layernorm.output[0]
  490. if add_before_layernorm.op_type == "Add"
  491. else add_before_layernorm.output[3]
  492. )
  493. # make the fused node
  494. embed_node = self.create_fused_node(
  495. input_ids,
  496. layernorm,
  497. word_embedding_gather,
  498. position_embedding_gather,
  499. None,
  500. position_ids,
  501. optional_embedding_sum_output,
  502. )
  503. # direct the output to another add too
  504. self.model.replace_input_of_all_nodes(layernorm.output[0], embed_node.output[0])
  505. if optional_embedding_sum_output:
  506. self.model.replace_input_of_all_nodes(add_output, embed_node.output[2])
  507. return True
  508. def fuse_distilbert(self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node):
  509. """Fuse embedding layer for DistilBert
  510. Args:
  511. layernorm (NodeProto): node of LayerNormalization or SkipLayerNormalization
  512. add_before_layernorm (NodeProto): the Add node before LayerNormalization, or the SkipLayerNormalization itself
  513. input_name_to_nodes (Dict[str, List[NodeProto]]): map from input name to nodes
  514. output_name_to_node (Dict[str, List[NodeProto]]): map from output name to nodes
  515. """
  516. # DistilBert has no segment embedding, subgraph pattern is like
  517. # input_ids
  518. # | \
  519. # | (position_embedding_subgraph)
  520. # | |
  521. # Gather Gather
  522. # \ /
  523. # Add
  524. # |
  525. # LayerNormalization
  526. two_gather = self.match_two_gather(add_before_layernorm)
  527. if two_gather is None:
  528. return False
  529. word_embedding_gather, position_embedding_gather = two_gather
  530. input_ids = word_embedding_gather.input[1]
  531. if not self.check_attention_subgraph(layernorm, input_name_to_nodes, is_distil_bert=True):
  532. return False
  533. if not self.match_position_embedding(position_embedding_gather, input_ids, output_name_to_node):
  534. return False
  535. if not self.check_embedding(word_embedding_gather, None, position_embedding_gather):
  536. return False
  537. embed_node = self.create_fused_node(
  538. input_ids, layernorm, word_embedding_gather, position_embedding_gather, None
  539. )
  540. self.finish_fusion(layernorm, embed_node)
  541. return True
  542. def fuse_bert(self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node):
  543. """Fuse embedding layer for Bert
  544. Args:
  545. layernorm (NodeProto): node of LayerNormalization or SkipLayerNormalization
  546. add_before_layernorm (NodeProto): the Add node before LayerNormalization, or the SkipLayerNormalization itself
  547. input_name_to_nodes (Dict[str, List[NodeProto]]): map from input name to nodes
  548. output_name_to_node (Dict[str, List[NodeProto]]): map from output name to nodes
  549. """
  550. add_2_gather = self.model.match_parent_path(add_before_layernorm, ["Add"], [0])
  551. if add_2_gather is None:
  552. return False
  553. two_gather = self.match_two_gather(add_2_gather[0])
  554. if two_gather is None:
  555. return False
  556. word_embedding_gather, segment_embedding_gather = two_gather
  557. input_ids = word_embedding_gather.input[1]
  558. if not self.check_attention_subgraph(layernorm, input_name_to_nodes, is_distil_bert=False):
  559. return False
  560. position_embedding_path = self.model.match_parent_path(add_before_layernorm, ["Gather"], [1])
  561. if position_embedding_path is None:
  562. return False
  563. position_embedding_gather = position_embedding_path[0]
  564. if not self.match_position_embedding(position_embedding_gather, input_ids, output_name_to_node):
  565. if not self.match_position_embedding(segment_embedding_gather, input_ids, output_name_to_node):
  566. return False
  567. # position and segment are switched
  568. temp = segment_embedding_gather
  569. segment_embedding_gather = position_embedding_gather
  570. position_embedding_gather = temp
  571. if not self.check_embedding(word_embedding_gather, segment_embedding_gather, position_embedding_gather):
  572. return False
  573. embed_node = self.create_fused_node(
  574. input_ids,
  575. layernorm,
  576. word_embedding_gather,
  577. position_embedding_gather,
  578. segment_embedding_gather,
  579. )
  580. self.finish_fusion(layernorm, embed_node)
  581. return True
  582. def fuse(self, node, input_name_to_nodes, output_name_to_node):
  583. if node.op_type == "LayerNormalization":
  584. first_add_path = self.model.match_parent_path(node, ["Add"], [0])
  585. if first_add_path is None:
  586. return
  587. add_before_layernorm = first_add_path[0]
  588. else: # SkipLayerNormalization
  589. add_before_layernorm = node # Add is fused into SkipLayerNormalization
  590. if self.fuse_gpt2(node, add_before_layernorm, input_name_to_nodes, output_name_to_node):
  591. return
  592. if self.fuse_distilbert(node, add_before_layernorm, input_name_to_nodes, output_name_to_node):
  593. return
  594. if self.fuse_bert(node, add_before_layernorm, input_name_to_nodes, output_name_to_node):
  595. return
  596. class FusionEmbedLayerNormalization(FusionEmbedLayerNoMask):
  597. def __init__(self, model: OnnxModel, use_mask_index=False):
  598. super().__init__(model, "with mask")
  599. self.use_mask_index = use_mask_index
  600. def replace_mask(self, mask_int32, attention_nodes):
  601. # Inputs of EmbedLayerNorm: input_ids, segment_ids (optional), word_embedding, position_embedding,
  602. # segment_embedding (optional), gamma, beta, mask (optional), position_ids (optional)
  603. embed_node = self.embed_node
  604. if len(embed_node.input) == 7:
  605. embed_node.input.append(mask_int32)
  606. logger.debug("append mask to %s", embed_node.name)
  607. elif len(embed_node.input) > 7 and embed_node.input[7] == "":
  608. embed_node.input[7] = mask_int32
  609. logger.debug("replace mask in %s", embed_node.name)
  610. else:
  611. logger.debug("skip mask in %s", embed_node.name)
  612. return
  613. for attention_node in attention_nodes:
  614. logger.debug("update mask_index in %s", attention_node.name)
  615. if attention_node.op_type == "Attention":
  616. attention_node.input[3] = embed_node.output[1]
  617. elif attention_node.op_type == "MultiHeadAttention":
  618. attention_node.input[4] = embed_node.output[1]
  619. def fuse(self, node, input_name_to_nodes, output_name_to_node):
  620. # Reset attention and embed_node so that we know fusion is successful when they are not None.
  621. self.attention = None
  622. self.cross_attention = None
  623. self.embed_node = None
  624. super().fuse(node, input_name_to_nodes, output_name_to_node)
  625. if self.embed_node is None:
  626. return
  627. if not self.use_mask_index:
  628. logger.debug("--use_mask_index is not set: EmbedLayerNormalization will not have mask")
  629. self.increase_counter("EmbedLayerNormalization(no mask)")
  630. return
  631. if self.attention is None and self.cross_attention is None:
  632. logger.debug("EmbedLayerNormalization will not have mask since attention node is not found")
  633. self.increase_counter("EmbedLayerNormalization(no mask)")
  634. return
  635. if self.attention:
  636. mask_int32 = self.attention.input[3]
  637. else:
  638. mask_int32 = self.cross_attention.input[4]
  639. children_nodes = input_name_to_nodes[mask_int32]
  640. if self.model.find_graph_input(mask_int32):
  641. attention_nodes = [node for node in children_nodes if node.op_type in ["Attention", "MultiHeadAttention"]]
  642. self.replace_mask(mask_int32, attention_nodes)
  643. self.increase_counter("EmbedLayerNormalization(with mask)")
  644. return
  645. if mask_int32 not in output_name_to_node:
  646. logger.debug("EmbedLayerNormalization will not have mask since %s is not a node output", mask_int32)
  647. self.increase_counter("EmbedLayerNormalization(no mask)")
  648. return
  649. node = output_name_to_node[mask_int32]
  650. if node.op_type in ["ReduceSum", "Cast"]:
  651. attention_nodes = [node for node in children_nodes if node.op_type in ["Attention", "MultiHeadAttention"]]
  652. if node.op_type == "ReduceSum":
  653. mask_int32 = node.input[0]
  654. if len(children_nodes) == len(attention_nodes):
  655. self.nodes_to_remove.append(node)
  656. self.replace_mask(mask_int32, attention_nodes)
  657. self.increase_counter("EmbedLayerNormalization(with mask)")