m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

230 lines
8.2 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import logging
  6. from typing import Union
  7. from fusion_attention import AttentionMask, FusionAttention
  8. from fusion_utils import NumpyHelper
  9. from onnx import NodeProto, TensorProto, helper, numpy_helper
  10. from onnx_model import OnnxModel
  11. from onnx_model_bert import BertOnnxModel
  12. logger = logging.getLogger(__name__)
  13. class FusionTnlrAttention(FusionAttention):
  14. """
  15. Fuse TNLR Attention subgraph into one Attention node.
  16. TNLR Attention has extra addtion after qk nodes and adopts [S, B, NH] as I/O shape.
  17. """
  18. def __init__(
  19. self,
  20. model: OnnxModel,
  21. hidden_size: int,
  22. num_heads: int,
  23. attention_mask: AttentionMask,
  24. ):
  25. super().__init__(model, hidden_size, num_heads, attention_mask)
  26. def create_attention_node(
  27. self,
  28. mask_index: str,
  29. matmul: NodeProto,
  30. add: NodeProto,
  31. num_heads: int,
  32. hidden_size: int,
  33. input: str,
  34. output: str,
  35. add_qk_str: str,
  36. ) -> Union[NodeProto, None]:
  37. assert num_heads > 0
  38. if hidden_size > 0 and (hidden_size % num_heads) != 0:
  39. logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
  40. return None
  41. weight = self.model.get_initializer(matmul.input[1])
  42. bias = self.model.get_initializer(add.input[1]) or self.model.get_initializer(add.input[0])
  43. if weight is None or bias is None:
  44. return None
  45. qkv_weight = NumpyHelper.to_array(weight)
  46. qkv_bias = NumpyHelper.to_array(bias)
  47. attention_node_name = self.model.create_node_name("Attention")
  48. weight = helper.make_tensor(
  49. name=attention_node_name + "_qkv_weight",
  50. data_type=TensorProto.FLOAT,
  51. dims=[hidden_size, 3 * hidden_size],
  52. vals=qkv_weight.flatten().tolist(),
  53. )
  54. # Sometimes weights and bias are stored in fp16
  55. if weight.data_type == 10:
  56. weight.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(weight).astype(np.float16), weight.name))
  57. self.model.add_initializer(weight, self.this_graph_name)
  58. bias = helper.make_tensor(
  59. name=attention_node_name + "_qkv_bias",
  60. data_type=TensorProto.FLOAT,
  61. dims=[3 * hidden_size],
  62. vals=qkv_bias.flatten().tolist(),
  63. )
  64. if bias.data_type == 10:
  65. bias.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(bias).astype(np.float16), bias.name))
  66. self.model.add_initializer(bias, self.this_graph_name)
  67. attention_inputs = [
  68. input,
  69. attention_node_name + "_qkv_weight",
  70. attention_node_name + "_qkv_bias",
  71. ]
  72. if mask_index is not None:
  73. attention_inputs.append(mask_index)
  74. else:
  75. attention_inputs.append("")
  76. if add_qk_str is not None:
  77. attention_inputs.append("")
  78. attention_inputs.append(add_qk_str)
  79. attention_node = helper.make_node(
  80. "Attention",
  81. inputs=attention_inputs,
  82. outputs=[output],
  83. name=attention_node_name,
  84. )
  85. attention_node.domain = "com.microsoft"
  86. attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
  87. return attention_node
  88. def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
  89. # Sometimes we can not fuse skiplayernormalization since the add before layernorm has an output that used by nodes outside skiplayernorm
  90. # Conceptually we treat add before layernorm as skiplayernorm node since they share the same pattern
  91. start_node = normalize_node
  92. if normalize_node.op_type != "SkipLayerNormalization":
  93. return
  94. # SkipLayerNormalization has two inputs, and one of them is the root input for attention.
  95. qkv_nodes = self.model.match_parent_path(
  96. start_node,
  97. ["Where", "Add", "MatMul", "Reshape", "Transpose", "MatMul"],
  98. [1, 1, 1, 0, 0, 0],
  99. )
  100. if qkv_nodes is not None:
  101. (_, _, matmul_below, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes
  102. else:
  103. return
  104. other_inputs = []
  105. for i, input in enumerate(start_node.input):
  106. if input not in output_name_to_node:
  107. continue
  108. if input == qkv_nodes[0].output[0]:
  109. continue
  110. other_inputs.append(input)
  111. if len(other_inputs) != 1:
  112. return
  113. root_input = other_inputs[0]
  114. v_nodes = self.model.match_parent_path(
  115. matmul_qkv,
  116. ["Transpose", "Reshape", "Slice", "Add", "MatMul"],
  117. [1, 0, 0, 0, 1],
  118. )
  119. if v_nodes is None:
  120. return
  121. (_, _, _, add, matmul) = v_nodes
  122. upper_nodes = self.model.match_parent_path(matmul, ["Transpose"], [0])
  123. transpose = upper_nodes[0]
  124. qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0])
  125. if qk_nodes is None:
  126. return
  127. (_, add_qk, matmul_qk) = qk_nodes
  128. q_nodes = self.model.match_parent_path(
  129. matmul_qk,
  130. ["Mul", "Transpose", "Reshape", "Slice", "Add", "MatMul"],
  131. [0, 0, 0, 0, 0, 1],
  132. )
  133. if q_nodes is None:
  134. return
  135. add = q_nodes[-2]
  136. matmul = q_nodes[-1]
  137. k_nodes = self.model.match_parent_path(
  138. matmul_qk,
  139. ["Transpose", "Reshape", "Slice", "Add", "MatMul"],
  140. [1, 0, 0, 0, 1],
  141. )
  142. if k_nodes is None:
  143. return
  144. add = k_nodes[-2]
  145. matmul = k_nodes[-1]
  146. relative_position_bias_nodes = self.model.match_parent_path(add_qk, ["Reshape", "Where"], [1, 0])
  147. if relative_position_bias_nodes is None:
  148. return
  149. if matmul.input[0] == root_input:
  150. mask_index = None
  151. attention_last_node = reshape_qkv
  152. # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
  153. # the input_hidden_size represents the input hidden size, this is used as needed but hidden sizes for Q, K are extracted appropriately
  154. new_node = self.create_attention_node(
  155. mask_index,
  156. matmul,
  157. add,
  158. self.num_heads,
  159. self.hidden_size,
  160. root_input,
  161. attention_last_node.output[0],
  162. relative_position_bias_nodes[0].input[0],
  163. )
  164. if new_node is None:
  165. return
  166. self.nodes_to_add.append(new_node)
  167. self.node_name_to_graph_name[new_node.name] = self.this_graph_name
  168. # Add a transpose node after the attention node
  169. back_transpose = helper.make_node(
  170. "Transpose",
  171. ["back_transpose_in_" + new_node.name],
  172. [new_node.output[0]],
  173. "back_transpose_" + new_node.name,
  174. perm=[1, 0, 2],
  175. )
  176. self.model.add_node(back_transpose, self.this_graph_name)
  177. new_node.input[0] = transpose.input[0]
  178. new_node.output[0] = "back_transpose_in_" + new_node.name
  179. self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
  180. self.nodes_to_remove.extend(qk_nodes)
  181. self.nodes_to_remove.extend(q_nodes)
  182. self.nodes_to_remove.extend(k_nodes)
  183. self.nodes_to_remove.extend(v_nodes)
  184. # Use prune graph to remove mask nodes since they are shared by all attention nodes.
  185. # self.nodes_to_remove.extend(mask_nodes)
  186. self.prune_graph = True
  187. class TnlrOnnxModel(BertOnnxModel):
  188. def __init__(self, model, num_heads, hidden_size):
  189. super().__init__(model, num_heads, hidden_size)
  190. self.attention_mask = AttentionMask(self)
  191. self.attention_fusion = FusionTnlrAttention(self, self.hidden_size, self.num_heads, self.attention_mask)
  192. def fuse_attention(self):
  193. self.attention_fusion.apply()