m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

485 lines
20 KiB

# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger
from typing import List, Optional
from fusion_attention import AttentionMask, FusionAttention
from fusion_biasgelu import FusionBiasGelu
from fusion_embedlayer import FusionEmbedLayerNormalization
from fusion_fastgelu import FusionFastGelu
from fusion_gelu import FusionGelu
from fusion_gelu_approximation import FusionGeluApproximation
from fusion_gemmfastgelu import FusionGemmFastGelu
from fusion_layernorm import FusionLayerNormalization, FusionLayerNormalizationTF
from fusion_options import AttentionMaskFormat, FusionOptions
from fusion_qordered_attention import FusionQOrderedAttention
from fusion_qordered_gelu import FusionQOrderedGelu
from fusion_qordered_layernorm import FusionQOrderedLayerNormalization
from fusion_qordered_matmul import FusionQOrderedMatMul
from fusion_reshape import FusionReshape
from fusion_shape import FusionShape
from fusion_skiplayernorm import FusionBiasSkipLayerNormalization, FusionSkipLayerNormalization
from fusion_utils import FusionUtils
from onnx import GraphProto, ModelProto, TensorProto, ValueInfoProto, helper
from onnx_model import OnnxModel
logger = getLogger(__name__)
class BertOptimizationOptions(FusionOptions):
"""This class is deprecated"""
def __init__(self, model_type):
logger.warning(f"BertOptimizationOptions is depreciated. Please use FusionOptions instead.")
super().__init__(model_type)
class BertOnnxModel(OnnxModel):
def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0):
"""Initialize BERT ONNX Model.
Args:
model (ModelProto): the ONNX model
num_heads (int, optional): number of attention heads. Defaults to 0 (detect the parameter automatically).
hidden_size (int, optional): hidden dimension. Defaults to 0 (detect the parameter automatically).
"""
assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0)
super().__init__(model)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.attention_mask = AttentionMask(self)
self.attention_fusion = FusionAttention(self, self.hidden_size, self.num_heads, self.attention_mask)
self.qordered_attention_fusion = FusionQOrderedAttention(
self, self.hidden_size, self.num_heads, self.attention_mask
)
self.utils = FusionUtils(self)
def fuse_attention(self):
self.attention_fusion.apply()
# Only relevant in models with Q-DQ nodes
self.qordered_attention_fusion.apply()
def fuse_gelu(self):
fusion = FusionGelu(self)
fusion.apply()
fusion = FusionFastGelu(self)
fusion.apply()
# Only relevant in models with Q-DQ nodes
fusion = FusionQOrderedGelu(self)
fusion.apply()
def fuse_bias_gelu(self, is_fastgelu):
fusion = FusionBiasGelu(self, is_fastgelu)
fusion.apply()
def gelu_approximation(self):
fusion = FusionGeluApproximation(self)
fusion.apply()
def fuse_gemm_fast_gelu(self):
fusion = FusionGemmFastGelu(self)
fusion.apply()
def fuse_add_bias_skip_layer_norm(self):
fusion = FusionBiasSkipLayerNormalization(self)
fusion.apply()
def fuse_reshape(self):
fusion = FusionReshape(self)
fusion.apply()
def fuse_shape(self):
fusion = FusionShape(self)
fusion.apply()
def fuse_embed_layer(self, use_mask_index):
fusion = FusionEmbedLayerNormalization(self, use_mask_index)
fusion.apply()
def fuse_layer_norm(self):
fusion = FusionLayerNormalization(self)
fusion.apply()
fusion = FusionLayerNormalizationTF(self)
fusion.apply()
# Only relevant in models with Q-DQ nodes
fusion = FusionQOrderedLayerNormalization(self)
fusion.apply()
def fuse_skip_layer_norm(self):
fusion = FusionSkipLayerNormalization(self)
fusion.apply()
# Only relevant in models with Q-DQ nodes
def fuse_qordered_mamtul(self):
fusion = FusionQOrderedMatMul(self)
fusion.apply()
def get_graph_inputs_from_node_type(self, op_type: str, input_indices: List[int], casted: bool):
"""
Get graph inputs that feed into node type (like EmbedLayerNormalization or Attention).
Returns a list of the graph input names based on the filter whether it is casted or not.
"""
graph_inputs = []
output_name_to_node = self.output_name_to_node()
nodes = self.get_nodes_by_op_type(op_type)
for node in nodes:
bert_inputs = [node.input[i] for i in input_indices if i < len(node.input)]
for bert_input in bert_inputs:
if self.find_graph_input(bert_input):
if not casted:
graph_inputs.append(bert_input)
elif bert_input in output_name_to_node:
parent = output_name_to_node[bert_input]
if parent.op_type == "Cast" and self.find_graph_input(parent.input[0]) is not None:
if casted:
graph_inputs.append(parent.input[0])
return graph_inputs
def get_graph_inputs_from_fused_nodes(self, casted: bool):
inputs = self.get_graph_inputs_from_node_type("EmbedLayerNormalization", [0, 1, 7], casted)
inputs += self.get_graph_inputs_from_node_type("Attention", [3], casted)
return inputs
def change_graph_input_type(
self,
graph: GraphProto,
graph_input: ValueInfoProto,
new_type: int = TensorProto.INT32,
):
"""Change graph input type, and add Cast node if needed.
Args:
graph (GraphProto): graph
graph_input (TensorProto): input of the graph
new_type (int, optional): new data type. Defaults to TensorProto.INT32.
Returns:
NodeProto: a new Cast node that added. None if Cast node is not added.
List[NodeProto]: Cast nodes that have been removed.
"""
assert isinstance(graph, GraphProto)
assert isinstance(graph_input, ValueInfoProto)
assert self.find_graph_input(graph_input.name)
if graph_input.type.tensor_type.elem_type == int(new_type):
return None, []
new_cast_node = None
nodes_to_remove = []
input_name_to_nodes = self.input_name_to_nodes()
if graph_input.name in input_name_to_nodes:
nodes = input_name_to_nodes[graph_input.name]
# For children that is not Cast node, insert a Cast node to convert int32 to original data type.
nodes_not_cast = [node for node in nodes if node.op_type != "Cast"]
if nodes_not_cast:
node_name = self.create_node_name("Cast")
output_name = node_name + "_" + graph_input.name
new_value_info = graph.value_info.add()
new_value_info.CopyFrom(graph_input)
new_value_info.name = output_name
new_cast_node = helper.make_node(
"Cast",
[graph_input.name],
[output_name],
to=int(graph_input.type.tensor_type.elem_type),
name=node_name,
)
graph.node.extend([new_cast_node])
for node in nodes_not_cast:
OnnxModel.replace_node_input(node, graph_input.name, output_name)
# For children that is Cast node, no need to insert Cast.
# When the children is Cast to int32, we can remove that Cast node since input type is int32 now.
nodes_cast = [node for node in nodes if node.op_type == "Cast"]
for node in nodes_cast:
if OnnxModel.get_node_attribute(node, "to") == int(new_type):
self.replace_input_of_all_nodes(node.output[0], graph_input.name)
if not self.find_graph_output(node.output[0]):
nodes_to_remove.append(node)
if nodes_to_remove:
self.remove_nodes(nodes_to_remove)
graph_input.type.tensor_type.elem_type = int(new_type)
return new_cast_node, nodes_to_remove
def change_graph_inputs_to_int32(self):
"""Change data type of all graph inputs to int32 type, and add Cast node if needed."""
graph = self.graph()
add_cast_count = 0
remove_cast_count = 0
for graph_input in graph.input:
new_node, removed_nodes = self.change_graph_input_type(graph, graph_input, TensorProto.INT32)
if new_node:
add_cast_count += 1
remove_cast_count += len(removed_nodes)
logger.info(
f"Graph inputs are changed to int32. Added {add_cast_count} Cast nodes, and removed {remove_cast_count} Cast nodes."
)
def use_dynamic_axes(self, dynamic_batch_dim="batch_size", dynamic_seq_len="max_seq_len"):
"""
Update input and output shape to use dynamic axes.
"""
bert_graph_inputs = self.get_graph_inputs_from_fused_nodes(
casted=True
) + self.get_graph_inputs_from_fused_nodes(casted=False)
dynamic_batch_inputs = {}
for input in self.model.graph.input:
if input.name in bert_graph_inputs:
dim_proto = input.type.tensor_type.shape.dim[0]
dim_proto.dim_param = dynamic_batch_dim
if dynamic_seq_len is not None:
dim_proto = input.type.tensor_type.shape.dim[1]
dim_proto.dim_param = dynamic_seq_len
for output in self.model.graph.output:
dim_proto = output.type.tensor_type.shape.dim[0]
dim_proto.dim_param = dynamic_batch_dim
def preprocess(self):
self.adjust_reshape_and_expand()
return
def adjust_reshape_and_expand(self):
nodes_to_remove = []
for node in self.nodes():
if node.op_type == "Reshape":
# Clean up unneccessary reshape nodes.
# Find reshape nodes with no actually data in "shape" attribute and remove.
reshape_shape = self.get_constant_value(node.input[1])
if reshape_shape is not None and reshape_shape.size == 0:
nodes_to_remove.extend([node])
self.replace_input_of_all_nodes(node.output[0], node.input[0])
continue
# Find path "Slice" -> "Reshape" -> "Expand" -> "Expand" -> current "Reshape", simplify the graph by
# changing current reshape's input to output of slice.
reshape_path = self.match_parent_path(
node,
["Expand", "Expand", "Reshape", "Slice"],
[0, 0, 0, 0],
self.output_name_to_node(),
)
if reshape_path is not None:
expand_node = reshape_path[-3]
expand_shape_value = self.get_constant_value(expand_node.input[1])
reshape_before_expand = reshape_path[-2]
shape_value = self.get_constant_value(reshape_before_expand.input[1])
slice_node = reshape_path[-1]
if (
expand_shape_value is not None
and shape_value is not None
and len(expand_shape_value) == 2
and len(shape_value) == 1
and expand_shape_value[1] == shape_value[0]
):
node.input[0] = slice_node.output[0]
if nodes_to_remove:
self.remove_nodes(nodes_to_remove)
logger.info(f"Removed Reshape and Expand count: {len(nodes_to_remove)}")
def clean_graph(self):
output_name_to_node = self.output_name_to_node()
nodes_to_remove = []
for node in self.nodes():
# Before:
# input_ids --> Shape --> Gather(indices=0) --> Unsqueeze ------+
# | |
# | v
# +----> Shape --> Gather(indices=1) --> Unsqueeze---> Concat --> ConstantOfShape -->Cast --> EmbedLayerNormaliation/ReduceSum
# After:
# input_ids --> Shape --> ConstantOfShape -->Cast --> EmbedLayerNormaliation/ReduceSum
# TODO: merge ConstantOfShape -->Cast to ConstantOfShape (need update the data type of value)
op_input_id = {"EmbedLayerNormalization": 1, "ReduceSum": 0, "Attention": 3}
if node.op_type in op_input_id:
i = op_input_id[node.op_type]
parent_nodes = self.match_parent_path(
node,
[
"Cast",
"ConstantOfShape",
"Concat",
"Unsqueeze",
"Gather",
"Shape",
],
[i, 0, 0, 0, 0, 0],
output_name_to_node,
)
if parent_nodes is not None:
(
cast,
constantOfShape,
concat,
unsqueeze,
gather,
shape,
) = parent_nodes
if shape.input[0] == self.graph().input[0].name:
constantOfShape.input[0] = shape.output[0]
output_name_to_node = self.output_name_to_node()
if node.op_type == "Attention":
# Before:
# input_ids --> Shape -->ConstantOfShape -->Cast --> ReduceSum --> Attention
# After:
# remove this path, and remove the optional mask_index input of Attention node.
parent_nodes = self.match_parent_path(
node,
["ReduceSum", "Cast", "ConstantOfShape", "Shape"],
[3, 0, 0, 0],
output_name_to_node,
)
if parent_nodes is not None:
if parent_nodes[-1].input[0] == self.graph().input[0].name:
attention_node = helper.make_node(
"Attention",
inputs=node.input[0 : len(node.input) - 1],
outputs=node.output,
name=node.name + "_remove_mask",
)
attention_node.domain = "com.microsoft"
attention_node.attribute.extend([helper.make_attribute("num_heads", self.num_heads)])
self.add_node(attention_node, self.get_graph_by_node(attention_node).name)
nodes_to_remove.append(node)
self.remove_nodes(nodes_to_remove)
def postprocess(self):
self.clean_graph()
self.prune_graph()
def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False):
if (options is not None) and not options.enable_shape_inference:
self.disable_shape_inference()
self.utils.remove_identity_nodes()
# Remove cast nodes that having same data type of input and output based on symbolic shape inference.
self.utils.remove_useless_cast_nodes()
if (options is None) or options.enable_layer_norm:
self.fuse_layer_norm()
if (options is None) or options.enable_gelu:
self.fuse_gelu()
self.preprocess()
self.fuse_reshape()
if (options is None) or options.enable_skip_layer_norm:
self.fuse_skip_layer_norm()
if options is not None:
self.attention_mask.set_mask_format(options.attention_mask_format)
if options.use_multi_head_attention:
self.attention_fusion = FusionAttention(
self, self.hidden_size, self.num_heads, self.attention_mask, options.use_multi_head_attention
)
if (options is None) or options.enable_attention:
self.fuse_attention()
# Perform the MatMul fusion after the Attention fusion as we do not
# want to fuse the MatMuls inside the Attention subgraphs
if (options is None) or options.enable_qordered_matmul:
self.fuse_qordered_mamtul()
self.fuse_shape()
if (options is None) or options.enable_embed_layer_norm:
use_mask_index = options.attention_mask_format == AttentionMaskFormat.MaskIndexEnd
self.fuse_embed_layer(use_mask_index)
# Remove reshape nodes that having same shape of input and output based on symbolic shape inference.
self.utils.remove_useless_reshape_nodes()
self.postprocess()
# Bias fusion is done after postprocess to avoid extra Reshape between bias and Gelu/FastGelu/SkipLayerNormalization
if (options is None) or options.enable_bias_gelu:
# Fuse Gelu and Add Bias before it.
self.fuse_bias_gelu(is_fastgelu=True)
self.fuse_bias_gelu(is_fastgelu=False)
if (options is None) or options.enable_bias_skip_layer_norm:
# Fuse SkipLayerNormalization and Add Bias before it.
self.fuse_add_bias_skip_layer_norm()
if options is not None and options.enable_gelu_approximation:
self.gelu_approximation()
if options is not None and options.enable_gemm_fast_gelu:
self.fuse_gemm_fast_gelu()
self.remove_unused_constant()
# Use symbolic batch dimension in input and output.
if add_dynamic_axes:
self.use_dynamic_axes()
logger.info(f"opset version: {self.get_opset_version()}")
def get_fused_operator_statistics(self):
"""
Returns node count of fused operators.
"""
op_count = {}
ops = [
"EmbedLayerNormalization",
"Attention",
"MultiHeadAttention",
"Gelu",
"FastGelu",
"BiasGelu",
"GemmFastGelu",
"LayerNormalization",
"SkipLayerNormalization",
]
q_ops = ["QOrderedAttention", "QOrderedGelu", "QOrderedLayerNormalization", "QOrderedMatMul"]
for op in ops + q_ops:
nodes = self.get_nodes_by_op_type(op)
op_count[op] = len(nodes)
logger.info(f"Optimized operators:{op_count}")
return op_count
def is_fully_optimized(self):
"""
Returns True when the model is fully optimized.
"""
op_count = self.get_fused_operator_statistics()
embed = op_count["EmbedLayerNormalization"]
attention = op_count["Attention"] + op_count["MultiHeadAttention"] + op_count["QOrderedAttention"]
gelu = op_count["Gelu"] + op_count["BiasGelu"] + op_count["FastGelu"]
layer_norm = op_count["LayerNormalization"] + op_count["SkipLayerNormalization"]
is_perfect = (embed > 0) and (attention > 0) and (attention == gelu) and (layer_norm >= 2 * attention)
if layer_norm == 0:
logger.debug("Layer Normalization not fused")
if gelu == 0:
logger.debug("Gelu/FastGelu not fused")
if embed == 0:
logger.debug("Embed Layer not fused")
if attention == 0:
logger.warning("Attention not fused")
return is_perfect