m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

66 lines
2.2 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from fusion_base import Fusion
  7. from fusion_utils import NumpyHelper
  8. from onnx import helper
  9. from onnx_model import OnnxModel
  10. logger = getLogger(__name__)
  11. class FusionBiasGelu(Fusion):
  12. def __init__(self, model: OnnxModel, is_fastgelu):
  13. if is_fastgelu:
  14. super().__init__(model, "FastGelu", "FastGelu", "add bias")
  15. else:
  16. super().__init__(model, "BiasGelu", "Gelu")
  17. def fuse(self, node, input_name_to_nodes, output_name_to_node):
  18. gelu_op_type = node.op_type
  19. fuse_op_type = "BiasGelu" if gelu_op_type == "Gelu" else "FastGelu"
  20. if len(node.input) != 1:
  21. return
  22. nodes = self.model.match_parent_path(node, ["Add", "MatMul"], [0, None])
  23. if nodes is None:
  24. return
  25. (add, matmul) = nodes
  26. bias_weight = None
  27. # bias should be one dimension
  28. bias_index = -1
  29. for i, input in enumerate(add.input):
  30. initializer = self.model.get_initializer(input)
  31. if initializer is None:
  32. continue
  33. bias_index = i
  34. bias_weight = NumpyHelper.to_array(initializer)
  35. break
  36. if bias_weight is None:
  37. return
  38. if len(bias_weight.shape) != 1:
  39. return
  40. subgraph_nodes = [node, add]
  41. if not self.model.is_safe_to_fuse_nodes(
  42. subgraph_nodes, [node.output[0]], input_name_to_nodes, output_name_to_node
  43. ):
  44. return
  45. self.nodes_to_remove.extend(subgraph_nodes)
  46. fused_node = helper.make_node(
  47. fuse_op_type,
  48. inputs=[matmul.output[0], add.input[bias_index]],
  49. outputs=node.output,
  50. name=self.model.create_node_name(fuse_op_type, gelu_op_type + "_AddBias_"),
  51. )
  52. fused_node.domain = "com.microsoft"
  53. self.nodes_to_add.append(fused_node)
  54. self.node_name_to_graph_name[fused_node.name] = self.this_graph_name