You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
122 lines
4.0 KiB
122 lines
4.0 KiB
# -------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
# --------------------------------------------------------------------------
|
|
|
|
from logging import getLogger
|
|
from typing import Dict, List, Union
|
|
|
|
from fusion_base import Fusion
|
|
from fusion_utils import NumpyHelper
|
|
from onnx import NodeProto, TensorProto, helper
|
|
from onnx_model import OnnxModel
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
class FusionGemmFastGelu(Fusion):
|
|
def __init__(self, model: OnnxModel):
|
|
super().__init__(model, "GemmFastGelu", "FastGelu", "GemmFastGelu")
|
|
self.shape_infer = None
|
|
self.shape_infer_done = False
|
|
|
|
def get_dimensions_from_tensor_proto(self, tensor_proto: TensorProto) -> Union[int, None]:
|
|
if tensor_proto.type.tensor_type.HasField("shape"):
|
|
return len(tensor_proto.type.tensor_type.shape.dim)
|
|
else:
|
|
return None
|
|
|
|
def get_dimensions(self, input_name: str) -> Union[int, None]:
|
|
graph_input = self.model.find_graph_input(input_name)
|
|
if graph_input:
|
|
return self.get_dimensions_from_tensor_proto(graph_input)
|
|
|
|
if not self.shape_infer_done:
|
|
self.shape_infer = self.model.infer_runtime_shape({}, update=True)
|
|
self.shape_infer_done = True
|
|
|
|
if self.shape_infer is not None:
|
|
return self.get_dimensions_from_tensor_proto(self.shape_infer.known_vi_[input_name])
|
|
|
|
return None
|
|
|
|
def fuse(
|
|
self,
|
|
node: NodeProto,
|
|
input_name_to_nodes: Dict[str, List[NodeProto]],
|
|
output_name_to_node: Dict[str, NodeProto],
|
|
):
|
|
"""
|
|
This pattern is from PyTorch bert model
|
|
Fuse MatMul with FastGelu into one node:
|
|
|
|
[root] --> MatMul --> FastGelu -->
|
|
|
|
"""
|
|
has_bias = False
|
|
if len(node.input) == 2:
|
|
has_bias = True
|
|
|
|
match_nodes = self.model.match_parent_path(node, ["MatMul"], [0])
|
|
if match_nodes is None:
|
|
return
|
|
matmul = match_nodes[0]
|
|
|
|
# matmul input X should >= two dimension, input weight should be two dimension
|
|
weight_index = -1
|
|
x_dims = 0
|
|
weight = None
|
|
|
|
for i, input in enumerate(matmul.input):
|
|
initializer = self.model.get_initializer(input)
|
|
if initializer is None:
|
|
x_dims = self.get_dimensions(matmul.input[i])
|
|
else:
|
|
weight_index = i
|
|
weight = NumpyHelper.to_array(initializer)
|
|
if weight is None:
|
|
return
|
|
if len(weight.shape) != 2:
|
|
return
|
|
if x_dims < len(weight.shape):
|
|
return
|
|
|
|
# bias weight should be one dimension
|
|
bias_index = -1
|
|
if has_bias:
|
|
bias_weight = None
|
|
for i, input in enumerate(node.input):
|
|
initializer = self.model.get_initializer(input)
|
|
if initializer is None:
|
|
continue
|
|
bias_index = i
|
|
bias_weight = NumpyHelper.to_array(initializer)
|
|
break
|
|
if bias_weight is None:
|
|
return
|
|
if len(bias_weight.shape) != 1:
|
|
return
|
|
|
|
subgraph_nodes = [node, matmul]
|
|
if not self.model.is_safe_to_fuse_nodes(
|
|
subgraph_nodes, [node.output[0]], input_name_to_nodes, output_name_to_node
|
|
):
|
|
return
|
|
|
|
self.nodes_to_remove.extend(subgraph_nodes)
|
|
|
|
inputs = (
|
|
[matmul.input[1 - weight_index], matmul.input[weight_index], node.input[bias_index]]
|
|
if has_bias
|
|
else [matmul.input[1 - weight_index], matmul.input[weight_index]]
|
|
)
|
|
|
|
fused_node = helper.make_node(
|
|
"GemmFastGelu",
|
|
inputs=inputs,
|
|
outputs=node.output,
|
|
name=self.model.create_node_name("GemmFastGelu"),
|
|
)
|
|
fused_node.domain = "com.microsoft"
|
|
self.nodes_to_add.append(fused_node)
|
|
self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
|