# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger from typing import List, Optional from fusion_attention import AttentionMask, FusionAttention from fusion_biasgelu import FusionBiasGelu from fusion_embedlayer import FusionEmbedLayerNormalization from fusion_fastgelu import FusionFastGelu from fusion_gelu import FusionGelu from fusion_gelu_approximation import FusionGeluApproximation from fusion_gemmfastgelu import FusionGemmFastGelu from fusion_layernorm import FusionLayerNormalization, FusionLayerNormalizationTF from fusion_options import AttentionMaskFormat, FusionOptions from fusion_qordered_attention import FusionQOrderedAttention from fusion_qordered_gelu import FusionQOrderedGelu from fusion_qordered_layernorm import FusionQOrderedLayerNormalization from fusion_qordered_matmul import FusionQOrderedMatMul from fusion_reshape import FusionReshape from fusion_shape import FusionShape from fusion_skiplayernorm import FusionBiasSkipLayerNormalization, FusionSkipLayerNormalization from fusion_utils import FusionUtils from onnx import GraphProto, ModelProto, TensorProto, ValueInfoProto, helper from onnx_model import OnnxModel logger = getLogger(__name__) class BertOptimizationOptions(FusionOptions): """This class is deprecated""" def __init__(self, model_type): logger.warning(f"BertOptimizationOptions is depreciated. Please use FusionOptions instead.") super().__init__(model_type) class BertOnnxModel(OnnxModel): def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0): """Initialize BERT ONNX Model. Args: model (ModelProto): the ONNX model num_heads (int, optional): number of attention heads. Defaults to 0 (detect the parameter automatically). hidden_size (int, optional): hidden dimension. Defaults to 0 (detect the parameter automatically). """ assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0) super().__init__(model) self.num_heads = num_heads self.hidden_size = hidden_size self.attention_mask = AttentionMask(self) self.attention_fusion = FusionAttention(self, self.hidden_size, self.num_heads, self.attention_mask) self.qordered_attention_fusion = FusionQOrderedAttention( self, self.hidden_size, self.num_heads, self.attention_mask ) self.utils = FusionUtils(self) def fuse_attention(self): self.attention_fusion.apply() # Only relevant in models with Q-DQ nodes self.qordered_attention_fusion.apply() def fuse_gelu(self): fusion = FusionGelu(self) fusion.apply() fusion = FusionFastGelu(self) fusion.apply() # Only relevant in models with Q-DQ nodes fusion = FusionQOrderedGelu(self) fusion.apply() def fuse_bias_gelu(self, is_fastgelu): fusion = FusionBiasGelu(self, is_fastgelu) fusion.apply() def gelu_approximation(self): fusion = FusionGeluApproximation(self) fusion.apply() def fuse_gemm_fast_gelu(self): fusion = FusionGemmFastGelu(self) fusion.apply() def fuse_add_bias_skip_layer_norm(self): fusion = FusionBiasSkipLayerNormalization(self) fusion.apply() def fuse_reshape(self): fusion = FusionReshape(self) fusion.apply() def fuse_shape(self): fusion = FusionShape(self) fusion.apply() def fuse_embed_layer(self, use_mask_index): fusion = FusionEmbedLayerNormalization(self, use_mask_index) fusion.apply() def fuse_layer_norm(self): fusion = FusionLayerNormalization(self) fusion.apply() fusion = FusionLayerNormalizationTF(self) fusion.apply() # Only relevant in models with Q-DQ nodes fusion = FusionQOrderedLayerNormalization(self) fusion.apply() def fuse_skip_layer_norm(self): fusion = FusionSkipLayerNormalization(self) fusion.apply() # Only relevant in models with Q-DQ nodes def fuse_qordered_mamtul(self): fusion = FusionQOrderedMatMul(self) fusion.apply() def get_graph_inputs_from_node_type(self, op_type: str, input_indices: List[int], casted: bool): """ Get graph inputs that feed into node type (like EmbedLayerNormalization or Attention). Returns a list of the graph input names based on the filter whether it is casted or not. """ graph_inputs = [] output_name_to_node = self.output_name_to_node() nodes = self.get_nodes_by_op_type(op_type) for node in nodes: bert_inputs = [node.input[i] for i in input_indices if i < len(node.input)] for bert_input in bert_inputs: if self.find_graph_input(bert_input): if not casted: graph_inputs.append(bert_input) elif bert_input in output_name_to_node: parent = output_name_to_node[bert_input] if parent.op_type == "Cast" and self.find_graph_input(parent.input[0]) is not None: if casted: graph_inputs.append(parent.input[0]) return graph_inputs def get_graph_inputs_from_fused_nodes(self, casted: bool): inputs = self.get_graph_inputs_from_node_type("EmbedLayerNormalization", [0, 1, 7], casted) inputs += self.get_graph_inputs_from_node_type("Attention", [3], casted) return inputs def change_graph_input_type( self, graph: GraphProto, graph_input: ValueInfoProto, new_type: int = TensorProto.INT32, ): """Change graph input type, and add Cast node if needed. Args: graph (GraphProto): graph graph_input (TensorProto): input of the graph new_type (int, optional): new data type. Defaults to TensorProto.INT32. Returns: NodeProto: a new Cast node that added. None if Cast node is not added. List[NodeProto]: Cast nodes that have been removed. """ assert isinstance(graph, GraphProto) assert isinstance(graph_input, ValueInfoProto) assert self.find_graph_input(graph_input.name) if graph_input.type.tensor_type.elem_type == int(new_type): return None, [] new_cast_node = None nodes_to_remove = [] input_name_to_nodes = self.input_name_to_nodes() if graph_input.name in input_name_to_nodes: nodes = input_name_to_nodes[graph_input.name] # For children that is not Cast node, insert a Cast node to convert int32 to original data type. nodes_not_cast = [node for node in nodes if node.op_type != "Cast"] if nodes_not_cast: node_name = self.create_node_name("Cast") output_name = node_name + "_" + graph_input.name new_value_info = graph.value_info.add() new_value_info.CopyFrom(graph_input) new_value_info.name = output_name new_cast_node = helper.make_node( "Cast", [graph_input.name], [output_name], to=int(graph_input.type.tensor_type.elem_type), name=node_name, ) graph.node.extend([new_cast_node]) for node in nodes_not_cast: OnnxModel.replace_node_input(node, graph_input.name, output_name) # For children that is Cast node, no need to insert Cast. # When the children is Cast to int32, we can remove that Cast node since input type is int32 now. nodes_cast = [node for node in nodes if node.op_type == "Cast"] for node in nodes_cast: if OnnxModel.get_node_attribute(node, "to") == int(new_type): self.replace_input_of_all_nodes(node.output[0], graph_input.name) if not self.find_graph_output(node.output[0]): nodes_to_remove.append(node) if nodes_to_remove: self.remove_nodes(nodes_to_remove) graph_input.type.tensor_type.elem_type = int(new_type) return new_cast_node, nodes_to_remove def change_graph_inputs_to_int32(self): """Change data type of all graph inputs to int32 type, and add Cast node if needed.""" graph = self.graph() add_cast_count = 0 remove_cast_count = 0 for graph_input in graph.input: new_node, removed_nodes = self.change_graph_input_type(graph, graph_input, TensorProto.INT32) if new_node: add_cast_count += 1 remove_cast_count += len(removed_nodes) logger.info( f"Graph inputs are changed to int32. Added {add_cast_count} Cast nodes, and removed {remove_cast_count} Cast nodes." ) def use_dynamic_axes(self, dynamic_batch_dim="batch_size", dynamic_seq_len="max_seq_len"): """ Update input and output shape to use dynamic axes. """ bert_graph_inputs = self.get_graph_inputs_from_fused_nodes( casted=True ) + self.get_graph_inputs_from_fused_nodes(casted=False) dynamic_batch_inputs = {} for input in self.model.graph.input: if input.name in bert_graph_inputs: dim_proto = input.type.tensor_type.shape.dim[0] dim_proto.dim_param = dynamic_batch_dim if dynamic_seq_len is not None: dim_proto = input.type.tensor_type.shape.dim[1] dim_proto.dim_param = dynamic_seq_len for output in self.model.graph.output: dim_proto = output.type.tensor_type.shape.dim[0] dim_proto.dim_param = dynamic_batch_dim def preprocess(self): self.adjust_reshape_and_expand() return def adjust_reshape_and_expand(self): nodes_to_remove = [] for node in self.nodes(): if node.op_type == "Reshape": # Clean up unneccessary reshape nodes. # Find reshape nodes with no actually data in "shape" attribute and remove. reshape_shape = self.get_constant_value(node.input[1]) if reshape_shape is not None and reshape_shape.size == 0: nodes_to_remove.extend([node]) self.replace_input_of_all_nodes(node.output[0], node.input[0]) continue # Find path "Slice" -> "Reshape" -> "Expand" -> "Expand" -> current "Reshape", simplify the graph by # changing current reshape's input to output of slice. reshape_path = self.match_parent_path( node, ["Expand", "Expand", "Reshape", "Slice"], [0, 0, 0, 0], self.output_name_to_node(), ) if reshape_path is not None: expand_node = reshape_path[-3] expand_shape_value = self.get_constant_value(expand_node.input[1]) reshape_before_expand = reshape_path[-2] shape_value = self.get_constant_value(reshape_before_expand.input[1]) slice_node = reshape_path[-1] if ( expand_shape_value is not None and shape_value is not None and len(expand_shape_value) == 2 and len(shape_value) == 1 and expand_shape_value[1] == shape_value[0] ): node.input[0] = slice_node.output[0] if nodes_to_remove: self.remove_nodes(nodes_to_remove) logger.info(f"Removed Reshape and Expand count: {len(nodes_to_remove)}") def clean_graph(self): output_name_to_node = self.output_name_to_node() nodes_to_remove = [] for node in self.nodes(): # Before: # input_ids --> Shape --> Gather(indices=0) --> Unsqueeze ------+ # | | # | v # +----> Shape --> Gather(indices=1) --> Unsqueeze---> Concat --> ConstantOfShape -->Cast --> EmbedLayerNormaliation/ReduceSum # After: # input_ids --> Shape --> ConstantOfShape -->Cast --> EmbedLayerNormaliation/ReduceSum # TODO: merge ConstantOfShape -->Cast to ConstantOfShape (need update the data type of value) op_input_id = {"EmbedLayerNormalization": 1, "ReduceSum": 0, "Attention": 3} if node.op_type in op_input_id: i = op_input_id[node.op_type] parent_nodes = self.match_parent_path( node, [ "Cast", "ConstantOfShape", "Concat", "Unsqueeze", "Gather", "Shape", ], [i, 0, 0, 0, 0, 0], output_name_to_node, ) if parent_nodes is not None: ( cast, constantOfShape, concat, unsqueeze, gather, shape, ) = parent_nodes if shape.input[0] == self.graph().input[0].name: constantOfShape.input[0] = shape.output[0] output_name_to_node = self.output_name_to_node() if node.op_type == "Attention": # Before: # input_ids --> Shape -->ConstantOfShape -->Cast --> ReduceSum --> Attention # After: # remove this path, and remove the optional mask_index input of Attention node. parent_nodes = self.match_parent_path( node, ["ReduceSum", "Cast", "ConstantOfShape", "Shape"], [3, 0, 0, 0], output_name_to_node, ) if parent_nodes is not None: if parent_nodes[-1].input[0] == self.graph().input[0].name: attention_node = helper.make_node( "Attention", inputs=node.input[0 : len(node.input) - 1], outputs=node.output, name=node.name + "_remove_mask", ) attention_node.domain = "com.microsoft" attention_node.attribute.extend([helper.make_attribute("num_heads", self.num_heads)]) self.add_node(attention_node, self.get_graph_by_node(attention_node).name) nodes_to_remove.append(node) self.remove_nodes(nodes_to_remove) def postprocess(self): self.clean_graph() self.prune_graph() def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False): if (options is not None) and not options.enable_shape_inference: self.disable_shape_inference() self.utils.remove_identity_nodes() # Remove cast nodes that having same data type of input and output based on symbolic shape inference. self.utils.remove_useless_cast_nodes() if (options is None) or options.enable_layer_norm: self.fuse_layer_norm() if (options is None) or options.enable_gelu: self.fuse_gelu() self.preprocess() self.fuse_reshape() if (options is None) or options.enable_skip_layer_norm: self.fuse_skip_layer_norm() if options is not None: self.attention_mask.set_mask_format(options.attention_mask_format) if options.use_multi_head_attention: self.attention_fusion = FusionAttention( self, self.hidden_size, self.num_heads, self.attention_mask, options.use_multi_head_attention ) if (options is None) or options.enable_attention: self.fuse_attention() # Perform the MatMul fusion after the Attention fusion as we do not # want to fuse the MatMuls inside the Attention subgraphs if (options is None) or options.enable_qordered_matmul: self.fuse_qordered_mamtul() self.fuse_shape() if (options is None) or options.enable_embed_layer_norm: use_mask_index = options.attention_mask_format == AttentionMaskFormat.MaskIndexEnd self.fuse_embed_layer(use_mask_index) # Remove reshape nodes that having same shape of input and output based on symbolic shape inference. self.utils.remove_useless_reshape_nodes() self.postprocess() # Bias fusion is done after postprocess to avoid extra Reshape between bias and Gelu/FastGelu/SkipLayerNormalization if (options is None) or options.enable_bias_gelu: # Fuse Gelu and Add Bias before it. self.fuse_bias_gelu(is_fastgelu=True) self.fuse_bias_gelu(is_fastgelu=False) if (options is None) or options.enable_bias_skip_layer_norm: # Fuse SkipLayerNormalization and Add Bias before it. self.fuse_add_bias_skip_layer_norm() if options is not None and options.enable_gelu_approximation: self.gelu_approximation() if options is not None and options.enable_gemm_fast_gelu: self.fuse_gemm_fast_gelu() self.remove_unused_constant() # Use symbolic batch dimension in input and output. if add_dynamic_axes: self.use_dynamic_axes() logger.info(f"opset version: {self.get_opset_version()}") def get_fused_operator_statistics(self): """ Returns node count of fused operators. """ op_count = {} ops = [ "EmbedLayerNormalization", "Attention", "MultiHeadAttention", "Gelu", "FastGelu", "BiasGelu", "GemmFastGelu", "LayerNormalization", "SkipLayerNormalization", ] q_ops = ["QOrderedAttention", "QOrderedGelu", "QOrderedLayerNormalization", "QOrderedMatMul"] for op in ops + q_ops: nodes = self.get_nodes_by_op_type(op) op_count[op] = len(nodes) logger.info(f"Optimized operators:{op_count}") return op_count def is_fully_optimized(self): """ Returns True when the model is fully optimized. """ op_count = self.get_fused_operator_statistics() embed = op_count["EmbedLayerNormalization"] attention = op_count["Attention"] + op_count["MultiHeadAttention"] + op_count["QOrderedAttention"] gelu = op_count["Gelu"] + op_count["BiasGelu"] + op_count["FastGelu"] layer_norm = op_count["LayerNormalization"] + op_count["SkipLayerNormalization"] is_perfect = (embed > 0) and (attention > 0) and (attention == gelu) and (layer_norm >= 2 * attention) if layer_norm == 0: logger.debug("Layer Normalization not fused") if gelu == 0: logger.debug("Gelu/FastGelu not fused") if embed == 0: logger.debug("Embed Layer not fused") if attention == 0: logger.warning("Attention not fused") return is_perfect