m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

360 lines
13 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from typing import Dict, Optional
  7. from fusion_base import Fusion
  8. from onnx import helper
  9. from onnx_model import OnnxModel
  10. logger = getLogger(__name__)
  11. class FusionFastGelu(Fusion):
  12. def __init__(self, model: OnnxModel):
  13. super().__init__(model, "FastGelu", "Tanh")
  14. def fuse(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict):
  15. if self.fuse_1(tanh_node, input_name_to_nodes, output_name_to_node):
  16. return
  17. if self.fuse_2(tanh_node, input_name_to_nodes, output_name_to_node):
  18. return
  19. if self.fuse_3(tanh_node, input_name_to_nodes, output_name_to_node):
  20. return
  21. def fuse_1(self, tanh_node, input_name_to_nodes, output_name_to_node) -> Optional[bool]:
  22. """
  23. Fuse Gelu with tanh into one node:
  24. +---------------------------+
  25. | |
  26. | v
  27. [root] --> Pow --> Mul -----> Add --> Mul --> Tanh --> Add --> Mul
  28. | (Y=3) (B=0.0447...) (B=0.7978...) (B=1) ^
  29. | |
  30. +------> Mul(B=0.5)--------------------------------------------+
  31. Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
  32. """
  33. if tanh_node.output[0] not in input_name_to_nodes:
  34. return
  35. children = input_name_to_nodes[tanh_node.output[0]]
  36. if len(children) != 1 or children[0].op_type != "Add":
  37. return
  38. add_after_tanh = children[0]
  39. if not self.model.has_constant_input(add_after_tanh, 1.0):
  40. return
  41. if add_after_tanh.output[0] not in input_name_to_nodes:
  42. return
  43. children = input_name_to_nodes[add_after_tanh.output[0]]
  44. if len(children) != 1 or children[0].op_type != "Mul":
  45. return
  46. mul_after_tanh = children[0]
  47. mul_half = self.model.match_parent(mul_after_tanh, "Mul", None, output_name_to_node)
  48. if mul_half is None:
  49. return
  50. i = self.model.find_constant_input(mul_half, 0.5)
  51. if i < 0:
  52. return
  53. root_input = mul_half.input[0 if i == 1 else 1]
  54. # root_node could be None when root_input is graph input
  55. root_node = self.model.get_parent(mul_half, 0 if i == 1 else 1, output_name_to_node)
  56. mul_before_tanh = self.model.match_parent(tanh_node, "Mul", 0, output_name_to_node)
  57. if mul_before_tanh is None:
  58. return
  59. i = self.model.find_constant_input(mul_before_tanh, 0.7978, delta=0.0001)
  60. if i < 0:
  61. return
  62. add_before_tanh = self.model.match_parent(mul_before_tanh, "Add", 0 if i == 1 else 1, output_name_to_node)
  63. if add_before_tanh is None:
  64. return
  65. mul_after_pow = self.model.match_parent(
  66. add_before_tanh,
  67. "Mul",
  68. None,
  69. output_name_to_node,
  70. exclude=[root_node] if root_node else [],
  71. )
  72. if mul_after_pow is None:
  73. return
  74. i = self.model.find_constant_input(mul_after_pow, 0.0447, delta=0.0001)
  75. if i < 0:
  76. return
  77. pow = self.model.match_parent(mul_after_pow, "Pow", 0 if i == 1 else 1, output_name_to_node)
  78. if pow is None:
  79. return
  80. if not self.model.has_constant_input(pow, 3.0):
  81. return
  82. if pow.input[0] != root_input:
  83. return
  84. subgraph_nodes = [
  85. mul_after_tanh,
  86. mul_half,
  87. add_after_tanh,
  88. tanh_node,
  89. mul_before_tanh,
  90. add_before_tanh,
  91. mul_after_pow,
  92. pow,
  93. ]
  94. if not self.model.is_safe_to_fuse_nodes(
  95. subgraph_nodes,
  96. [mul_after_tanh.output[0]],
  97. input_name_to_nodes,
  98. output_name_to_node,
  99. ):
  100. return
  101. self.nodes_to_remove.extend(subgraph_nodes)
  102. fused_node = helper.make_node(
  103. "FastGelu",
  104. inputs=[root_input],
  105. outputs=mul_after_tanh.output,
  106. name=self.model.create_node_name("FastGelu"),
  107. )
  108. fused_node.domain = "com.microsoft"
  109. self.nodes_to_add.append(fused_node)
  110. self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
  111. return True
  112. def fuse_2(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]:
  113. """
  114. This pattern is from Tensorflow model.
  115. Fuse Gelu with tanh into one node:
  116. +---------------------------+
  117. | |
  118. | v
  119. [root] --> Pow --> Mul -----> Add --> Mul --> Tanh --> Add --> Mul(B=0.5)-->Mul-->
  120. | (Y=3) (B=0.0447...) (B=0.7978...) (B=1) ^
  121. | |
  122. +---------------------------------------------------------------------------+
  123. Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
  124. """
  125. if tanh_node.output[0] not in input_name_to_nodes:
  126. return
  127. children = input_name_to_nodes[tanh_node.output[0]]
  128. if len(children) != 1 or children[0].op_type != "Add":
  129. return
  130. add_after_tanh = children[0]
  131. if not self.model.has_constant_input(add_after_tanh, 1.0):
  132. return
  133. if add_after_tanh.output[0] not in input_name_to_nodes:
  134. return
  135. children = input_name_to_nodes[add_after_tanh.output[0]]
  136. if len(children) != 1 or children[0].op_type != "Mul":
  137. return
  138. mul_half = children[0]
  139. i = self.model.find_constant_input(mul_half, 0.5)
  140. if i < 0:
  141. return
  142. if mul_half.output[0] not in input_name_to_nodes:
  143. return
  144. children = input_name_to_nodes[mul_half.output[0]]
  145. if len(children) != 1 or children[0].op_type != "Mul":
  146. return
  147. mul_after_mul_half = children[0]
  148. root_node = self.model.get_parent(
  149. mul_after_mul_half,
  150. 0 if mul_after_mul_half.input[1] == mul_half.output[0] else 1,
  151. output_name_to_node,
  152. )
  153. if root_node is None:
  154. return
  155. mul_before_tanh = self.model.match_parent(tanh_node, "Mul", 0, output_name_to_node)
  156. if mul_before_tanh is None:
  157. return
  158. i = self.model.find_constant_input(mul_before_tanh, 0.7978, delta=0.0001)
  159. if i < 0:
  160. return
  161. add_before_tanh = self.model.match_parent(mul_before_tanh, "Add", 0 if i == 1 else 1, output_name_to_node)
  162. if add_before_tanh is None:
  163. return
  164. mul_after_pow = self.model.match_parent(add_before_tanh, "Mul", None, output_name_to_node, exclude=[root_node])
  165. if mul_after_pow is None:
  166. return
  167. i = self.model.find_constant_input(mul_after_pow, 0.0447, delta=0.0001)
  168. if i < 0:
  169. return
  170. pow = self.model.match_parent(mul_after_pow, "Pow", 0 if i == 1 else 1, output_name_to_node)
  171. if pow is None:
  172. return
  173. if not self.model.has_constant_input(pow, 3.0):
  174. return
  175. if pow.input[0] != root_node.output[0]:
  176. return
  177. subgraph_nodes = [
  178. mul_after_mul_half,
  179. mul_half,
  180. add_after_tanh,
  181. tanh_node,
  182. mul_before_tanh,
  183. add_before_tanh,
  184. mul_after_pow,
  185. pow,
  186. ]
  187. if not self.model.is_safe_to_fuse_nodes(
  188. subgraph_nodes,
  189. [mul_after_mul_half.output[0]],
  190. input_name_to_nodes,
  191. output_name_to_node,
  192. ):
  193. return
  194. self.nodes_to_remove.extend(subgraph_nodes)
  195. fused_node = helper.make_node(
  196. "FastGelu",
  197. inputs=[root_node.output[0]],
  198. outputs=mul_after_mul_half.output,
  199. name=self.model.create_node_name("FastGelu"),
  200. )
  201. fused_node.domain = "com.microsoft"
  202. self.nodes_to_add.append(fused_node)
  203. self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
  204. return True
  205. def fuse_3(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]:
  206. """
  207. OpenAI's gelu implementation, also used in Megatron:
  208. Gelu(x) = x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1.0 + 0.044715 * x * x)))
  209. Fuse subgraph into a FastGelu node:
  210. +------------ Mul (B=0.79788456) -------------------+
  211. | |
  212. +-------------------------------+ |
  213. | | |
  214. | v v
  215. [root] --> Mul (B=0.044715) --> Mul --> Add(B=1) --> Mul --> Tanh --> Add(B=1) --> Mul-->
  216. | ^
  217. | |
  218. +-----------> Mul (B=0.5) --------------------------------------------------------+
  219. """
  220. if tanh_node.output[0] not in input_name_to_nodes:
  221. return
  222. children = input_name_to_nodes[tanh_node.output[0]]
  223. if len(children) != 1 or children[0].op_type != "Add":
  224. return
  225. add_after_tanh = children[0]
  226. if not self.model.has_constant_input(add_after_tanh, 1.0):
  227. return
  228. if add_after_tanh.output[0] not in input_name_to_nodes:
  229. return
  230. children = input_name_to_nodes[add_after_tanh.output[0]]
  231. if len(children) != 1 or children[0].op_type != "Mul":
  232. return
  233. mul_last = children[0]
  234. mul_half = self.model.match_parent(mul_last, "Mul", None, output_name_to_node)
  235. if mul_half is None:
  236. return
  237. i = self.model.find_constant_input(mul_half, 0.5)
  238. if i < 0:
  239. return
  240. root_input = mul_half.input[0 if i == 1 else 1]
  241. mul_before_tanh = self.model.match_parent(tanh_node, "Mul", 0, output_name_to_node)
  242. if mul_before_tanh is None:
  243. return
  244. add_1 = self.model.match_parent(mul_before_tanh, "Add", None, output_name_to_node)
  245. if add_1 is None:
  246. return
  247. j = self.model.find_constant_input(add_1, 1.0)
  248. if j < 0:
  249. return
  250. mul_7978 = self.model.match_parent(mul_before_tanh, "Mul", None, output_name_to_node)
  251. if mul_7978 is None:
  252. return
  253. k = self.model.find_constant_input(mul_7978, 0.7978, delta=0.0001)
  254. if k < 0:
  255. return
  256. if mul_7978.input[0 if k == 1 else 1] != root_input:
  257. return
  258. mul_before_add_1 = self.model.match_parent(add_1, "Mul", 0 if j == 1 else 1, output_name_to_node)
  259. if mul_before_add_1 is None:
  260. return
  261. if mul_before_add_1.input[0] == root_input:
  262. another = 1
  263. elif mul_before_add_1.input[1] == root_input:
  264. another = 0
  265. else:
  266. return
  267. mul_0447 = self.model.match_parent(mul_before_add_1, "Mul", another, output_name_to_node)
  268. if mul_0447 is None:
  269. return
  270. m = self.model.find_constant_input(mul_0447, 0.0447, delta=0.0001)
  271. if m < 0:
  272. return
  273. if mul_0447.input[0 if m == 1 else 1] != root_input:
  274. return
  275. subgraph_nodes = [
  276. mul_0447,
  277. mul_before_add_1,
  278. add_1,
  279. mul_before_tanh,
  280. tanh_node,
  281. add_after_tanh,
  282. mul_7978,
  283. mul_half,
  284. mul_last,
  285. ]
  286. if not self.model.is_safe_to_fuse_nodes(
  287. subgraph_nodes,
  288. [mul_last.output[0]],
  289. input_name_to_nodes,
  290. output_name_to_node,
  291. ):
  292. return
  293. self.nodes_to_remove.extend(subgraph_nodes)
  294. fused_node = helper.make_node(
  295. "FastGelu",
  296. inputs=[root_input],
  297. outputs=mul_last.output,
  298. name=self.model.create_node_name("FastGelu"),
  299. )
  300. fused_node.domain = "com.microsoft"
  301. self.nodes_to_add.append(fused_node)
  302. self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
  303. return True