m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

230 lines
8.2 KiB

# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from typing import Union
from fusion_attention import AttentionMask, FusionAttention
from fusion_utils import NumpyHelper
from onnx import NodeProto, TensorProto, helper, numpy_helper
from onnx_model import OnnxModel
from onnx_model_bert import BertOnnxModel
logger = logging.getLogger(__name__)
class FusionTnlrAttention(FusionAttention):
"""
Fuse TNLR Attention subgraph into one Attention node.
TNLR Attention has extra addtion after qk nodes and adopts [S, B, NH] as I/O shape.
"""
def __init__(
self,
model: OnnxModel,
hidden_size: int,
num_heads: int,
attention_mask: AttentionMask,
):
super().__init__(model, hidden_size, num_heads, attention_mask)
def create_attention_node(
self,
mask_index: str,
matmul: NodeProto,
add: NodeProto,
num_heads: int,
hidden_size: int,
input: str,
output: str,
add_qk_str: str,
) -> Union[NodeProto, None]:
assert num_heads > 0
if hidden_size > 0 and (hidden_size % num_heads) != 0:
logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
return None
weight = self.model.get_initializer(matmul.input[1])
bias = self.model.get_initializer(add.input[1]) or self.model.get_initializer(add.input[0])
if weight is None or bias is None:
return None
qkv_weight = NumpyHelper.to_array(weight)
qkv_bias = NumpyHelper.to_array(bias)
attention_node_name = self.model.create_node_name("Attention")
weight = helper.make_tensor(
name=attention_node_name + "_qkv_weight",
data_type=TensorProto.FLOAT,
dims=[hidden_size, 3 * hidden_size],
vals=qkv_weight.flatten().tolist(),
)
# Sometimes weights and bias are stored in fp16
if weight.data_type == 10:
weight.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(weight).astype(np.float16), weight.name))
self.model.add_initializer(weight, self.this_graph_name)
bias = helper.make_tensor(
name=attention_node_name + "_qkv_bias",
data_type=TensorProto.FLOAT,
dims=[3 * hidden_size],
vals=qkv_bias.flatten().tolist(),
)
if bias.data_type == 10:
bias.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(bias).astype(np.float16), bias.name))
self.model.add_initializer(bias, self.this_graph_name)
attention_inputs = [
input,
attention_node_name + "_qkv_weight",
attention_node_name + "_qkv_bias",
]
if mask_index is not None:
attention_inputs.append(mask_index)
else:
attention_inputs.append("")
if add_qk_str is not None:
attention_inputs.append("")
attention_inputs.append(add_qk_str)
attention_node = helper.make_node(
"Attention",
inputs=attention_inputs,
outputs=[output],
name=attention_node_name,
)
attention_node.domain = "com.microsoft"
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
return attention_node
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
# Sometimes we can not fuse skiplayernormalization since the add before layernorm has an output that used by nodes outside skiplayernorm
# Conceptually we treat add before layernorm as skiplayernorm node since they share the same pattern
start_node = normalize_node
if normalize_node.op_type != "SkipLayerNormalization":
return
# SkipLayerNormalization has two inputs, and one of them is the root input for attention.
qkv_nodes = self.model.match_parent_path(
start_node,
["Where", "Add", "MatMul", "Reshape", "Transpose", "MatMul"],
[1, 1, 1, 0, 0, 0],
)
if qkv_nodes is not None:
(_, _, matmul_below, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes
else:
return
other_inputs = []
for i, input in enumerate(start_node.input):
if input not in output_name_to_node:
continue
if input == qkv_nodes[0].output[0]:
continue
other_inputs.append(input)
if len(other_inputs) != 1:
return
root_input = other_inputs[0]
v_nodes = self.model.match_parent_path(
matmul_qkv,
["Transpose", "Reshape", "Slice", "Add", "MatMul"],
[1, 0, 0, 0, 1],
)
if v_nodes is None:
return
(_, _, _, add, matmul) = v_nodes
upper_nodes = self.model.match_parent_path(matmul, ["Transpose"], [0])
transpose = upper_nodes[0]
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0])
if qk_nodes is None:
return
(_, add_qk, matmul_qk) = qk_nodes
q_nodes = self.model.match_parent_path(
matmul_qk,
["Mul", "Transpose", "Reshape", "Slice", "Add", "MatMul"],
[0, 0, 0, 0, 0, 1],
)
if q_nodes is None:
return
add = q_nodes[-2]
matmul = q_nodes[-1]
k_nodes = self.model.match_parent_path(
matmul_qk,
["Transpose", "Reshape", "Slice", "Add", "MatMul"],
[1, 0, 0, 0, 1],
)
if k_nodes is None:
return
add = k_nodes[-2]
matmul = k_nodes[-1]
relative_position_bias_nodes = self.model.match_parent_path(add_qk, ["Reshape", "Where"], [1, 0])
if relative_position_bias_nodes is None:
return
if matmul.input[0] == root_input:
mask_index = None
attention_last_node = reshape_qkv
# number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
# the input_hidden_size represents the input hidden size, this is used as needed but hidden sizes for Q, K are extracted appropriately
new_node = self.create_attention_node(
mask_index,
matmul,
add,
self.num_heads,
self.hidden_size,
root_input,
attention_last_node.output[0],
relative_position_bias_nodes[0].input[0],
)
if new_node is None:
return
self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
# Add a transpose node after the attention node
back_transpose = helper.make_node(
"Transpose",
["back_transpose_in_" + new_node.name],
[new_node.output[0]],
"back_transpose_" + new_node.name,
perm=[1, 0, 2],
)
self.model.add_node(back_transpose, self.this_graph_name)
new_node.input[0] = transpose.input[0]
new_node.output[0] = "back_transpose_in_" + new_node.name
self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
self.nodes_to_remove.extend(qk_nodes)
self.nodes_to_remove.extend(q_nodes)
self.nodes_to_remove.extend(k_nodes)
self.nodes_to_remove.extend(v_nodes)
# Use prune graph to remove mask nodes since they are shared by all attention nodes.
# self.nodes_to_remove.extend(mask_nodes)
self.prune_graph = True
class TnlrOnnxModel(BertOnnxModel):
def __init__(self, model, num_heads, hidden_size):
super().__init__(model, num_heads, hidden_size)
self.attention_mask = AttentionMask(self)
self.attention_fusion = FusionTnlrAttention(self, self.hidden_size, self.num_heads, self.attention_mask)
def fuse_attention(self):
self.attention_fusion.apply()