m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

122 lines
4.0 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from typing import Dict, List, Union
  7. from fusion_base import Fusion
  8. from fusion_utils import NumpyHelper
  9. from onnx import NodeProto, TensorProto, helper
  10. from onnx_model import OnnxModel
  11. logger = getLogger(__name__)
  12. class FusionGemmFastGelu(Fusion):
  13. def __init__(self, model: OnnxModel):
  14. super().__init__(model, "GemmFastGelu", "FastGelu", "GemmFastGelu")
  15. self.shape_infer = None
  16. self.shape_infer_done = False
  17. def get_dimensions_from_tensor_proto(self, tensor_proto: TensorProto) -> Union[int, None]:
  18. if tensor_proto.type.tensor_type.HasField("shape"):
  19. return len(tensor_proto.type.tensor_type.shape.dim)
  20. else:
  21. return None
  22. def get_dimensions(self, input_name: str) -> Union[int, None]:
  23. graph_input = self.model.find_graph_input(input_name)
  24. if graph_input:
  25. return self.get_dimensions_from_tensor_proto(graph_input)
  26. if not self.shape_infer_done:
  27. self.shape_infer = self.model.infer_runtime_shape({}, update=True)
  28. self.shape_infer_done = True
  29. if self.shape_infer is not None:
  30. return self.get_dimensions_from_tensor_proto(self.shape_infer.known_vi_[input_name])
  31. return None
  32. def fuse(
  33. self,
  34. node: NodeProto,
  35. input_name_to_nodes: Dict[str, List[NodeProto]],
  36. output_name_to_node: Dict[str, NodeProto],
  37. ):
  38. """
  39. This pattern is from PyTorch bert model
  40. Fuse MatMul with FastGelu into one node:
  41. [root] --> MatMul --> FastGelu -->
  42. """
  43. has_bias = False
  44. if len(node.input) == 2:
  45. has_bias = True
  46. match_nodes = self.model.match_parent_path(node, ["MatMul"], [0])
  47. if match_nodes is None:
  48. return
  49. matmul = match_nodes[0]
  50. # matmul input X should >= two dimension, input weight should be two dimension
  51. weight_index = -1
  52. x_dims = 0
  53. weight = None
  54. for i, input in enumerate(matmul.input):
  55. initializer = self.model.get_initializer(input)
  56. if initializer is None:
  57. x_dims = self.get_dimensions(matmul.input[i])
  58. else:
  59. weight_index = i
  60. weight = NumpyHelper.to_array(initializer)
  61. if weight is None:
  62. return
  63. if len(weight.shape) != 2:
  64. return
  65. if x_dims < len(weight.shape):
  66. return
  67. # bias weight should be one dimension
  68. bias_index = -1
  69. if has_bias:
  70. bias_weight = None
  71. for i, input in enumerate(node.input):
  72. initializer = self.model.get_initializer(input)
  73. if initializer is None:
  74. continue
  75. bias_index = i
  76. bias_weight = NumpyHelper.to_array(initializer)
  77. break
  78. if bias_weight is None:
  79. return
  80. if len(bias_weight.shape) != 1:
  81. return
  82. subgraph_nodes = [node, matmul]
  83. if not self.model.is_safe_to_fuse_nodes(
  84. subgraph_nodes, [node.output[0]], input_name_to_nodes, output_name_to_node
  85. ):
  86. return
  87. self.nodes_to_remove.extend(subgraph_nodes)
  88. inputs = (
  89. [matmul.input[1 - weight_index], matmul.input[weight_index], node.input[bias_index]]
  90. if has_bias
  91. else [matmul.input[1 - weight_index], matmul.input[weight_index]]
  92. )
  93. fused_node = helper.make_node(
  94. "GemmFastGelu",
  95. inputs=inputs,
  96. outputs=node.output,
  97. name=self.model.create_node_name("GemmFastGelu"),
  98. )
  99. fused_node.domain = "com.microsoft"
  100. self.nodes_to_add.append(fused_node)
  101. self.node_name_to_graph_name[fused_node.name] = self.this_graph_name