You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
66 lines
2.2 KiB
66 lines
2.2 KiB
# -------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
# --------------------------------------------------------------------------
|
|
|
|
from logging import getLogger
|
|
|
|
from fusion_base import Fusion
|
|
from fusion_utils import NumpyHelper
|
|
from onnx import helper
|
|
from onnx_model import OnnxModel
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
class FusionBiasGelu(Fusion):
|
|
def __init__(self, model: OnnxModel, is_fastgelu):
|
|
if is_fastgelu:
|
|
super().__init__(model, "FastGelu", "FastGelu", "add bias")
|
|
else:
|
|
super().__init__(model, "BiasGelu", "Gelu")
|
|
|
|
def fuse(self, node, input_name_to_nodes, output_name_to_node):
|
|
gelu_op_type = node.op_type
|
|
fuse_op_type = "BiasGelu" if gelu_op_type == "Gelu" else "FastGelu"
|
|
|
|
if len(node.input) != 1:
|
|
return
|
|
|
|
nodes = self.model.match_parent_path(node, ["Add", "MatMul"], [0, None])
|
|
if nodes is None:
|
|
return
|
|
(add, matmul) = nodes
|
|
|
|
bias_weight = None
|
|
# bias should be one dimension
|
|
bias_index = -1
|
|
for i, input in enumerate(add.input):
|
|
initializer = self.model.get_initializer(input)
|
|
if initializer is None:
|
|
continue
|
|
bias_index = i
|
|
bias_weight = NumpyHelper.to_array(initializer)
|
|
break
|
|
if bias_weight is None:
|
|
return
|
|
if len(bias_weight.shape) != 1:
|
|
return
|
|
|
|
subgraph_nodes = [node, add]
|
|
if not self.model.is_safe_to_fuse_nodes(
|
|
subgraph_nodes, [node.output[0]], input_name_to_nodes, output_name_to_node
|
|
):
|
|
return
|
|
|
|
self.nodes_to_remove.extend(subgraph_nodes)
|
|
|
|
fused_node = helper.make_node(
|
|
fuse_op_type,
|
|
inputs=[matmul.output[0], add.input[bias_index]],
|
|
outputs=node.output,
|
|
name=self.model.create_node_name(fuse_op_type, gelu_op_type + "_AddBias_"),
|
|
)
|
|
fused_node.domain = "com.microsoft"
|
|
self.nodes_to_add.append(fused_node)
|
|
self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
|