m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

250 lines
9.7 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from typing import Dict, Optional
  7. from fusion_base import Fusion
  8. from onnx import helper
  9. from onnx_model import OnnxModel
  10. logger = getLogger(__name__)
  11. class FusionGelu(Fusion):
  12. def __init__(self, model: OnnxModel):
  13. super().__init__(model, "Gelu", "Erf")
  14. def fuse(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict):
  15. if self.fuse_1(erf_node, input_name_to_nodes, output_name_to_node):
  16. return
  17. if self.fuse_2(erf_node, input_name_to_nodes, output_name_to_node):
  18. return
  19. self.fuse_3(erf_node, input_name_to_nodes, output_name_to_node)
  20. def fuse_1(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]:
  21. """
  22. This pattern is from PyTorch model
  23. Fuse Gelu with Erf into one node:
  24. Pattern 1:
  25. +-------Mul(0.5)---------------------+
  26. | |
  27. | v
  28. [root] --> Div -----> Erf --> Add --> Mul -->
  29. (B=1.4142...) (1)
  30. Pattern 2:
  31. +------------------------------------+
  32. | |
  33. | v
  34. [root] --> Div -----> Erf --> Add --> Mul -->Mul -->
  35. (B=1.4142...) (1) (0.5)
  36. Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
  37. """
  38. if erf_node.output[0] not in input_name_to_nodes:
  39. return
  40. children = input_name_to_nodes[erf_node.output[0]]
  41. if len(children) != 1 or children[0].op_type != "Add":
  42. return
  43. add_after_erf = children[0]
  44. if not self.model.has_constant_input(add_after_erf, 1):
  45. return
  46. if add_after_erf.output[0] not in input_name_to_nodes:
  47. return
  48. children = input_name_to_nodes[add_after_erf.output[0]]
  49. if len(children) != 1 or children[0].op_type != "Mul":
  50. return
  51. mul_after_erf = children[0]
  52. div = self.model.match_parent(erf_node, "Div", 0, output_name_to_node)
  53. if div is None:
  54. return
  55. if self.model.find_constant_input(div, 1.4142, delta=0.001) != 1:
  56. return
  57. subgraph_input = div.input[0]
  58. another = 1 if mul_after_erf.input[0] == add_after_erf.output[0] else 0
  59. if subgraph_input == mul_after_erf.input[another]: # pattern 2
  60. children = input_name_to_nodes[mul_after_erf.output[0]]
  61. if len(children) != 1 or children[0].op_type != "Mul":
  62. return
  63. mul_half = children[0]
  64. if not self.model.has_constant_input(mul_half, 0.5):
  65. return
  66. subgraph_output = mul_half.output[0]
  67. else: # pattern 1
  68. mul_half = self.model.match_parent(mul_after_erf, "Mul", another, output_name_to_node)
  69. if mul_half is None:
  70. return
  71. if not self.model.has_constant_input(mul_half, 0.5):
  72. return
  73. if subgraph_input not in mul_half.input:
  74. return
  75. subgraph_output = mul_after_erf.output[0]
  76. subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul_half]
  77. if not self.model.is_safe_to_fuse_nodes(
  78. subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node
  79. ):
  80. return
  81. self.nodes_to_remove.extend(subgraph_nodes)
  82. fused_node = helper.make_node("Gelu", inputs=[subgraph_input], outputs=[subgraph_output])
  83. fused_node.domain = "com.microsoft"
  84. self.nodes_to_add.append(fused_node)
  85. self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
  86. return True
  87. def fuse_2(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]:
  88. """
  89. This pattern is from Keras model
  90. Fuse Gelu with Erf into one node:
  91. +------------------------------------------+
  92. | |
  93. | v
  94. [root] --> Div -----> Erf --> Add --> Mul -->Mul
  95. (B=1.4142...) (A=1) (A=0.5)
  96. Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
  97. """
  98. if erf_node.output[0] not in input_name_to_nodes:
  99. return
  100. children = input_name_to_nodes[erf_node.output[0]]
  101. if len(children) != 1 or children[0].op_type != "Add":
  102. return
  103. add_after_erf = children[0]
  104. if not self.model.has_constant_input(add_after_erf, 1):
  105. return
  106. if add_after_erf.output[0] not in input_name_to_nodes:
  107. return
  108. children = input_name_to_nodes[add_after_erf.output[0]]
  109. if len(children) != 1 or children[0].op_type != "Mul":
  110. return
  111. mul_after_erf = children[0]
  112. if not self.model.has_constant_input(mul_after_erf, 0.5):
  113. return
  114. if mul_after_erf.output[0] not in input_name_to_nodes:
  115. return
  116. children = input_name_to_nodes[mul_after_erf.output[0]]
  117. if len(children) != 1 or children[0].op_type != "Mul":
  118. return
  119. mul = children[0]
  120. div = self.model.match_parent(erf_node, "Div", 0, output_name_to_node)
  121. if div is None:
  122. return
  123. sqrt_node = None
  124. if self.model.find_constant_input(div, 1.4142, delta=0.001) != 1:
  125. sqrt_node = self.model.match_parent(div, "Sqrt", 1, output_name_to_node)
  126. if sqrt_node is None:
  127. return
  128. if not self.model.has_constant_input(sqrt_node, 2.0):
  129. return
  130. root_node = self.model.get_parent(div, 0, output_name_to_node)
  131. if root_node is None:
  132. return
  133. if root_node.output[0] not in mul.input:
  134. return
  135. subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul]
  136. if sqrt_node:
  137. subgraph_nodes.append(sqrt_node)
  138. if not self.model.is_safe_to_fuse_nodes(
  139. subgraph_nodes, [mul.output[0]], input_name_to_nodes, output_name_to_node
  140. ):
  141. return
  142. self.nodes_to_remove.extend(subgraph_nodes)
  143. fused_node = helper.make_node("Gelu", inputs=[root_node.output[0]], outputs=[mul.output[0]])
  144. fused_node.domain = "com.microsoft"
  145. self.nodes_to_add.append(fused_node)
  146. self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
  147. return True
  148. def fuse_3(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]:
  149. """
  150. This pattern is from TensorFlow model
  151. Fuse Gelu with Erf into one node:
  152. +----------------------------------------------+
  153. | |
  154. | v
  155. [root] --> Mul -----> Erf --> Add --> Mul -->Mul
  156. (A=0.7071067690849304) (B=1) (B=0.5)
  157. Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
  158. """
  159. if erf_node.output[0] not in input_name_to_nodes:
  160. return
  161. children = input_name_to_nodes[erf_node.output[0]]
  162. if len(children) != 1 or children[0].op_type != "Add":
  163. return
  164. add_after_erf = children[0]
  165. if not self.model.has_constant_input(add_after_erf, 1):
  166. return
  167. if add_after_erf.output[0] not in input_name_to_nodes:
  168. return
  169. children = input_name_to_nodes[add_after_erf.output[0]]
  170. if len(children) != 1 or children[0].op_type != "Mul":
  171. return
  172. mul_half = children[0]
  173. if not self.model.has_constant_input(mul_half, 0.5):
  174. return
  175. first_mul = self.model.match_parent(erf_node, "Mul", 0, output_name_to_node)
  176. if first_mul is None:
  177. return
  178. i = self.model.find_constant_input(first_mul, 0.7071067690849304, delta=0.001)
  179. if i < 0:
  180. return
  181. root_node = self.model.get_parent(first_mul, 0 if i == 1 else 1, output_name_to_node)
  182. if root_node is None:
  183. return
  184. if mul_half.output[0] not in input_name_to_nodes:
  185. return
  186. children = input_name_to_nodes[mul_half.output[0]]
  187. if len(children) != 1 or children[0].op_type != "Mul":
  188. return
  189. last_mul = children[0]
  190. if not (last_mul.input[0] == root_node.output[0] or last_mul.input[1] == root_node.output[0]):
  191. return
  192. subgraph_nodes = [first_mul, erf_node, add_after_erf, mul_half, last_mul]
  193. if not self.model.is_safe_to_fuse_nodes(
  194. subgraph_nodes,
  195. [last_mul.output[0]],
  196. input_name_to_nodes,
  197. output_name_to_node,
  198. ):
  199. return
  200. self.nodes_to_remove.extend(subgraph_nodes)
  201. fused_node = helper.make_node("Gelu", inputs=[root_node.output[0]], outputs=[last_mul.output[0]])
  202. fused_node.domain = "com.microsoft"
  203. self.nodes_to_add.append(fused_node)
  204. self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
  205. return True