m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

186 lines
7.5 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from fusion_base import Fusion
  7. from fusion_utils import NumpyHelper
  8. from onnx import helper
  9. from onnx_model import OnnxModel
  10. logger = getLogger(__name__)
  11. class FusionSkipLayerNormalization(Fusion):
  12. """
  13. Fuse Add + LayerNormalization into one node: SkipLayerNormalization
  14. Note: This fusion does not check the input shape of Add and LayerNormalization.
  15. """
  16. def __init__(self, model: OnnxModel):
  17. super().__init__(model, "SkipLayerNormalization", "LayerNormalization")
  18. # Update shape inference is needed since other fusions might add new edge which does not have shape info yet.
  19. self.shape_infer_helper = self.model.infer_runtime_shape({"batch_size": 4, "seq_len": 7}, update=True)
  20. if self.shape_infer_helper is None:
  21. # TODO(tianleiwu): support subgraph in shape inference or add broadcasting in SkipLayerNormalization op.
  22. logger.warning("symbolic shape inference disabled or failed.")
  23. def fuse(self, node, input_name_to_nodes, output_name_to_node):
  24. add = self.model.get_parent(node, 0, output_name_to_node)
  25. # In some models there is input_ids->gather->add->LayerNorm and one of input of the
  26. # add node is initializer with fixed shape which should not be fused into SkipLayerNorm
  27. if add is None:
  28. return
  29. for add_input in add.input:
  30. if self.model.get_initializer(add_input) != None:
  31. return
  32. # The number of input node of add should be 2
  33. if len(self.model.get_parents(add)) != 2:
  34. return
  35. if self.shape_infer_helper is not None:
  36. if not self.shape_infer_helper.compare_shape(add.input[0], add.input[1]):
  37. logger.debug(
  38. "skip SkipLayerNormalization fusion since shape of inputs (%s, %s) are not same",
  39. add.input[0],
  40. add.input[1],
  41. )
  42. return
  43. else:
  44. logger.debug("skip SkipLayerNormalization fusion since symbolic shape inference failed")
  45. return
  46. gather_path = self.model.match_parent_path(add, ["Gather"], [None])
  47. if gather_path is not None and self.model.find_graph_input(gather_path[0].input[1]) is None:
  48. if self.model.match_parent_path(gather_path[0], ["ConstantOfShape"], [1]) is None:
  49. return
  50. residual_add_has_multiple_consumers = False
  51. add_children = self.model.get_children(add, input_name_to_nodes)
  52. # This means that the residual Add before the LayerNormalization produces an output
  53. # that is consumed by some other nodes other than the LayerNormalization itself
  54. # We can still go ahead with the SkipLayerNormalization fusion but we need to
  55. # preserve the output of Add and that needs to be produced by SkipLayerNormalization.
  56. if len(add_children) != 1:
  57. residual_add_has_multiple_consumers = True
  58. outputs_to_keep = node.output
  59. if residual_add_has_multiple_consumers:
  60. outputs_to_keep.extend([add.output[0]])
  61. outputs = [node.output[0]]
  62. # Skip the other optional outputs of SkipLayerNormalization before adding the Add's output
  63. if residual_add_has_multiple_consumers:
  64. outputs.extend(["", "", add.output[0]])
  65. if (
  66. add is not None
  67. and add.op_type == "Add"
  68. and self.model.is_safe_to_fuse_nodes([add, node], outputs_to_keep, input_name_to_nodes, output_name_to_node)
  69. ):
  70. self.nodes_to_remove.extend([add, node])
  71. inputs = [add.input[0], add.input[1], node.input[1], node.input[2]]
  72. normalize_node = helper.make_node(
  73. "SkipLayerNormalization",
  74. inputs=inputs,
  75. outputs=outputs,
  76. name=self.model.create_node_name("SkipLayerNormalization", name_prefix="SkipLayerNorm"),
  77. )
  78. normalize_node.domain = "com.microsoft"
  79. # Pass attribute "epsilon" from layernorm node to SkipLayerNormalization
  80. for att in node.attribute:
  81. if att.name == "epsilon":
  82. normalize_node.attribute.extend([att])
  83. # Set default epsilon if no epsilon exists from layernorm
  84. if len(normalize_node.attribute) == 0:
  85. normalize_node.attribute.extend([helper.make_attribute("epsilon", 1.0e-12)])
  86. self.nodes_to_add.append(normalize_node)
  87. self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name
  88. class FusionBiasSkipLayerNormalization(Fusion):
  89. def __init__(self, model: OnnxModel):
  90. super().__init__(model, "SkipLayerNormalization", "SkipLayerNormalization", "add bias")
  91. def fuse(self, node, input_name_to_nodes, output_name_to_node):
  92. if len(node.input) != 4:
  93. return
  94. return_indice = []
  95. nodes = self.model.match_parent_path(node, ["Add", "MatMul"], [None, None], None, return_indice)
  96. if nodes is None:
  97. # In case of fp16, we could have a Cast between the MatMul and the bias Add
  98. nodes = self.model.match_parent_path(
  99. node, ["Add", "Cast", "MatMul"], [None, None, None], None, return_indice
  100. )
  101. if nodes is None:
  102. return
  103. assert len(return_indice) == 2 or len(return_indice) == 3
  104. add_input_index = return_indice[0]
  105. if add_input_index >= 2:
  106. return
  107. (add, matmul) = nodes
  108. # bias should be one dimension
  109. bias_index = -1
  110. for i, input in enumerate(add.input):
  111. initializer = self.model.get_initializer(input)
  112. if initializer is None:
  113. continue
  114. bias_index = i
  115. bias_weight = NumpyHelper.to_array(initializer)
  116. break
  117. if bias_weight is None:
  118. logger.debug(f"Bias weight not found")
  119. return
  120. if len(bias_weight.shape) != 1:
  121. logger.debug(f"Bias weight is not 1D")
  122. return
  123. subgraph_nodes = [node, add]
  124. if not self.model.is_safe_to_fuse_nodes(subgraph_nodes, node.output, input_name_to_nodes, output_name_to_node):
  125. logger.debug(f"Skip fusing SkipLayerNormalization with Bias since it is not safe")
  126. return
  127. self.nodes_to_remove.extend(subgraph_nodes)
  128. inputs = [
  129. node.input[1 - add_input_index],
  130. matmul.output[0],
  131. node.input[2],
  132. node.input[3],
  133. add.input[bias_index],
  134. ]
  135. new_node = helper.make_node(
  136. "SkipLayerNormalization",
  137. inputs=inputs,
  138. outputs=node.output,
  139. name=self.model.create_node_name("SkipLayerNormalization", "SkipLayerNorm_AddBias_"),
  140. )
  141. new_node.domain = "com.microsoft"
  142. # Pass attribute "epsilon" from skiplayernorm node to skiplayernorm(add bias)
  143. for att in node.attribute:
  144. if att.name == "epsilon":
  145. new_node.attribute.extend([att])
  146. # Set default epsilon if no epsilon exists from skiplayernorm
  147. if len(new_node.attribute) == 0:
  148. new_node.attribute.extend([helper.make_attribute("epsilon", 1.0e-12)])
  149. self.nodes_to_add.append(new_node)
  150. self.node_name_to_graph_name[new_node.name] = self.this_graph_name