m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

492 lines
19 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import argparse
  6. import logging
  7. import sys
  8. from collections import deque
  9. import numpy as np
  10. import onnx
  11. from onnx import ModelProto, TensorProto, numpy_helper
  12. from onnx_model_bert_tf import BertOnnxModelTF
  13. logger = logging.getLogger(__name__)
  14. class BertOnnxModelKeras(BertOnnxModelTF):
  15. def __init__(self, model, num_heads, hidden_size):
  16. super().__init__(model, num_heads, hidden_size)
  17. def match_mask_path(self, add_or_sub_before_softmax):
  18. mask_nodes = self.match_parent_path(
  19. add_or_sub_before_softmax,
  20. ["Mul", "Sub", "Reshape", "Cast"],
  21. [1, None, 1, 0],
  22. )
  23. if mask_nodes is not None:
  24. return mask_nodes
  25. mask_nodes = self.match_parent_path(
  26. add_or_sub_before_softmax,
  27. ["Mul", "Sub", "Cast", "Slice", "Unsqueeze"],
  28. [1, 1, 1, 0, 0],
  29. )
  30. if mask_nodes is not None:
  31. return mask_nodes
  32. mask_nodes = self.match_parent_path(
  33. add_or_sub_before_softmax,
  34. ["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"],
  35. [1, None, 1, 0, 0],
  36. )
  37. return mask_nodes
  38. def check_attention_input(self, matmul_q, matmul_k, matmul_v, parent, output_name_to_node):
  39. reshape_nodes = []
  40. for x in [matmul_q, matmul_k, matmul_v]:
  41. root_input = x.input[0]
  42. root_node = output_name_to_node[root_input]
  43. if root_node == parent:
  44. continue
  45. if root_node.op_type == "Reshape" and root_node.input[0] == parent.output[0]:
  46. reshape_nodes.append(root_node)
  47. continue
  48. logger.debug(f"Check attention input failed:{root_input}, {parent.output[0]}")
  49. return False, []
  50. return True, reshape_nodes
  51. def fuse_attention(self):
  52. input_name_to_nodes = self.input_name_to_nodes()
  53. output_name_to_node = self.output_name_to_node()
  54. nodes_to_remove = []
  55. attention_count = 0
  56. skip_layer_norm_nodes = self.get_nodes_by_op_type("SkipLayerNormalization")
  57. for normalize_node in skip_layer_norm_nodes:
  58. # SkipLayerNormalization has two inputs, and one of them is the root input for attention.
  59. parent = self.get_parent(normalize_node, 0)
  60. if parent is None or parent.op_type not in [
  61. "SkipLayerNormalization",
  62. "EmbedLayerNormalization",
  63. ]:
  64. if parent.op_type == "Add":
  65. parent = self.get_parent(normalize_node, 1)
  66. if parent is None or parent.op_type not in [
  67. "SkipLayerNormalization",
  68. "EmbedLayerNormalization",
  69. ]:
  70. logger.debug(
  71. "First input for skiplayernorm: {}".format(parent.op_type if parent is not None else None)
  72. )
  73. continue
  74. else:
  75. logger.debug(
  76. "First input for skiplayernorm: {}".format(parent.op_type if parent is not None else None)
  77. )
  78. continue
  79. else:
  80. # TODO: shall we add back the checking of children op types.
  81. pass
  82. qkv_nodes = self.match_parent_path(
  83. normalize_node,
  84. ["Add", "Reshape", "MatMul", "Reshape", "Transpose", "MatMul"],
  85. [None, 0, 0, 0, 0, 0],
  86. )
  87. if qkv_nodes is None:
  88. logger.debug("Failed to match qkv nodes")
  89. continue
  90. (
  91. add,
  92. extra_reshape_0,
  93. matmul,
  94. reshape_qkv,
  95. transpose_qkv,
  96. matmul_qkv,
  97. ) = qkv_nodes
  98. logger.debug("Matched qkv nodes")
  99. v_nodes = self.match_parent_path(
  100. matmul_qkv,
  101. ["Transpose", "Reshape", "Add", "Reshape", "MatMul"],
  102. [1, 0, 0, 0, 0],
  103. )
  104. if v_nodes is None:
  105. logger.debug("Failed to match v path")
  106. continue
  107. (transpose_v, reshape_v, add_v, extra_reshape_1, matmul_v) = v_nodes
  108. qk_nodes = self.match_parent_path(matmul_qkv, ["Softmax", "Sub", "MatMul"], [0, 0, 0])
  109. if qk_nodes is not None:
  110. (softmax_qk, sub_qk, matmul_qk) = qk_nodes
  111. q_nodes = self.match_parent_path(
  112. matmul_qk,
  113. ["Mul", "Transpose", "Reshape", "Add", "Reshape", "MatMul"],
  114. [0, None, 0, 0, 0, 0],
  115. )
  116. if q_nodes is not None:
  117. (
  118. mul_q,
  119. transpose_q,
  120. reshape_q,
  121. add_q,
  122. extra_reshape_2,
  123. matmul_q,
  124. ) = q_nodes
  125. else:
  126. qk_nodes = self.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, None])
  127. if qk_nodes is None:
  128. qk_nodes = self.match_parent_path(matmul_qkv, ["Softmax", "Add", "Div", "MatMul"], [0, 0, 0, None])
  129. if qk_nodes is None:
  130. logger.debug("Failed to match qk path")
  131. continue
  132. (softmax_qk, add_qk, mul_qk, matmul_qk) = qk_nodes
  133. q_nodes = self.match_parent_path(
  134. matmul_qk,
  135. ["Transpose", "Reshape", "Add", "Reshape", "MatMul"],
  136. [0, 0, 0, 0, 0],
  137. )
  138. if q_nodes is not None:
  139. (transpose_q, reshape_q, add_q, extra_reshape_2, matmul_q) = q_nodes
  140. if q_nodes is None:
  141. logger.debug("Failed to match q path")
  142. continue
  143. k_nodes = self.match_parent_path(
  144. matmul_qk,
  145. ["Transpose", "Reshape", "Add", "Reshape", "MatMul"],
  146. [1, 0, 0, 0, 0],
  147. )
  148. if k_nodes is None:
  149. logger.debug("Failed to match k path")
  150. continue
  151. (transpose_k, reshape_k, add_k, extra_reshape_3, matmul_k) = k_nodes
  152. mask_nodes = self.match_mask_path(qk_nodes[1])
  153. if mask_nodes is None:
  154. logger.debug("Failed to match mask path")
  155. continue
  156. if not self.has_constant_input(mask_nodes[1], 1):
  157. logger.debug("Sub node expected to have an input with constant value 1.0.")
  158. continue
  159. is_same_root, reshape_nodes = self.check_attention_input(
  160. matmul_q, matmul_k, matmul_v, parent, output_name_to_node
  161. )
  162. if is_same_root:
  163. mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
  164. logger.debug("Create an Attention node.")
  165. attention_node = self.attention_fusion.create_attention_node(
  166. mask_index,
  167. matmul_q,
  168. matmul_k,
  169. matmul_v,
  170. add_q,
  171. add_k,
  172. add_v,
  173. self.num_heads,
  174. self.hidden_size,
  175. parent.output[0],
  176. reshape_qkv.output[0],
  177. None,
  178. )
  179. if attention_node is None:
  180. continue
  181. self.add_node(attention_node)
  182. attention_count += 1
  183. nodes_to_remove.extend([reshape_qkv, transpose_qkv, matmul_qkv])
  184. nodes_to_remove.extend(qk_nodes)
  185. nodes_to_remove.extend(q_nodes)
  186. nodes_to_remove.extend(k_nodes)
  187. nodes_to_remove.extend(v_nodes)
  188. nodes_to_remove.extend(mask_nodes)
  189. nodes_to_remove.extend(reshape_nodes)
  190. nodes_to_remove.append(extra_reshape_0)
  191. self.replace_node_input(add, extra_reshape_0.output[0], matmul.output[0])
  192. else:
  193. logger.debug("Root node not matched.")
  194. continue
  195. self.remove_nodes(nodes_to_remove)
  196. self.update_graph()
  197. logger.info(f"Fused Attention count:{attention_count}")
  198. def preprocess(self):
  199. self.process_embedding()
  200. self.fuse_mask()
  201. self.skip_reshape()
  202. def skip_reshape(self):
  203. input_name_to_nodes = self.input_name_to_nodes()
  204. output_name_to_node = self.output_name_to_node()
  205. nodes_to_remove = []
  206. attention_count = 0
  207. count = 0
  208. reshape_nodes = self.get_nodes_by_op_type("Reshape")
  209. for reshape_node in reshape_nodes:
  210. parent = self.get_parent(reshape_node, 0)
  211. if parent is not None and parent.op_type == "Reshape":
  212. reshape_node.input[0] = parent.input[0]
  213. count += 1
  214. if count > 0:
  215. logger.info(f"Skip consequent Reshape count: {count}")
  216. def fuse_embedding(self, node, output_name_to_node):
  217. assert node.op_type == "LayerNormalization"
  218. logger.debug(f"start fusing embedding from node with output={node.output[0]}...")
  219. word_embed_path = self.match_parent_path(node, ["Add", "Add", "Gather"], [0, 0, 0], output_name_to_node)
  220. if word_embed_path is None:
  221. logger.debug("failed to match word_embed_path")
  222. return False
  223. skip_node, add_node, gather_node = word_embed_path
  224. word_initializer = self.get_initializer(gather_node.input[0])
  225. if word_initializer is None:
  226. logger.debug("failed to get word initializer")
  227. return False
  228. temp = numpy_helper.to_array(word_initializer)
  229. if len(temp.shape) == 2:
  230. logger.info("Found word embedding. name:{}, shape:{}".format(word_initializer.name, temp.shape))
  231. word_embedding = word_initializer.name
  232. else:
  233. logger.info("Failed to find word embedding. name:{}, shape:{}".format(word_initializer.name, temp.shape))
  234. return False
  235. pos_initializer = self.get_initializer(add_node.input[1])
  236. if pos_initializer is not None:
  237. temp = numpy_helper.to_array(pos_initializer)
  238. if len(temp.shape) == 3 and temp.shape[0] == 1:
  239. tensor = numpy_helper.from_array(temp.reshape((temp.shape[1], temp.shape[2])), "position_embedding")
  240. self.add_initializer(tensor)
  241. logger.info("Found position embedding. name:{}, shape:{}".format(pos_initializer.name, temp.shape[1:]))
  242. position_embedding = "position_embedding"
  243. else:
  244. logger.info(
  245. "Failed to find position embedding. name:{}, shape:{}".format(pos_initializer.name, temp.shape)
  246. )
  247. return False
  248. else:
  249. pos_embed_path = self.match_parent_path(add_node, ["Gather", "Slice"], [1, 1], output_name_to_node)
  250. if pos_embed_path is None:
  251. logger.debug("failed to match pos_embed_path")
  252. return False
  253. pos_gather, pos_slice = pos_embed_path
  254. pos_initializer = self.get_initializer(pos_gather.input[0])
  255. if pos_initializer is None:
  256. logger.debug("failed to get pos initializer")
  257. return False
  258. temp = numpy_helper.to_array(pos_initializer)
  259. if len(temp.shape) == 2:
  260. logger.info("Found word embedding. name:{}, shape:{}".format(pos_initializer.name, temp.shape))
  261. position_embedding = pos_initializer.name
  262. else:
  263. logger.info(
  264. "Failed to find position embedding. name:{}, shape:{}".format(pos_initializer.name, temp.shape)
  265. )
  266. return False
  267. gather = self.get_parent(skip_node, 1, output_name_to_node)
  268. if gather is None or gather.op_type != "Gather":
  269. logger.debug("failed to get gather")
  270. return False
  271. segment_initializer = self.get_initializer(gather.input[0])
  272. if segment_initializer is None:
  273. logger.debug("failed to get segment initializer")
  274. return False
  275. temp = numpy_helper.to_array(segment_initializer)
  276. if len(temp.shape) == 2:
  277. logger.info("Found segment embedding. name:{}, shape:{}".format(segment_initializer.name, temp.shape))
  278. segment_embedding = segment_initializer.name
  279. else:
  280. logger.info(
  281. "Failed to find segment embedding. name:{}, shape:{}".format(segment_initializer.name, temp.shape)
  282. )
  283. return False
  284. logger.info("Create Embedding node")
  285. self.create_embedding_subgraph(node, word_embedding, segment_embedding, position_embedding)
  286. return True
  287. def process_embedding(self):
  288. """
  289. Automatically detect word, segment and position embeddings.
  290. """
  291. logger.info("start processing embedding layer...")
  292. output_name_to_node = self.output_name_to_node()
  293. for node in self.nodes():
  294. if node.op_type == "LayerNormalization":
  295. if self.fuse_embedding(node, output_name_to_node):
  296. return
  297. break
  298. def fuse_mask(self):
  299. nodes_to_remove = []
  300. for node in self.nodes():
  301. if node.op_type == "Mul" and self.has_constant_input(node, -10000):
  302. mask_path = self.match_parent_path(node, ["Sub", "Cast", "Slice", "Unsqueeze"], [0, 1, 0, 0])
  303. if mask_path is None:
  304. continue
  305. sub_node, cast_node, slice_node, unsqueeze_node = mask_path
  306. mask_input_name = self.attention_mask.get_first_mask()
  307. if unsqueeze_node.input[0] != mask_input_name:
  308. print("Cast input {} is not mask input {}".format(unsqueeze_node.input[0], mask_input_name))
  309. continue
  310. unsqueeze_added_1 = onnx.helper.make_node(
  311. "Unsqueeze",
  312. inputs=[mask_input_name],
  313. outputs=["mask_fuse_unsqueeze1_output"],
  314. name="Mask_UnSqueeze_1",
  315. axes=[1],
  316. )
  317. unsqueeze_added_2 = onnx.helper.make_node(
  318. "Unsqueeze",
  319. inputs=["mask_fuse_unsqueeze1_output"],
  320. outputs=["mask_fuse_unsqueeze2_output"],
  321. name="Mask_UnSqueeze_2",
  322. axes=[2],
  323. )
  324. # self.replace_node_input(cast_node, cast_node.input[0], 'mask_fuse_unsqueeze2_output')
  325. cast_node_2 = onnx.helper.make_node(
  326. "Cast",
  327. inputs=["mask_fuse_unsqueeze2_output"],
  328. outputs=["mask_fuse_cast_output"],
  329. )
  330. cast_node_2.attribute.extend([onnx.helper.make_attribute("to", 1)])
  331. self.replace_node_input(sub_node, sub_node.input[1], "mask_fuse_cast_output")
  332. nodes_to_remove.extend([slice_node, unsqueeze_node, cast_node])
  333. self.add_node(unsqueeze_added_1)
  334. self.add_node(unsqueeze_added_2)
  335. self.add_node(cast_node_2)
  336. self.remove_nodes(nodes_to_remove)
  337. # Prune graph is done after removing nodes to remove island nodes.
  338. if len(nodes_to_remove) > 0:
  339. self.prune_graph()
  340. logger.info("Fused mask" if len(nodes_to_remove) > 0 else "Failed to fuse mask")
  341. def remove_extra_reshape(self):
  342. skiplayernorm_nodes = self.get_nodes_by_op_type("SkipLayerNormalization")
  343. reshape_removed = 0
  344. for skiplayernorm_node in skiplayernorm_nodes:
  345. path = self.match_parent_path(
  346. skiplayernorm_node,
  347. [
  348. "Add",
  349. "Reshape",
  350. "MatMul",
  351. "Reshape",
  352. "Gelu",
  353. "Add",
  354. "Reshape",
  355. "MatMul",
  356. "SkipLayerNormalization",
  357. ],
  358. [0, 0, 0, 0, 0, 0, 0, 0, 0],
  359. )
  360. if path is None:
  361. continue
  362. (
  363. add_1,
  364. reshape_1,
  365. matmul_1,
  366. reshape_2,
  367. gelu,
  368. add_2,
  369. reshape_3,
  370. matmul_2,
  371. skiplayernorm,
  372. ) = path
  373. add_2.input[0] = matmul_2.output[0]
  374. self.remove_node(reshape_3)
  375. matmul_1.input[0] = gelu.output[0]
  376. self.remove_node(reshape_2)
  377. add_1.input[0] = matmul_1.output[0]
  378. self.remove_node(reshape_1)
  379. reshape_removed += 3
  380. return reshape_removed
  381. def remove_extra_reshape_2(self):
  382. skiplayernorm_nodes = self.get_nodes_by_op_type("SkipLayerNormalization")
  383. reshape_removed = 0
  384. for skiplayernorm_node in skiplayernorm_nodes:
  385. path = self.match_parent_path(
  386. skiplayernorm_node,
  387. [
  388. "Add",
  389. "Reshape",
  390. "MatMul",
  391. "Reshape",
  392. "Gelu",
  393. "Add",
  394. "Reshape",
  395. "MatMul",
  396. "Reshape",
  397. "SkipLayerNormalization",
  398. ],
  399. [None, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  400. ) # yapf: disable
  401. if path is None:
  402. continue
  403. (
  404. add_1,
  405. reshape_1,
  406. matmul_1,
  407. reshape_2,
  408. gelu,
  409. add_2,
  410. reshape_3,
  411. matmul_2,
  412. reshape_4,
  413. skiplayernorm,
  414. ) = path
  415. matmul_2.input[0] = skiplayernorm.output[0]
  416. self.remove_node(reshape_4)
  417. add_2.input[0] = matmul_2.output[0]
  418. self.remove_node(reshape_3)
  419. matmul_1.input[0] = gelu.output[0]
  420. self.remove_node(reshape_2)
  421. add_1.input[0] = matmul_1.output[0]
  422. self.remove_node(reshape_1)
  423. reshape_removed += 4
  424. return reshape_removed
  425. def postprocess(self):
  426. reshape_removed = self.remove_extra_reshape() + self.remove_extra_reshape_2()
  427. logger.info(f"Remove {reshape_removed} Reshape nodes.")
  428. self.prune_graph()