m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

296 lines
12 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from typing import Dict
  7. from fusion_base import Fusion
  8. from onnx import helper
  9. from onnx_model import OnnxModel
  10. logger = getLogger(__name__)
  11. class FusionLayerNormalization(Fusion):
  12. def __init__(self, model: OnnxModel):
  13. super().__init__(model, "LayerNormalization", "ReduceMean")
  14. def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
  15. """
  16. Fuse Layer Normalization subgraph into one node LayerNormalization:
  17. +----------------------+
  18. | |
  19. | v
  20. [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
  21. (axis=2 or -1) | (Y=2) (axis=2 or -1) (E-6 or E-12 or 0) ^
  22. | |
  23. +-----------------------------------------------+
  24. It also handles cases of duplicated sub nodes exported from older version of PyTorch:
  25. +----------------------+
  26. | v
  27. | +-------> Sub-----------------------------------------------+
  28. | | |
  29. | | v
  30. [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
  31. | ^
  32. | |
  33. +----------------------+
  34. """
  35. children = self.model.get_children(node, input_name_to_nodes)
  36. if len(children) == 0 or len(children) > 2:
  37. return
  38. root_input = node.input[0]
  39. if children[0].op_type != "Sub" or children[0].input[0] != root_input:
  40. return
  41. if len(children) == 2:
  42. if children[1].op_type != "Sub" or children[1].input[0] != root_input:
  43. return
  44. div_node = None
  45. for child in children:
  46. div_node = self.model.find_first_child_by_type(child, "Div", input_name_to_nodes, recursive=False)
  47. if div_node is not None:
  48. break
  49. if div_node is None:
  50. return
  51. path_id, parent_nodes, _ = self.model.match_parent_paths(
  52. div_node,
  53. [
  54. (["Sqrt", "Add", "ReduceMean", "Pow", "Sub"], [1, 0, 0, 0, 0]),
  55. (
  56. ["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"],
  57. [1, 0, 0, 0, 0, 0],
  58. ),
  59. ],
  60. output_name_to_node,
  61. )
  62. if path_id < 0:
  63. return
  64. sub_node = parent_nodes[-1]
  65. if sub_node not in children:
  66. return
  67. second_add_node = parent_nodes[1]
  68. i, add_weight = self.model.get_constant_input(second_add_node)
  69. if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4:
  70. logger.warning(f"epsilon value is not expeced: {add_weight}")
  71. return
  72. pow_node = parent_nodes[3]
  73. if not self.model.find_constant_input(pow_node, 2.0) == 1:
  74. return
  75. mul_node = input_name_to_nodes[div_node.output[0]][0]
  76. if mul_node.op_type != "Mul":
  77. return
  78. last_add_node = input_name_to_nodes[mul_node.output[0]][0]
  79. if last_add_node.op_type != "Add":
  80. return
  81. subgraph_nodes = [node]
  82. subgraph_nodes.extend(children)
  83. subgraph_nodes.extend(parent_nodes[:-1])
  84. subgraph_nodes.extend([last_add_node, mul_node, div_node])
  85. if not self.model.is_safe_to_fuse_nodes(
  86. subgraph_nodes,
  87. last_add_node.output,
  88. input_name_to_nodes,
  89. output_name_to_node,
  90. ):
  91. logger.debug(f"It is not safe to fuse LayerNormalization node. Skip")
  92. return
  93. weight_input = mul_node.input[1 - self.model.input_index(div_node.output[0], mul_node)]
  94. if not self.model.is_constant_with_specified_dimension(weight_input, 1, "layernorm weight"):
  95. return
  96. bias_input = last_add_node.input[1 - self.model.input_index(mul_node.output[0], last_add_node)]
  97. if not self.model.is_constant_with_specified_dimension(bias_input, 1, "layernorm bias"):
  98. return
  99. self.nodes_to_remove.extend(subgraph_nodes)
  100. normalize_node = helper.make_node(
  101. "LayerNormalization",
  102. inputs=[node.input[0], weight_input, bias_input],
  103. outputs=[last_add_node.output[0]],
  104. name=self.model.create_node_name("LayerNormalization", name_prefix="LayerNorm"),
  105. )
  106. normalize_node.attribute.extend([helper.make_attribute("epsilon", float(add_weight))])
  107. self.nodes_to_add.append(normalize_node)
  108. self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name
  109. class FusionLayerNormalizationTF(Fusion):
  110. def __init__(self, model: OnnxModel):
  111. super().__init__(model, "LayerNormalization", "Add", "TF")
  112. def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
  113. """
  114. Layer Norm from Tensorflow model(using keras2onnx or tf2onnx):
  115. +------------------------------------+
  116. | |
  117. | |
  118. (Cast_1) |
  119. | |
  120. | v (B) (B) (A)
  121. Add --> (Cast_1) --> ReduceMean --> Sub --> Mul --> ReduceMean --> (Cast_3) --> Add --> Sqrt --> Reciprocol --> Mul --> Mul --> Sub --> Add
  122. | | | ^ ^
  123. | | | | |
  124. | +--------------------------------------------------(Cast_2)-------------------------------|-------+ |
  125. | v |
  126. +---------------------------------------------------------------------------------------------------------------> Mul--------------------+
  127. """
  128. return_indice = []
  129. _, parent_nodes, return_indice = self.model.match_parent_paths(
  130. node,
  131. [
  132. (
  133. [
  134. "Sub",
  135. "Mul",
  136. "Mul",
  137. "Reciprocal",
  138. "Sqrt",
  139. "Add",
  140. "ReduceMean",
  141. "Mul",
  142. "Sub",
  143. "ReduceMean",
  144. ],
  145. [1, 1, None, 0, 0, 0, None, 0, 0, None],
  146. ),
  147. (
  148. [
  149. "Sub",
  150. "Mul",
  151. "Mul",
  152. "Reciprocal",
  153. "Sqrt",
  154. "Add",
  155. "Cast",
  156. "ReduceMean",
  157. "Mul",
  158. "Sub",
  159. "ReduceMean",
  160. ],
  161. [1, 1, None, 0, 0, 0, 0, None, 0, 0, None],
  162. ),
  163. ],
  164. output_name_to_node,
  165. ) # yapf: disable
  166. if parent_nodes is None:
  167. return
  168. assert len(return_indice) == 3
  169. if not (return_indice[0] in [0, 1] and return_indice[1] in [0, 1] and return_indice[2] in [0, 1]):
  170. logger.debug("return indice is exepected in [0, 1], but got {return_indice}")
  171. return
  172. (
  173. sub_node_0,
  174. mul_node_0,
  175. mul_node_1,
  176. reciprocol_node,
  177. sqrt_node,
  178. add_node_0,
  179. ) = parent_nodes[:6]
  180. reduce_mean_node_0, mul_node_2, sub_node_1, reduce_mean_node_1 = parent_nodes[-4:]
  181. cast_node_3 = None
  182. if len(parent_nodes) == 11:
  183. cast_node_3 = parent_nodes[6]
  184. assert cast_node_3.op_type == "Cast"
  185. mul_node_3 = self.model.match_parent(node, "Mul", 0, output_name_to_node)
  186. if mul_node_3 is None:
  187. logger.debug("mul_node_3 not found")
  188. return
  189. node_before_reduce = self.model.get_parent(reduce_mean_node_1, 0, output_name_to_node)
  190. root_node = (
  191. node_before_reduce
  192. if cast_node_3 is None
  193. else self.model.get_parent(node_before_reduce, 0, output_name_to_node)
  194. )
  195. if root_node is None:
  196. logger.debug("root node is none")
  197. return
  198. i, epsilon = self.model.get_constant_input(add_node_0)
  199. if epsilon is None or epsilon <= 0 or (epsilon > 1.0e-5 and cast_node_3 is None):
  200. logger.debug("epsilon is not matched")
  201. return
  202. if cast_node_3 is None and (
  203. reduce_mean_node_1.input[0] not in mul_node_3.input or reduce_mean_node_1.input[0] not in sub_node_1.input
  204. ):
  205. logger.debug("reduce_mean_node_1 and mul_node_3 shall link from root node")
  206. return
  207. if cast_node_3 is not None and (
  208. node_before_reduce.input[0] not in mul_node_3.input or reduce_mean_node_1.input[0] not in sub_node_1.input
  209. ):
  210. logger.debug("reduce_mean_node_1 and mul_node_3 shall link from root node")
  211. return
  212. if mul_node_2.input[0] != mul_node_2.input[1]:
  213. logger.debug("mul_node_2 shall have two same inputs")
  214. return
  215. subgraph_nodes = [
  216. node,
  217. sub_node_0,
  218. mul_node_0,
  219. mul_node_1,
  220. reciprocol_node,
  221. sqrt_node,
  222. add_node_0,
  223. reduce_mean_node_0,
  224. mul_node_2,
  225. sub_node_1,
  226. reduce_mean_node_1,
  227. mul_node_3,
  228. ]
  229. if cast_node_3 is not None:
  230. cast_node_2 = self.model.match_parent(mul_node_0, "Cast", 0, output_name_to_node)
  231. if cast_node_2 is None:
  232. logger.debug("cast_node_2 not found")
  233. return
  234. subgraph_nodes.extend([node_before_reduce, cast_node_2, cast_node_3])
  235. if not self.model.is_safe_to_fuse_nodes(
  236. subgraph_nodes,
  237. node.output,
  238. self.model.input_name_to_nodes(),
  239. self.model.output_name_to_node(),
  240. ):
  241. logger.debug("not safe to fuse layer normalization")
  242. return
  243. self.nodes_to_remove.extend(subgraph_nodes)
  244. weight_input = mul_node_1.input[1]
  245. bias_input = sub_node_0.input[0]
  246. # TODO: add epsilon attribute
  247. fused_node = helper.make_node(
  248. "LayerNormalization",
  249. inputs=[mul_node_3.input[0], weight_input, bias_input],
  250. outputs=[node.output[0]],
  251. name=self.model.create_node_name("LayerNormalization", name_prefix="LayerNorm"),
  252. )
  253. fused_node.attribute.extend([helper.make_attribute("epsilon", float(epsilon))])
  254. self.nodes_to_add.append(fused_node)
  255. self.node_name_to_graph_name[fused_node.name] = self.this_graph_name