m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

98 lines
3.6 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import logging
  6. import onnx
  7. from fusion_gpt_attention import FusionGptAttention
  8. from fusion_gpt_attention_megatron import FusionGptAttentionMegatron
  9. from fusion_gpt_attention_no_past import FusionGptAttentionNoPast
  10. from onnx_model_bert import BertOnnxModel
  11. logger = logging.getLogger(__name__)
  12. class Gpt2OnnxModel(BertOnnxModel):
  13. def __init__(self, model, num_heads, hidden_size):
  14. super().__init__(model, num_heads, hidden_size)
  15. def fuse_attention(self):
  16. if len(self.model.graph.input) == 1 or len(self.model.graph.output) == 1:
  17. fusion = FusionGptAttentionNoPast(self, self.num_heads)
  18. fusion.apply()
  19. else:
  20. fusion = FusionGptAttention(self, self.num_heads)
  21. fusion.apply()
  22. fusion = FusionGptAttentionMegatron(self, self.num_heads)
  23. fusion.apply()
  24. def postprocess(self):
  25. """
  26. Remove extra reshape nodes.
  27. """
  28. logger.debug(f"start postprocessing...")
  29. input_name_to_nodes = self.input_name_to_nodes()
  30. output_name_to_node = self.output_name_to_node()
  31. reshape_count = 0
  32. for gemm_node in self.get_nodes_by_op_type("Gemm"):
  33. reshape_after_gemm = self.find_first_child_by_type(
  34. gemm_node, "Reshape", input_name_to_nodes, recursive=False
  35. )
  36. return_indice = []
  37. nodes = self.match_parent_path(gemm_node, ["Reshape", "FastGelu"], [0, 0], output_name_to_node)
  38. if nodes is None:
  39. nodes = self.match_parent_path(
  40. gemm_node,
  41. ["Reshape", "LayerNormalization"],
  42. [0, 0],
  43. output_name_to_node,
  44. )
  45. if nodes is None:
  46. nodes = self.match_parent_path(
  47. gemm_node,
  48. ["Reshape", "SkipLayerNormalization"],
  49. [0, 0],
  50. output_name_to_node,
  51. )
  52. if nodes is None:
  53. continue
  54. (reshape_before_gemm, root_node) = nodes
  55. matmul_node_name = self.create_node_name("MatMul", "FullyConnect_MatMul")
  56. matmul_node = onnx.helper.make_node(
  57. "MatMul",
  58. inputs=[matmul_node_name + "_input", gemm_node.input[1]],
  59. outputs=[matmul_node_name + "_output"],
  60. name=matmul_node_name,
  61. )
  62. add_node_name = self.create_node_name("Add", "FullyConnect_Add")
  63. add_node = onnx.helper.make_node(
  64. "Add",
  65. inputs=[matmul_node_name + "_output", gemm_node.input[2]],
  66. outputs=[add_node_name + "_output"],
  67. name=add_node_name,
  68. )
  69. self.replace_input_of_all_nodes(reshape_after_gemm.output[0], add_node_name + "_output")
  70. # Link root node output with MatMul
  71. self.replace_input_of_all_nodes(root_node.output[0], matmul_node_name + "_input")
  72. root_node.output[0] = matmul_node_name + "_input"
  73. self.replace_input_of_all_nodes(reshape_after_gemm.output[0], add_node_name + "_output")
  74. self.add_node(matmul_node)
  75. self.add_node(add_node)
  76. reshape_count += 2
  77. self.prune_graph()
  78. logger.info(f"postprocess: remove Reshape count:{reshape_count}")