|
|
# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger from typing import Dict
from fusion_base import Fusion from onnx import helper from onnx_model import OnnxModel
logger = getLogger(__name__)
class FusionBiasSplitGelu(Fusion): def __init__(self, model: OnnxModel): super().__init__(model, "BiasSplitGelu", "Gelu")
def fuse(self, gelu_node, input_name_to_nodes: Dict, output_name_to_node: Dict): """
[root] --->Add --------------------> Slice ---------------> Mul --> | ^ ^ | | | +----------------------------+---Slice --> Gelu---+ | | ^ | |-----| | | | | Mul Mul | ^ ^ v | | Shape ---> Gather --> Add --> Div --+ """
if gelu_node.output[0] not in input_name_to_nodes: return children = input_name_to_nodes[gelu_node.output[0]] if len(children) != 1 or children[0].op_type != "Mul": return mul_after_gelu = children[0]
slice_before_gelu = self.model.match_parent(gelu_node, "Slice", 0, output_name_to_node) if slice_before_gelu is None: return
if self.model.find_constant_input(slice_before_gelu, -1, delta=0.001) != 3: return
add_output = slice_before_gelu.input[0]
start_index_nodes = self.model.match_parent_path( slice_before_gelu, ["Div", "Add", "Gather", "Shape", "Add"], [1, 0, 0, 0, 0], output_name_to_node, # Mul(1) is optional ) if start_index_nodes is None: start_index_nodes = self.model.match_parent_path( slice_before_gelu, ["Mul", "Div", "Add", "Gather", "Shape", "Add"], [1, 0, 0, 0, 0, 0], output_name_to_node, )
if start_index_nodes is None or start_index_nodes[-2].input[0] != add_output: return
end_index_nodes = self.model.match_parent_path(slice_before_gelu, ["Mul", "Div"], [2, 0], output_name_to_node)
if ( end_index_nodes is None or end_index_nodes[1] not in start_index_nodes ): # the Div is parent of both two Mul nodes return
slice_before_mul = self.model.match_parent(mul_after_gelu, "Slice", 0, output_name_to_node) if slice_before_mul is None: return
if ( slice_before_mul.input[2] != slice_before_gelu.input[1] ): # end index of slice_before_mul is start index of slice_before_gelu return
subgraph_nodes = start_index_nodes + [ end_index_nodes[0], mul_after_gelu, gelu_node, slice_before_mul, slice_before_gelu, ] subgraph_output = mul_after_gelu.output[0] if not self.model.is_safe_to_fuse_nodes( subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node ): logger.info("Skip fuse BiasSplitGelu since it is not safe to fuse the subgraph.") return
add_node = start_index_nodes[-1] bias_index, _value = self.model.get_constant_input(add_node) if not isinstance(bias_index, int): return self.nodes_to_remove.extend(subgraph_nodes) node_name = self.model.create_node_name("BiasSplitGelu", name_prefix="BiasSplitGelu") fused_node = helper.make_node( "BiasSplitGelu", inputs=[add_node.input[1 - bias_index], add_node.input[bias_index]], outputs=[subgraph_output], name=node_name, ) fused_node.domain = "com.microsoft" self.nodes_to_add.append(fused_node) self.node_name_to_graph_name[node_name] = self.this_graph_name
|