You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
230 lines
8.2 KiB
230 lines
8.2 KiB
# -------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
# --------------------------------------------------------------------------
|
|
import logging
|
|
from typing import Union
|
|
|
|
from fusion_attention import AttentionMask, FusionAttention
|
|
from fusion_utils import NumpyHelper
|
|
from onnx import NodeProto, TensorProto, helper, numpy_helper
|
|
from onnx_model import OnnxModel
|
|
from onnx_model_bert import BertOnnxModel
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class FusionTnlrAttention(FusionAttention):
|
|
"""
|
|
Fuse TNLR Attention subgraph into one Attention node.
|
|
TNLR Attention has extra addtion after qk nodes and adopts [S, B, NH] as I/O shape.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: OnnxModel,
|
|
hidden_size: int,
|
|
num_heads: int,
|
|
attention_mask: AttentionMask,
|
|
):
|
|
super().__init__(model, hidden_size, num_heads, attention_mask)
|
|
|
|
def create_attention_node(
|
|
self,
|
|
mask_index: str,
|
|
matmul: NodeProto,
|
|
add: NodeProto,
|
|
num_heads: int,
|
|
hidden_size: int,
|
|
input: str,
|
|
output: str,
|
|
add_qk_str: str,
|
|
) -> Union[NodeProto, None]:
|
|
|
|
assert num_heads > 0
|
|
if hidden_size > 0 and (hidden_size % num_heads) != 0:
|
|
logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
|
|
return None
|
|
|
|
weight = self.model.get_initializer(matmul.input[1])
|
|
bias = self.model.get_initializer(add.input[1]) or self.model.get_initializer(add.input[0])
|
|
|
|
if weight is None or bias is None:
|
|
return None
|
|
|
|
qkv_weight = NumpyHelper.to_array(weight)
|
|
qkv_bias = NumpyHelper.to_array(bias)
|
|
|
|
attention_node_name = self.model.create_node_name("Attention")
|
|
|
|
weight = helper.make_tensor(
|
|
name=attention_node_name + "_qkv_weight",
|
|
data_type=TensorProto.FLOAT,
|
|
dims=[hidden_size, 3 * hidden_size],
|
|
vals=qkv_weight.flatten().tolist(),
|
|
)
|
|
|
|
# Sometimes weights and bias are stored in fp16
|
|
if weight.data_type == 10:
|
|
weight.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(weight).astype(np.float16), weight.name))
|
|
self.model.add_initializer(weight, self.this_graph_name)
|
|
|
|
bias = helper.make_tensor(
|
|
name=attention_node_name + "_qkv_bias",
|
|
data_type=TensorProto.FLOAT,
|
|
dims=[3 * hidden_size],
|
|
vals=qkv_bias.flatten().tolist(),
|
|
)
|
|
if bias.data_type == 10:
|
|
bias.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(bias).astype(np.float16), bias.name))
|
|
self.model.add_initializer(bias, self.this_graph_name)
|
|
|
|
attention_inputs = [
|
|
input,
|
|
attention_node_name + "_qkv_weight",
|
|
attention_node_name + "_qkv_bias",
|
|
]
|
|
if mask_index is not None:
|
|
attention_inputs.append(mask_index)
|
|
else:
|
|
attention_inputs.append("")
|
|
|
|
if add_qk_str is not None:
|
|
attention_inputs.append("")
|
|
attention_inputs.append(add_qk_str)
|
|
|
|
attention_node = helper.make_node(
|
|
"Attention",
|
|
inputs=attention_inputs,
|
|
outputs=[output],
|
|
name=attention_node_name,
|
|
)
|
|
attention_node.domain = "com.microsoft"
|
|
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
|
|
|
|
return attention_node
|
|
|
|
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
|
|
# Sometimes we can not fuse skiplayernormalization since the add before layernorm has an output that used by nodes outside skiplayernorm
|
|
# Conceptually we treat add before layernorm as skiplayernorm node since they share the same pattern
|
|
start_node = normalize_node
|
|
if normalize_node.op_type != "SkipLayerNormalization":
|
|
return
|
|
|
|
# SkipLayerNormalization has two inputs, and one of them is the root input for attention.
|
|
qkv_nodes = self.model.match_parent_path(
|
|
start_node,
|
|
["Where", "Add", "MatMul", "Reshape", "Transpose", "MatMul"],
|
|
[1, 1, 1, 0, 0, 0],
|
|
)
|
|
if qkv_nodes is not None:
|
|
(_, _, matmul_below, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes
|
|
else:
|
|
return
|
|
|
|
other_inputs = []
|
|
for i, input in enumerate(start_node.input):
|
|
if input not in output_name_to_node:
|
|
continue
|
|
|
|
if input == qkv_nodes[0].output[0]:
|
|
continue
|
|
other_inputs.append(input)
|
|
if len(other_inputs) != 1:
|
|
return
|
|
|
|
root_input = other_inputs[0]
|
|
|
|
v_nodes = self.model.match_parent_path(
|
|
matmul_qkv,
|
|
["Transpose", "Reshape", "Slice", "Add", "MatMul"],
|
|
[1, 0, 0, 0, 1],
|
|
)
|
|
if v_nodes is None:
|
|
return
|
|
(_, _, _, add, matmul) = v_nodes
|
|
|
|
upper_nodes = self.model.match_parent_path(matmul, ["Transpose"], [0])
|
|
transpose = upper_nodes[0]
|
|
|
|
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0])
|
|
if qk_nodes is None:
|
|
return
|
|
(_, add_qk, matmul_qk) = qk_nodes
|
|
|
|
q_nodes = self.model.match_parent_path(
|
|
matmul_qk,
|
|
["Mul", "Transpose", "Reshape", "Slice", "Add", "MatMul"],
|
|
[0, 0, 0, 0, 0, 1],
|
|
)
|
|
if q_nodes is None:
|
|
return
|
|
add = q_nodes[-2]
|
|
matmul = q_nodes[-1]
|
|
|
|
k_nodes = self.model.match_parent_path(
|
|
matmul_qk,
|
|
["Transpose", "Reshape", "Slice", "Add", "MatMul"],
|
|
[1, 0, 0, 0, 1],
|
|
)
|
|
if k_nodes is None:
|
|
return
|
|
add = k_nodes[-2]
|
|
matmul = k_nodes[-1]
|
|
|
|
relative_position_bias_nodes = self.model.match_parent_path(add_qk, ["Reshape", "Where"], [1, 0])
|
|
if relative_position_bias_nodes is None:
|
|
return
|
|
|
|
if matmul.input[0] == root_input:
|
|
mask_index = None
|
|
attention_last_node = reshape_qkv
|
|
# number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
|
|
# the input_hidden_size represents the input hidden size, this is used as needed but hidden sizes for Q, K are extracted appropriately
|
|
new_node = self.create_attention_node(
|
|
mask_index,
|
|
matmul,
|
|
add,
|
|
self.num_heads,
|
|
self.hidden_size,
|
|
root_input,
|
|
attention_last_node.output[0],
|
|
relative_position_bias_nodes[0].input[0],
|
|
)
|
|
if new_node is None:
|
|
return
|
|
|
|
self.nodes_to_add.append(new_node)
|
|
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
|
|
|
|
# Add a transpose node after the attention node
|
|
back_transpose = helper.make_node(
|
|
"Transpose",
|
|
["back_transpose_in_" + new_node.name],
|
|
[new_node.output[0]],
|
|
"back_transpose_" + new_node.name,
|
|
perm=[1, 0, 2],
|
|
)
|
|
self.model.add_node(back_transpose, self.this_graph_name)
|
|
new_node.input[0] = transpose.input[0]
|
|
new_node.output[0] = "back_transpose_in_" + new_node.name
|
|
|
|
self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
|
|
self.nodes_to_remove.extend(qk_nodes)
|
|
self.nodes_to_remove.extend(q_nodes)
|
|
self.nodes_to_remove.extend(k_nodes)
|
|
self.nodes_to_remove.extend(v_nodes)
|
|
|
|
# Use prune graph to remove mask nodes since they are shared by all attention nodes.
|
|
# self.nodes_to_remove.extend(mask_nodes)
|
|
self.prune_graph = True
|
|
|
|
|
|
class TnlrOnnxModel(BertOnnxModel):
|
|
def __init__(self, model, num_heads, hidden_size):
|
|
super().__init__(model, num_heads, hidden_size)
|
|
self.attention_mask = AttentionMask(self)
|
|
self.attention_fusion = FusionTnlrAttention(self, self.hidden_size, self.num_heads, self.attention_mask)
|
|
|
|
def fuse_attention(self):
|
|
self.attention_fusion.apply()
|