You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
98 lines
3.6 KiB
98 lines
3.6 KiB
# -------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
# --------------------------------------------------------------------------
|
|
import logging
|
|
|
|
import onnx
|
|
from fusion_gpt_attention import FusionGptAttention
|
|
from fusion_gpt_attention_megatron import FusionGptAttentionMegatron
|
|
from fusion_gpt_attention_no_past import FusionGptAttentionNoPast
|
|
from onnx_model_bert import BertOnnxModel
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Gpt2OnnxModel(BertOnnxModel):
|
|
def __init__(self, model, num_heads, hidden_size):
|
|
super().__init__(model, num_heads, hidden_size)
|
|
|
|
def fuse_attention(self):
|
|
if len(self.model.graph.input) == 1 or len(self.model.graph.output) == 1:
|
|
fusion = FusionGptAttentionNoPast(self, self.num_heads)
|
|
fusion.apply()
|
|
else:
|
|
fusion = FusionGptAttention(self, self.num_heads)
|
|
fusion.apply()
|
|
fusion = FusionGptAttentionMegatron(self, self.num_heads)
|
|
fusion.apply()
|
|
|
|
def postprocess(self):
|
|
"""
|
|
Remove extra reshape nodes.
|
|
"""
|
|
logger.debug(f"start postprocessing...")
|
|
|
|
input_name_to_nodes = self.input_name_to_nodes()
|
|
output_name_to_node = self.output_name_to_node()
|
|
|
|
reshape_count = 0
|
|
for gemm_node in self.get_nodes_by_op_type("Gemm"):
|
|
reshape_after_gemm = self.find_first_child_by_type(
|
|
gemm_node, "Reshape", input_name_to_nodes, recursive=False
|
|
)
|
|
|
|
return_indice = []
|
|
nodes = self.match_parent_path(gemm_node, ["Reshape", "FastGelu"], [0, 0], output_name_to_node)
|
|
if nodes is None:
|
|
nodes = self.match_parent_path(
|
|
gemm_node,
|
|
["Reshape", "LayerNormalization"],
|
|
[0, 0],
|
|
output_name_to_node,
|
|
)
|
|
|
|
if nodes is None:
|
|
nodes = self.match_parent_path(
|
|
gemm_node,
|
|
["Reshape", "SkipLayerNormalization"],
|
|
[0, 0],
|
|
output_name_to_node,
|
|
)
|
|
|
|
if nodes is None:
|
|
continue
|
|
|
|
(reshape_before_gemm, root_node) = nodes
|
|
|
|
matmul_node_name = self.create_node_name("MatMul", "FullyConnect_MatMul")
|
|
matmul_node = onnx.helper.make_node(
|
|
"MatMul",
|
|
inputs=[matmul_node_name + "_input", gemm_node.input[1]],
|
|
outputs=[matmul_node_name + "_output"],
|
|
name=matmul_node_name,
|
|
)
|
|
|
|
add_node_name = self.create_node_name("Add", "FullyConnect_Add")
|
|
add_node = onnx.helper.make_node(
|
|
"Add",
|
|
inputs=[matmul_node_name + "_output", gemm_node.input[2]],
|
|
outputs=[add_node_name + "_output"],
|
|
name=add_node_name,
|
|
)
|
|
|
|
self.replace_input_of_all_nodes(reshape_after_gemm.output[0], add_node_name + "_output")
|
|
|
|
# Link root node output with MatMul
|
|
self.replace_input_of_all_nodes(root_node.output[0], matmul_node_name + "_input")
|
|
root_node.output[0] = matmul_node_name + "_input"
|
|
|
|
self.replace_input_of_all_nodes(reshape_after_gemm.output[0], add_node_name + "_output")
|
|
|
|
self.add_node(matmul_node)
|
|
self.add_node(add_node)
|
|
|
|
reshape_count += 2
|
|
|
|
self.prune_graph()
|
|
logger.info(f"postprocess: remove Reshape count:{reshape_count}")
|