m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

66 lines
2.2 KiB

# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger
from fusion_base import Fusion
from fusion_utils import NumpyHelper
from onnx import helper
from onnx_model import OnnxModel
logger = getLogger(__name__)
class FusionBiasGelu(Fusion):
def __init__(self, model: OnnxModel, is_fastgelu):
if is_fastgelu:
super().__init__(model, "FastGelu", "FastGelu", "add bias")
else:
super().__init__(model, "BiasGelu", "Gelu")
def fuse(self, node, input_name_to_nodes, output_name_to_node):
gelu_op_type = node.op_type
fuse_op_type = "BiasGelu" if gelu_op_type == "Gelu" else "FastGelu"
if len(node.input) != 1:
return
nodes = self.model.match_parent_path(node, ["Add", "MatMul"], [0, None])
if nodes is None:
return
(add, matmul) = nodes
bias_weight = None
# bias should be one dimension
bias_index = -1
for i, input in enumerate(add.input):
initializer = self.model.get_initializer(input)
if initializer is None:
continue
bias_index = i
bias_weight = NumpyHelper.to_array(initializer)
break
if bias_weight is None:
return
if len(bias_weight.shape) != 1:
return
subgraph_nodes = [node, add]
if not self.model.is_safe_to_fuse_nodes(
subgraph_nodes, [node.output[0]], input_name_to_nodes, output_name_to_node
):
return
self.nodes_to_remove.extend(subgraph_nodes)
fused_node = helper.make_node(
fuse_op_type,
inputs=[matmul.output[0], add.input[bias_index]],
outputs=node.output,
name=self.model.create_node_name(fuse_op_type, gelu_op_type + "_AddBias_"),
)
fused_node.domain = "com.microsoft"
self.nodes_to_add.append(fused_node)
self.node_name_to_graph_name[fused_node.name] = self.this_graph_name