m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

315 lines
11 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import logging
  6. from fusion_attention import AttentionMask, FusionAttention
  7. from fusion_reshape import FusionReshape
  8. from onnx import numpy_helper
  9. from onnx_model import OnnxModel
  10. from onnx_model_bert import BertOnnxModel
  11. logger = logging.getLogger(__name__)
  12. class FusionBartEncoderAttention(FusionAttention):
  13. """
  14. Fuse Bart Attention subgraph into one Attention node.
  15. """
  16. def __init__(
  17. self,
  18. model: OnnxModel,
  19. hidden_size: int,
  20. num_heads: int,
  21. attention_mask: AttentionMask,
  22. ):
  23. super().__init__(model, hidden_size, num_heads, attention_mask)
  24. def check_runtime_shape_path(
  25. self,
  26. reshape_qkv_2,
  27. reshape_qkv_1,
  28. reshape_q_2,
  29. reshape_k_2,
  30. reshape_v_2,
  31. root_input,
  32. ):
  33. concat_qkv_2_path = self.model.match_parent_path(reshape_qkv_2, ["Concat"], [1])
  34. if concat_qkv_2_path is None:
  35. return False
  36. concat_qkv_2 = concat_qkv_2_path[0]
  37. reshape_qkv_2_path_1 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
  38. reshape_qkv_2_path_2 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
  39. reshape_qkv_2_path_3 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [2, 0, 0])
  40. if reshape_qkv_2_path_1 is None or reshape_qkv_2_path_2 is None or reshape_qkv_2_path_3 is None:
  41. return False
  42. _, gather_1, shape_1 = reshape_qkv_2_path_1
  43. _, gather_2, shape_2 = reshape_qkv_2_path_2
  44. _, _, shape_3 = reshape_qkv_2_path_3
  45. if shape_1.input[0] != root_input or shape_2.input[0] != root_input or shape_3.input[0] != root_input:
  46. return False
  47. reshape_qkv_1_path_1 = self.model.match_parent_path(reshape_qkv_1, ["Concat", "Unsqueeze", "Gather"], [1, 0, 0])
  48. reshape_qkv_1_path_2 = self.model.match_parent_path(reshape_qkv_1, ["Concat", "Unsqueeze", "Gather"], [1, 2, 0])
  49. if reshape_qkv_1_path_1 is None or reshape_qkv_1_path_2 is None:
  50. return False
  51. if reshape_qkv_1_path_1[-1].name != gather_1.name or reshape_qkv_1_path_2[-1].name != gather_2.name:
  52. return False
  53. reshape_q_2_path = self.model.match_parent_path(reshape_q_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0])
  54. reshape_k_2_path = self.model.match_parent_path(reshape_k_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0])
  55. reshape_v_2_path = self.model.match_parent_path(reshape_v_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0])
  56. if reshape_q_2_path is None or reshape_k_2_path is None or reshape_v_2_path is None:
  57. return False
  58. mul_q = reshape_q_2_path[-1]
  59. mul_k = reshape_k_2_path[-1]
  60. mul_v = reshape_v_2_path[-1]
  61. gather_1_out = gather_1.output[0]
  62. if mul_q.input[0] != gather_1_out or mul_k.input[0] != gather_1_out or mul_v.input[0] != gather_1_out:
  63. return False
  64. return True
  65. def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
  66. # SkipLayerNormalization has two inputs, and one of them is the root input for attention.
  67. qkv_nodes = self.model.match_parent_path(
  68. normalize_node,
  69. ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
  70. [None, 1, 0, 0, 0, 0],
  71. )
  72. if qkv_nodes is not None:
  73. (
  74. add_out,
  75. matmul_out,
  76. reshape_qkv_2,
  77. transpose_qkv,
  78. reshape_qkv_1,
  79. matmul_qkv,
  80. ) = qkv_nodes
  81. else:
  82. return
  83. other_inputs = []
  84. for i, input in enumerate(normalize_node.input):
  85. if input not in output_name_to_node:
  86. continue
  87. if input == qkv_nodes[0].output[0]:
  88. continue
  89. other_inputs.append(input)
  90. if len(other_inputs) != 1:
  91. return
  92. root_input = other_inputs[0]
  93. children = input_name_to_nodes[root_input]
  94. children_types = [child.op_type for child in children]
  95. if children_types.count("MatMul") != 3:
  96. return
  97. v_nodes = self.model.match_parent_path(
  98. matmul_qkv,
  99. ["Reshape", "Transpose", "Reshape", "Add", "MatMul"],
  100. [1, 0, 0, 0, None],
  101. )
  102. if v_nodes is None:
  103. logger.debug("fuse_attention: failed to match v path")
  104. return
  105. (reshape_v_2, transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes
  106. qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0])
  107. if qk_nodes is not None:
  108. _, matmul_qk = qk_nodes
  109. else:
  110. return
  111. q_nodes = self.model.match_parent_path(
  112. matmul_qk,
  113. ["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"],
  114. [0, 0, 0, 0, 0, 1],
  115. )
  116. if q_nodes is not None:
  117. reshape_q_2, _, reshape_q_1, _, add_q, matmul_q = q_nodes
  118. else:
  119. return
  120. k_nodes = self.model.match_parent_path(
  121. matmul_qk,
  122. ["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"],
  123. [1, 0, 0, 0, 0, 1],
  124. )
  125. if k_nodes is not None:
  126. _, reshape_k_2, _, reshape_k_1, add_k, matmul_k = k_nodes
  127. else:
  128. return
  129. if not self.check_runtime_shape_path(
  130. reshape_qkv_2,
  131. reshape_qkv_1,
  132. reshape_q_2,
  133. reshape_k_2,
  134. reshape_v_2,
  135. root_input,
  136. ):
  137. return
  138. if matmul_v.input[0] == root_input and matmul_q.input[0] == root_input and matmul_v.input[0] == root_input:
  139. mask_nodes = []
  140. mask_index = None
  141. attention_last_node = reshape_qkv_2
  142. num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q_1)
  143. if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0:
  144. logger.debug("fuse_attention: failed to detect num_heads or hidden_size")
  145. return
  146. new_node = self.create_attention_node(
  147. mask_index,
  148. matmul_q,
  149. matmul_k,
  150. matmul_v,
  151. add_q,
  152. add_k,
  153. add_v,
  154. num_heads,
  155. hidden_size,
  156. root_input,
  157. attention_last_node.output[0],
  158. None,
  159. )
  160. if new_node is None:
  161. return
  162. self.nodes_to_add.append(new_node)
  163. self.node_name_to_graph_name[new_node.name] = self.this_graph_name
  164. self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
  165. self.nodes_to_remove.extend(qk_nodes)
  166. self.nodes_to_remove.extend(q_nodes)
  167. self.nodes_to_remove.extend(k_nodes)
  168. self.nodes_to_remove.extend(v_nodes)
  169. # Use prune graph to remove mask nodes since they are shared by all attention nodes.
  170. self.nodes_to_remove.extend(mask_nodes)
  171. self.prune_graph = True
  172. class FusionBartReshape(FusionReshape):
  173. def __init__(self, model: OnnxModel):
  174. super().__init__(model)
  175. def fuse(self, reshape_node, input_name_to_nodes, output_name_to_node):
  176. if reshape_node.input[1] not in output_name_to_node:
  177. return
  178. concat_node = output_name_to_node[reshape_node.input[1]]
  179. if concat_node.op_type != "Concat" or len(concat_node.input) != 4:
  180. return
  181. path0 = self.model.match_parent_path(
  182. concat_node,
  183. ["Unsqueeze", "Gather", "Shape"],
  184. [0, 0, 0],
  185. output_name_to_node,
  186. )
  187. if path0 is None:
  188. return
  189. (_, gather_0, shape_0) = path0
  190. shape = []
  191. gather_value = self.model.get_constant_value(gather_0.input[1])
  192. if gather_value == 0:
  193. shape.append(0)
  194. path1 = self.model.match_parent_path(
  195. concat_node,
  196. ["Unsqueeze", "Gather", "Shape"],
  197. [1, 0, 0],
  198. output_name_to_node,
  199. )
  200. if path1 is None:
  201. input_1_proto = self.model.get_initializer(concat_node.input[1])
  202. input_2_proto = self.model.get_initializer(concat_node.input[2])
  203. input_3_proto = self.model.get_initializer(concat_node.input[3])
  204. if input_1_proto is None or input_2_proto is None or input_3_proto is None:
  205. return
  206. input_1 = numpy_helper.to_array(input_1_proto)
  207. input_2 = numpy_helper.to_array(input_2_proto)
  208. input_3 = numpy_helper.to_array(input_3_proto)
  209. if len(input_1) != 1 or len(input_2) != 1 or len(input_3) != 1:
  210. return
  211. if not (input_1[0] == -1 and input_2[0] > 0 and input_3[0] > 0):
  212. return
  213. shape.extend(input_1)
  214. shape.extend(input_2)
  215. shape.extend(input_3)
  216. gemm_path = self.model.match_parent_path(reshape_node, ["Add", "MatMul"], [0, 1], output_name_to_node)
  217. if gemm_path is None:
  218. return
  219. top_matmul = gemm_path[-1]
  220. root_input = top_matmul.input[0]
  221. if shape_0.input[0] != root_input:
  222. return
  223. self.replace_reshape_node(shape, reshape_node, concat_node)
  224. else:
  225. (_, gather_1, shape_1) = path1
  226. gather_value = self.model.get_constant_value(gather_1.input[1])
  227. if gather_value == 1:
  228. shape.append(0)
  229. input_2_proto = self.model.get_initializer(concat_node.input[2])
  230. input_3_proto = self.model.get_initializer(concat_node.input[3])
  231. if input_2_proto is None or input_3_proto is None:
  232. return
  233. input_2 = numpy_helper.to_array(input_2_proto)
  234. input_3 = numpy_helper.to_array(input_3_proto)
  235. if len(input_2) != 1 or len(input_3) != 1:
  236. return
  237. if not (input_2[0] > 0 and input_3[0] > 0):
  238. return
  239. shape.extend(input_2)
  240. shape.extend(input_3)
  241. gemm_path = self.model.match_parent_path(
  242. reshape_node, ["Mul", "Add", "MatMul"], [0, 0, 1], output_name_to_node
  243. )
  244. if gemm_path is None:
  245. return
  246. top_matmul = gemm_path[-1]
  247. root_input = top_matmul.input[0]
  248. if shape_0.input[0] != root_input or shape_1.input[0] != root_input:
  249. return
  250. self.replace_reshape_node(shape, reshape_node, concat_node)
  251. class BartOnnxModel(BertOnnxModel):
  252. def __init__(self, model, num_heads, hidden_size):
  253. super().__init__(model, num_heads, hidden_size)
  254. self.attention_mask = AttentionMask(self)
  255. self.attention_fusion = FusionBartEncoderAttention(self, self.hidden_size, self.num_heads, self.attention_mask)
  256. self.bart_reshape_fusion_preprocess = FusionBartReshape(self)
  257. def fuse_attention(self):
  258. self.attention_fusion.apply()
  259. def preprocess(self):
  260. self.adjust_reshape_and_expand()
  261. self.bart_reshape_fusion_preprocess.apply()