# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger from typing import Dict, List, Union from fusion_base import Fusion from fusion_utils import NumpyHelper from onnx import NodeProto, TensorProto, helper from onnx_model import OnnxModel logger = getLogger(__name__) class FusionGemmFastGelu(Fusion): def __init__(self, model: OnnxModel): super().__init__(model, "GemmFastGelu", "FastGelu", "GemmFastGelu") self.shape_infer = None self.shape_infer_done = False def get_dimensions_from_tensor_proto(self, tensor_proto: TensorProto) -> Union[int, None]: if tensor_proto.type.tensor_type.HasField("shape"): return len(tensor_proto.type.tensor_type.shape.dim) else: return None def get_dimensions(self, input_name: str) -> Union[int, None]: graph_input = self.model.find_graph_input(input_name) if graph_input: return self.get_dimensions_from_tensor_proto(graph_input) if not self.shape_infer_done: self.shape_infer = self.model.infer_runtime_shape({}, update=True) self.shape_infer_done = True if self.shape_infer is not None: return self.get_dimensions_from_tensor_proto(self.shape_infer.known_vi_[input_name]) return None def fuse( self, node: NodeProto, input_name_to_nodes: Dict[str, List[NodeProto]], output_name_to_node: Dict[str, NodeProto], ): """ This pattern is from PyTorch bert model Fuse MatMul with FastGelu into one node: [root] --> MatMul --> FastGelu --> """ has_bias = False if len(node.input) == 2: has_bias = True match_nodes = self.model.match_parent_path(node, ["MatMul"], [0]) if match_nodes is None: return matmul = match_nodes[0] # matmul input X should >= two dimension, input weight should be two dimension weight_index = -1 x_dims = 0 weight = None for i, input in enumerate(matmul.input): initializer = self.model.get_initializer(input) if initializer is None: x_dims = self.get_dimensions(matmul.input[i]) else: weight_index = i weight = NumpyHelper.to_array(initializer) if weight is None: return if len(weight.shape) != 2: return if x_dims < len(weight.shape): return # bias weight should be one dimension bias_index = -1 if has_bias: bias_weight = None for i, input in enumerate(node.input): initializer = self.model.get_initializer(input) if initializer is None: continue bias_index = i bias_weight = NumpyHelper.to_array(initializer) break if bias_weight is None: return if len(bias_weight.shape) != 1: return subgraph_nodes = [node, matmul] if not self.model.is_safe_to_fuse_nodes( subgraph_nodes, [node.output[0]], input_name_to_nodes, output_name_to_node ): return self.nodes_to_remove.extend(subgraph_nodes) inputs = ( [matmul.input[1 - weight_index], matmul.input[weight_index], node.input[bias_index]] if has_bias else [matmul.input[1 - weight_index], matmul.input[weight_index]] ) fused_node = helper.make_node( "GemmFastGelu", inputs=inputs, outputs=node.output, name=self.model.create_node_name("GemmFastGelu"), ) fused_node.domain = "com.microsoft" self.nodes_to_add.append(fused_node) self.node_name_to_graph_name[fused_node.name] = self.this_graph_name