m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

117 lines
4.2 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from typing import Dict
  7. from fusion_base import Fusion
  8. from fusion_utils import FusionUtils
  9. from onnx import helper
  10. from onnx_model import OnnxModel
  11. logger = getLogger(__name__)
  12. class FusionQOrderedGelu(Fusion):
  13. def __init__(self, model: OnnxModel):
  14. super().__init__(model, "QOrderedGelu", ["Gelu", "FastGelu"])
  15. def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
  16. """
  17. INPUT PATTERN
  18. Fuse (quantized) Gelu subgraph into one node QOrderedGelu:
  19. -> quantized input -> DQ -> Gelu -> Q ->
  20. (or)
  21. -> quantized input -> DQ -> FastGelu -> Q ->
  22. OUTPUT PATTERN
  23. -> QOrderedGelu ->
  24. """
  25. gelu_children = self.model.get_children(node, input_name_to_nodes)
  26. # Should only have 1 child - QuantizeLinear (or)
  27. # Should have 2 children - QuantizeLinear + Shape
  28. if not (
  29. (len(gelu_children) == 1 and gelu_children[0].op_type == "QuantizeLinear")
  30. or (
  31. len(gelu_children) == 2
  32. and gelu_children[0].op_type == "QuantizeLinear"
  33. and gelu_children[1].op_type == "Shape"
  34. )
  35. ):
  36. return
  37. downstream_quantize_node = gelu_children[0]
  38. downstream_shape_node = None
  39. if len(gelu_children) == 2:
  40. downstream_shape_node = gelu_children[1]
  41. if not FusionUtils.check_qdq_node_for_fusion(downstream_quantize_node, self.model):
  42. return
  43. # The first input to Gelu should flow through a DequantizeLinear node
  44. first_path_id, first_input_parent_nodes, _ = self.model.match_parent_paths(
  45. node,
  46. [(["DequantizeLinear"], [0])],
  47. output_name_to_node,
  48. )
  49. if first_path_id < 0:
  50. return
  51. upstream_dequantize_node = first_input_parent_nodes[0]
  52. if not FusionUtils.check_qdq_node_for_fusion(upstream_dequantize_node, self.model):
  53. return
  54. # Fusion logic
  55. subgraph_nodes = [node] # Gelu/FastGelu
  56. subgraph_nodes.extend([downstream_quantize_node, upstream_dequantize_node]) # Relevant Q, DQ nodes
  57. if not self.model.is_safe_to_fuse_nodes(
  58. subgraph_nodes,
  59. [node.output[0], downstream_quantize_node.output[0]]
  60. if downstream_shape_node is not None
  61. else downstream_quantize_node.output,
  62. input_name_to_nodes,
  63. output_name_to_node,
  64. ):
  65. logger.debug(f"It is not safe to fuse QOrderedGelu node. Skip")
  66. return
  67. self.nodes_to_remove.extend(subgraph_nodes)
  68. ordered_gelu_node = helper.make_node(
  69. "QOrderedGelu",
  70. inputs=[
  71. upstream_dequantize_node.input[0],
  72. upstream_dequantize_node.input[1],
  73. downstream_quantize_node.input[1],
  74. ],
  75. outputs=[downstream_quantize_node.output[0]],
  76. name=self.model.create_node_name("QOrderedGelu", name_prefix="QOrderedGelu"),
  77. )
  78. # Arrange the downstream Shape's input to be fed from the
  79. # downstream QuantizeLinear node, so that fusion will
  80. # be deemed safe
  81. if downstream_shape_node is not None:
  82. self.model.replace_node_input(
  83. downstream_shape_node, downstream_shape_node.input[0], downstream_quantize_node.output[0]
  84. )
  85. # TODO: We only support CuBlasLt order ORDER_ROW for now.
  86. # Once we start supporting other data ordering format(s), we
  87. # will support user configuring the data ordering for the op.
  88. ordered_gelu_node.attribute.extend([helper.make_attribute("order_X", 1)])
  89. ordered_gelu_node.attribute.extend([helper.make_attribute("order_Y", 1)])
  90. ordered_gelu_node.domain = "com.microsoft"
  91. self.nodes_to_add.append(ordered_gelu_node)
  92. self.node_name_to_graph_name[ordered_gelu_node.name] = self.this_graph_name