m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

592 lines
24 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import argparse
  6. import logging
  7. import sys
  8. from collections import deque
  9. import numpy as np
  10. import onnx
  11. from onnx import ModelProto, TensorProto, helper, numpy_helper
  12. from onnx_model_bert import BertOnnxModel
  13. logger = logging.getLogger(__name__)
  14. class BertOnnxModelTF(BertOnnxModel):
  15. def __init__(self, model, num_heads, hidden_size):
  16. super().__init__(model, num_heads, hidden_size)
  17. def remove_identity(self):
  18. nodes_to_remove = []
  19. for node in self.nodes():
  20. if node.op_type == "Identity":
  21. if not self.find_graph_output(node.output[0]):
  22. self.replace_input_of_all_nodes(node.output[0], node.input[0])
  23. nodes_to_remove.append(node)
  24. self.remove_nodes(nodes_to_remove)
  25. logger.info(f"Removed Identity count: {len(nodes_to_remove)}")
  26. def match_mask_path(self, add_or_sub_before_softmax):
  27. mask_nodes = self.match_parent_path(
  28. add_or_sub_before_softmax,
  29. ["Mul", "Sub", "Reshape", "Cast"],
  30. [1, None, 1, 0],
  31. )
  32. if mask_nodes is not None:
  33. return mask_nodes
  34. mask_nodes = self.match_parent_path(
  35. add_or_sub_before_softmax,
  36. ["Mul", "Sub", "Cast", "Slice", "Unsqueeze"],
  37. [1, 0, 1, 0, 0],
  38. )
  39. if mask_nodes is not None:
  40. return mask_nodes
  41. mask_nodes = self.match_parent_path(
  42. add_or_sub_before_softmax,
  43. ["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"],
  44. [1, None, 1, 0, 0],
  45. )
  46. return mask_nodes
  47. def get_2d_initializers_from_parent_subgraphs(self, current_node):
  48. """
  49. Find initializers that is 2D. Returns a dictionary with name as key and shape as value.
  50. """
  51. parent_nodes = self.get_parent_subgraph_nodes(current_node, [])
  52. initializers = {}
  53. for node in parent_nodes:
  54. for input in node.input:
  55. initializer = self.get_initializer(input)
  56. if initializer:
  57. temp = numpy_helper.to_array(initializer)
  58. if len(temp.shape) == 2:
  59. initializers[initializer.name] = temp.shape
  60. return initializers
  61. def find_segment_ids(self, segment_embedding, input_ids):
  62. input_name_to_nodes = self.input_name_to_nodes()
  63. if segment_embedding not in input_name_to_nodes:
  64. return None
  65. nodes = input_name_to_nodes[segment_embedding]
  66. if len(nodes) != 1:
  67. return None
  68. graph_inputs = self.get_graph_inputs(nodes[0], recursive=True)
  69. if len(graph_inputs) > 1:
  70. print("Found multiple candidates of segment_ids", graph_inputs)
  71. return None
  72. # Find segment ids in graph inputs. The segment id input must not be the same as input_ids.
  73. if len(graph_inputs) == 1 and graph_inputs[0] != input_ids:
  74. return graph_inputs[0]
  75. # If the segment id candidate is the same as the input_ids, try to assign alternative segment ids and simplify the graph if needed.
  76. segment_ids = nodes[0].input[1]
  77. _, segment_id_path, _ = self.match_parent_paths(
  78. nodes[0],
  79. [
  80. (
  81. ["ConstantOfShape", "Cast", "Concat", "Slice", "Cast", "Shape"],
  82. [1, 0, 0, 0, 0, 0],
  83. ),
  84. (
  85. [
  86. "ConstantOfShape",
  87. "Cast",
  88. "Concat",
  89. "Unsqueeze",
  90. "Squeeze",
  91. "Slice",
  92. "Cast",
  93. "Shape",
  94. ],
  95. [1, 0, 0, 0, 0, 0, 0, 0],
  96. ),
  97. ],
  98. None,
  99. )
  100. if segment_id_path and input_ids and input_ids == segment_id_path[-1].input[0]:
  101. logger.debug("Simplify semgent id path...")
  102. constantofshape_node = segment_id_path[0]
  103. graph_name = self.get_graph_by_node(constantofshape_node).name
  104. self.add_node(
  105. helper.make_node("Shape", inputs=[input_ids], outputs=["input_shape"]),
  106. graph_name,
  107. )
  108. constantofshape_value = helper.get_attribute_value(constantofshape_node.attribute[0])
  109. self.add_node(
  110. helper.make_node(
  111. "ConstantOfShape",
  112. inputs=["input_shape"],
  113. outputs=["zeros_for_input_shape"],
  114. value=constantofshape_value,
  115. ),
  116. graph_name,
  117. )
  118. segment_ids = "zeros_for_input_shape"
  119. return segment_ids
  120. def find_input_ids(self, word_embedding):
  121. input_name_to_nodes = self.input_name_to_nodes()
  122. if word_embedding not in input_name_to_nodes:
  123. return None
  124. nodes = input_name_to_nodes[word_embedding]
  125. if len(nodes) != 1:
  126. return None
  127. graph_inputs = self.get_graph_inputs(nodes[0], recursive=True)
  128. if len(graph_inputs) == 1:
  129. return graph_inputs[0]
  130. print("Found multiple candidates of input_ids", graph_inputs)
  131. return None
  132. def find_mask_input(self, excluded_graph_inputs):
  133. for node in self.nodes():
  134. if node.op_type == "Softmax":
  135. mask_path = self.match_parent_path(
  136. node,
  137. ["Add", "Mul", "Sub", "Cast", "Slice", "Unsqueeze"],
  138. [0, 1, None, 1, 0, 0],
  139. )
  140. if mask_path is None:
  141. continue
  142. (
  143. add_node,
  144. mul_node,
  145. sub_node,
  146. cast_node,
  147. slice_node,
  148. unsqueeze_node,
  149. ) = mask_path
  150. if self.has_constant_input(mul_node, -10000) and self.has_constant_input(sub_node, 1):
  151. graph_inputs = self.get_graph_inputs(sub_node, recursive=True)
  152. inputs = [input for input in graph_inputs if input not in excluded_graph_inputs]
  153. if len(inputs) > 1:
  154. print("Found multiple candidates of mask input", inputs)
  155. return None
  156. if len(inputs) == 1:
  157. return inputs[0]
  158. # Duplicated input found. Try to simplify the graph.
  159. path_to_be_simplified = self.match_parent_path(
  160. mask_path[-1],
  161. [
  162. "ConstantOfShape",
  163. "Cast",
  164. "Concat",
  165. "Unsqueeze",
  166. "Squeeze",
  167. "Slice",
  168. "Cast",
  169. "Shape",
  170. ],
  171. [0, 0, 0, 0, 0, 0, 0, 0],
  172. )
  173. duplicated_inputs = [input for input in graph_inputs if input in excluded_graph_inputs]
  174. # Simplify graph for dynamic axes.
  175. if (
  176. path_to_be_simplified
  177. and duplicated_inputs
  178. and len(duplicated_inputs) == 1
  179. and duplicated_inputs[0] == path_to_be_simplified[-1].input[0]
  180. ):
  181. logger.debug("Simplify semgent id path...")
  182. constantofshape_node = path_to_be_simplified[0]
  183. constantofshape_value = helper.get_attribute_value(constantofshape_node.attribute[0])
  184. graph_name = self.get_graph_by_node(constantofshape_node).name
  185. self.add_node(
  186. helper.make_node(
  187. "Shape",
  188. inputs=[duplicated_inputs[0]],
  189. outputs=["input_shape_for_mask"],
  190. ),
  191. graph_name,
  192. )
  193. self.add_node(
  194. helper.make_node(
  195. "ConstantOfShape",
  196. inputs=["input_shape_for_mask"],
  197. outputs=[unsqueeze_node.input[0]],
  198. value=constantofshape_value,
  199. ),
  200. graph_name,
  201. )
  202. return unsqueeze_node.input[0]
  203. return None
  204. def create_embedding_subgraph(self, normalize_node, word_embedding, segment_embedding, position_embedding):
  205. input_ids = self.find_input_ids(word_embedding)
  206. if input_ids is None:
  207. logger.info("Failed to find input_ids. Cannot fuse embedding layer.")
  208. return False
  209. segment_ids = self.find_segment_ids(segment_embedding, input_ids)
  210. if segment_ids is None:
  211. logger.info("Failed to find segment_ids. Cannot fuse embedding layer.")
  212. return False
  213. mask_input = self.find_mask_input([segment_ids, input_ids])
  214. if mask_input is None:
  215. logger.info("Failed to find input_mask. Cannot fuse embedding layer.")
  216. return False
  217. self.bert_inputs = [input_ids, segment_ids, mask_input]
  218. mask_index = self.create_node_name("mask_index")
  219. self.attention_mask.set_mask_indice(mask_input, mask_index)
  220. if self.find_graph_input(input_ids).type.tensor_type.elem_type != TensorProto.INT32:
  221. casted, input_ids = self.utils.cast_graph_input_to_int32(input_ids)
  222. if self.find_graph_input(segment_ids):
  223. casted, segment_ids = self.utils.cast_graph_input_to_int32(segment_ids)
  224. else:
  225. segment_ids, segment_id_cast_node = self.utils.cast_input_to_int32(segment_ids)
  226. if self.find_graph_input(mask_input):
  227. casted, mask_input = self.utils.cast_graph_input_to_int32(mask_input)
  228. else:
  229. mask_input, mask_input_cast_node = self.utils.cast_input_to_int32(mask_input)
  230. embed_output = self.create_node_name("embed_output")
  231. embed_node = onnx.helper.make_node(
  232. "EmbedLayerNormalization",
  233. inputs=[
  234. input_ids,
  235. segment_ids,
  236. word_embedding,
  237. position_embedding,
  238. segment_embedding,
  239. normalize_node.input[1], # gamma
  240. normalize_node.input[2], # beta
  241. mask_input,
  242. ],
  243. outputs=[embed_output, mask_index],
  244. name="EmbedLayer",
  245. )
  246. embed_node.domain = "com.microsoft"
  247. self.replace_input_of_all_nodes(normalize_node.output[0], embed_output)
  248. self.add_node(embed_node, self.get_graph_by_node(normalize_node).name)
  249. def process_embedding(self):
  250. """
  251. Automatically detect word, segment and position embeddings.
  252. """
  253. logger.info("start processing embedding layer...")
  254. output_name_to_node = self.output_name_to_node()
  255. layer_norm_nodes = self.get_nodes_by_op_type("LayerNormalization")
  256. for layer_norm_node in layer_norm_nodes:
  257. pos_embed_path = self.match_parent_path(
  258. layer_norm_node,
  259. ["Add", "Reshape", "Slice"],
  260. [0, 1, 0],
  261. output_name_to_node,
  262. )
  263. if pos_embed_path is None:
  264. continue
  265. add_node, reshape_node, slice_node = pos_embed_path
  266. initializer = self.get_initializer(slice_node.input[0])
  267. if initializer is None:
  268. continue
  269. temp = numpy_helper.to_array(initializer)
  270. if len(temp.shape) == 2:
  271. logger.info("Found position embedding. name:{}, shape:{}".format(initializer.name, temp.shape))
  272. position_embedding = initializer.name
  273. else:
  274. logger.info("Failed to find position embedding. name:{}, shape:{}".format(initializer.name, temp.shape))
  275. return
  276. first_parent = self.get_parent(add_node, 0, output_name_to_node)
  277. if first_parent is not None and first_parent.op_type == "Add":
  278. embeddings = self.get_2d_initializers_from_parent_subgraphs(first_parent)
  279. if len(embeddings) != 2:
  280. logger.warning(
  281. "Failed to find two embeddings (word and segment) from Add node. Found {}".format(embeddings)
  282. )
  283. return
  284. word_embedding = None
  285. segment_embedding = None
  286. for name, shape in embeddings.items():
  287. if shape[0] == 2:
  288. segment_embedding = name
  289. logger.info("Found segment embedding. name:{}, shape:{}".format(name, shape))
  290. else:
  291. word_embedding = name
  292. logger.info("Found words embedding. name:{}, shape:{}".format(name, shape))
  293. if word_embedding is None or segment_embedding is None:
  294. logger.info("Failed to find both word and segment embedding")
  295. return
  296. logger.info("Create Embedding node")
  297. self.create_embedding_subgraph(
  298. layer_norm_node,
  299. word_embedding,
  300. segment_embedding,
  301. position_embedding,
  302. )
  303. # Prune graph to remove those original embedding nodes.
  304. self.prune_graph()
  305. break
  306. def check_attention_input(self, matmul_q, matmul_k, matmul_v, parent, output_name_to_node):
  307. for x in [matmul_q, matmul_k, matmul_v]:
  308. root_input = x.input[0]
  309. root_node = output_name_to_node[root_input]
  310. if root_node == parent:
  311. continue
  312. logger.debug(f"Check attention input failed:{root_input}, {parent.output[0]}")
  313. return False
  314. return True
  315. def fuse_attention(self):
  316. output_name_to_node = self.output_name_to_node()
  317. nodes_to_remove = []
  318. attention_count = 0
  319. start_nodes = []
  320. skip_layer_norm_nodes = self.get_nodes_by_op_type("SkipLayerNormalization")
  321. layer_norm_nodes = self.get_nodes_by_op_type("LayerNormalization")
  322. # Sometimes we can not fuse skiplayernormalization since the add before layernorm has an output that used by nodes outside skiplayernorm
  323. # Conceptually we treat add before layernorm as skiplayernorm node since they share the same pattern
  324. start_nodes.extend(skip_layer_norm_nodes)
  325. start_nodes.extend(layer_norm_nodes)
  326. for normalize_node in start_nodes:
  327. graph_name = self.get_graph_by_node(normalize_node).name
  328. # SkipLayerNormalization has two inputs, and one of them is the root input for attention.
  329. if normalize_node.op_type == "LayerNormalization":
  330. add_before_layernorm = self.match_parent(normalize_node, "Add", 0)
  331. if add_before_layernorm is not None:
  332. normalize_node = add_before_layernorm
  333. else:
  334. continue
  335. parent = self.get_parent(normalize_node, 1)
  336. if parent is None or parent.op_type not in [
  337. "SkipLayerNormalization",
  338. "LayerNormalization",
  339. "Reshape",
  340. ]:
  341. parent = self.get_parent(normalize_node, 0)
  342. if parent is None or parent.op_type not in [
  343. "SkipLayerNormalization",
  344. "LayerNormalization",
  345. "Reshape",
  346. ]:
  347. logger.debug("Failed to match parent of normalize_node")
  348. continue
  349. qkv_nodes = self.match_parent_path(
  350. normalize_node,
  351. ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
  352. [0, 0, 0, 0, 0],
  353. )
  354. if qkv_nodes is None:
  355. qkv_nodes = self.match_parent_path(
  356. normalize_node,
  357. ["MatMul", "Reshape", "Transpose", "MatMul"],
  358. [1, 0, 0, 0],
  359. )
  360. if qkv_nodes is None:
  361. qkv_nodes = self.match_parent_path(normalize_node, ["Add", "Einsum", "Einsum"], [0, 0, 0])
  362. if qkv_nodes is None:
  363. logger.debug("Failed to match qkv nodes")
  364. continue
  365. matmul_qkv = qkv_nodes[-1]
  366. v_nodes = self.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0])
  367. if v_nodes is None:
  368. v_nodes = self.match_parent_path(matmul_qkv, ["Add", "Einsum"], [1, 0])
  369. if v_nodes is None:
  370. logger.debug("Failed to match v path")
  371. continue
  372. add_v = v_nodes[-2]
  373. matmul_v = v_nodes[-1]
  374. qk_nodes = self.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0])
  375. if qk_nodes is None:
  376. qk_nodes = self.match_parent_path(matmul_qkv, ["Softmax", "Add", "Einsum"], [0, 0, 0])
  377. if qk_nodes is None:
  378. logger.debug("Failed to match qk_paths")
  379. continue
  380. matmul_qk = qk_nodes[-1]
  381. q_nodes = self.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, 0])
  382. if q_nodes is None:
  383. q_nodes = self.match_parent_path(matmul_qk, ["Add", "Einsum"], [0, 0])
  384. if q_nodes is None:
  385. logger.debug("Failed to match q path")
  386. continue
  387. add_q = q_nodes[-2]
  388. matmul_q = q_nodes[-1]
  389. k_nodes = self.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0])
  390. if k_nodes is None:
  391. k_nodes = self.match_parent_path(matmul_qk, ["Mul", "Add", "Einsum"], [1, 0, 0])
  392. if k_nodes is None:
  393. logger.debug("Failed to match k path")
  394. continue
  395. add_k = k_nodes[-2]
  396. matmul_k = k_nodes[-1]
  397. mask_nodes = self.match_mask_path(qk_nodes[1])
  398. if mask_nodes is None:
  399. logger.debug("Cannot find mask_nodes.")
  400. continue
  401. if not self.has_constant_input(mask_nodes[1], 1):
  402. logger.debug("Sub node expected to have an input with constant value 1.0.")
  403. continue
  404. # add a squeeze node to convert a 3-d mask to 2-d
  405. squeeze_node = self.match_parent_path(mask_nodes[-1], ["Squeeze"], [0]) or self.match_parent_path(
  406. mask_nodes[-1], ["Expand"], [0]
  407. )
  408. squeeze_node_name = "Squeeze_3d_to_2d_mask"
  409. squeeze_output_name = squeeze_node_name + "_output"
  410. if squeeze_node is None and len(mask_nodes) == 5 and self.find_graph_input(mask_nodes[-1].input[0]) is None:
  411. mask_input = mask_nodes[-1].input[1]
  412. self.add_node(
  413. helper.make_node(
  414. "Squeeze",
  415. [mask_input],
  416. [squeeze_output_name],
  417. squeeze_node_name,
  418. axes=[1],
  419. ),
  420. graph_name,
  421. )
  422. mask_nodes[-1].input[0] = squeeze_output_name
  423. is_same_root = self.check_attention_input(matmul_q, matmul_k, matmul_v, parent, output_name_to_node)
  424. if is_same_root:
  425. mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
  426. logger.debug("Create an Attention node.")
  427. # For tf models, q and v are flipped.
  428. attention_node = self.attention_fusion.create_attention_node(
  429. mask_index,
  430. matmul_k,
  431. matmul_q,
  432. matmul_v,
  433. add_k,
  434. add_q,
  435. add_v,
  436. self.num_heads,
  437. self.hidden_size,
  438. parent.output[0],
  439. qkv_nodes[2].output[0],
  440. None,
  441. )
  442. if attention_node is None:
  443. continue
  444. if qkv_nodes[1].op_type == "Einsum":
  445. # add reshape before einsum
  446. tensor = helper.make_tensor(
  447. name=qkv_nodes[1].name + "_newshape",
  448. data_type=TensorProto.INT64,
  449. dims=[4],
  450. vals=np.int64(
  451. [
  452. [
  453. 0,
  454. 0,
  455. self.num_heads,
  456. int(self.hidden_size / self.num_heads),
  457. ]
  458. ]
  459. ).tobytes(),
  460. raw=True,
  461. )
  462. self.add_initializer(tensor, graph_name)
  463. reshape_ = helper.make_node(
  464. "Reshape",
  465. inputs=[
  466. attention_node.output[0],
  467. qkv_nodes[1].name + "_newshape",
  468. ],
  469. outputs=[qkv_nodes[1].name + "_reshape_output"],
  470. name=qkv_nodes[1].name + "_reshape",
  471. )
  472. qkv_nodes[1].input[0] = qkv_nodes[1].name + "_reshape_output"
  473. self.add_node(reshape_, graph_name)
  474. if parent.op_type == "Reshape":
  475. # Temporary work around: we require the skiplayernorm and attention op be fed with 3-d input
  476. hidden_size = numpy_helper.to_array(self.get_initializer(parent.input[1]))[1]
  477. tensor = helper.make_tensor(
  478. name=parent.name + "_modified",
  479. data_type=TensorProto.INT64,
  480. dims=[3],
  481. vals=np.int64([[1, -1, hidden_size]]).tobytes(),
  482. raw=True,
  483. )
  484. self.add_initializer(tensor, graph_name)
  485. parent.input[1] = parent.name + "_modified"
  486. self.add_node(attention_node, graph_name)
  487. attention_count += 1
  488. nodes_to_remove.extend(qkv_nodes[2:])
  489. nodes_to_remove.extend(qk_nodes)
  490. nodes_to_remove.extend(q_nodes)
  491. nodes_to_remove.extend(k_nodes)
  492. nodes_to_remove.extend(v_nodes)
  493. nodes_to_remove.extend(mask_nodes)
  494. else:
  495. logger.debug("Root node not matched.")
  496. continue
  497. self.remove_nodes(nodes_to_remove)
  498. self.update_graph()
  499. logger.info(f"Fused Attention count:{attention_count}")
  500. def preprocess(self):
  501. self.remove_identity()
  502. self.process_embedding()
  503. self.skip_reshape()
  504. def skip_reshape(self):
  505. count = 0
  506. reshape_nodes = self.get_nodes_by_op_type("Reshape")
  507. for reshape_node in reshape_nodes:
  508. parent = self.get_parent(reshape_node, 0)
  509. if parent is not None and parent.op_type == "Reshape":
  510. reshape_node.input[0] = parent.input[0]
  511. count += 1
  512. if count > 0:
  513. logger.info(f"Skip consequent Reshape count: {count}")
  514. def remove_reshape_before_first_attention(self):
  515. attention_nodes = self.get_nodes_by_op_type("Attention")
  516. for attention_node in attention_nodes:
  517. path = self.match_parent_path(attention_node, ["Reshape", "EmbedLayerNormalization"], [0, 0])
  518. if path is None:
  519. continue
  520. logger.info("Remove Reshape before first Attention node.")
  521. reshape, _ = path
  522. self.replace_input_of_all_nodes(reshape.output[0], reshape.input[0])
  523. self.remove_node(reshape)
  524. break
  525. def postprocess(self):
  526. self.remove_reshape_before_first_attention()
  527. self.prune_graph()