# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger from typing import Dict from fusion_base import Fusion from fusion_utils import FusionUtils from onnx import helper from onnx_model import OnnxModel logger = getLogger(__name__) class FusionQOrderedGelu(Fusion): def __init__(self, model: OnnxModel): super().__init__(model, "QOrderedGelu", ["Gelu", "FastGelu"]) def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): """ INPUT PATTERN Fuse (quantized) Gelu subgraph into one node QOrderedGelu: -> quantized input -> DQ -> Gelu -> Q -> (or) -> quantized input -> DQ -> FastGelu -> Q -> OUTPUT PATTERN -> QOrderedGelu -> """ gelu_children = self.model.get_children(node, input_name_to_nodes) # Should only have 1 child - QuantizeLinear (or) # Should have 2 children - QuantizeLinear + Shape if not ( (len(gelu_children) == 1 and gelu_children[0].op_type == "QuantizeLinear") or ( len(gelu_children) == 2 and gelu_children[0].op_type == "QuantizeLinear" and gelu_children[1].op_type == "Shape" ) ): return downstream_quantize_node = gelu_children[0] downstream_shape_node = None if len(gelu_children) == 2: downstream_shape_node = gelu_children[1] if not FusionUtils.check_qdq_node_for_fusion(downstream_quantize_node, self.model): return # The first input to Gelu should flow through a DequantizeLinear node first_path_id, first_input_parent_nodes, _ = self.model.match_parent_paths( node, [(["DequantizeLinear"], [0])], output_name_to_node, ) if first_path_id < 0: return upstream_dequantize_node = first_input_parent_nodes[0] if not FusionUtils.check_qdq_node_for_fusion(upstream_dequantize_node, self.model): return # Fusion logic subgraph_nodes = [node] # Gelu/FastGelu subgraph_nodes.extend([downstream_quantize_node, upstream_dequantize_node]) # Relevant Q, DQ nodes if not self.model.is_safe_to_fuse_nodes( subgraph_nodes, [node.output[0], downstream_quantize_node.output[0]] if downstream_shape_node is not None else downstream_quantize_node.output, input_name_to_nodes, output_name_to_node, ): logger.debug(f"It is not safe to fuse QOrderedGelu node. Skip") return self.nodes_to_remove.extend(subgraph_nodes) ordered_gelu_node = helper.make_node( "QOrderedGelu", inputs=[ upstream_dequantize_node.input[0], upstream_dequantize_node.input[1], downstream_quantize_node.input[1], ], outputs=[downstream_quantize_node.output[0]], name=self.model.create_node_name("QOrderedGelu", name_prefix="QOrderedGelu"), ) # Arrange the downstream Shape's input to be fed from the # downstream QuantizeLinear node, so that fusion will # be deemed safe if downstream_shape_node is not None: self.model.replace_node_input( downstream_shape_node, downstream_shape_node.input[0], downstream_quantize_node.output[0] ) # TODO: We only support CuBlasLt order ORDER_ROW for now. # Once we start supporting other data ordering format(s), we # will support user configuring the data ordering for the op. ordered_gelu_node.attribute.extend([helper.make_attribute("order_X", 1)]) ordered_gelu_node.attribute.extend([helper.make_attribute("order_Y", 1)]) ordered_gelu_node.domain = "com.microsoft" self.nodes_to_add.append(ordered_gelu_node) self.node_name_to_graph_name[ordered_gelu_node.name] = self.this_graph_name