# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- import argparse import logging import sys from collections import deque import numpy as np import onnx from onnx import ModelProto, TensorProto, numpy_helper from onnx_model_bert_tf import BertOnnxModelTF logger = logging.getLogger(__name__) class BertOnnxModelKeras(BertOnnxModelTF): def __init__(self, model, num_heads, hidden_size): super().__init__(model, num_heads, hidden_size) def match_mask_path(self, add_or_sub_before_softmax): mask_nodes = self.match_parent_path( add_or_sub_before_softmax, ["Mul", "Sub", "Reshape", "Cast"], [1, None, 1, 0], ) if mask_nodes is not None: return mask_nodes mask_nodes = self.match_parent_path( add_or_sub_before_softmax, ["Mul", "Sub", "Cast", "Slice", "Unsqueeze"], [1, 1, 1, 0, 0], ) if mask_nodes is not None: return mask_nodes mask_nodes = self.match_parent_path( add_or_sub_before_softmax, ["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"], [1, None, 1, 0, 0], ) return mask_nodes def check_attention_input(self, matmul_q, matmul_k, matmul_v, parent, output_name_to_node): reshape_nodes = [] for x in [matmul_q, matmul_k, matmul_v]: root_input = x.input[0] root_node = output_name_to_node[root_input] if root_node == parent: continue if root_node.op_type == "Reshape" and root_node.input[0] == parent.output[0]: reshape_nodes.append(root_node) continue logger.debug(f"Check attention input failed:{root_input}, {parent.output[0]}") return False, [] return True, reshape_nodes def fuse_attention(self): input_name_to_nodes = self.input_name_to_nodes() output_name_to_node = self.output_name_to_node() nodes_to_remove = [] attention_count = 0 skip_layer_norm_nodes = self.get_nodes_by_op_type("SkipLayerNormalization") for normalize_node in skip_layer_norm_nodes: # SkipLayerNormalization has two inputs, and one of them is the root input for attention. parent = self.get_parent(normalize_node, 0) if parent is None or parent.op_type not in [ "SkipLayerNormalization", "EmbedLayerNormalization", ]: if parent.op_type == "Add": parent = self.get_parent(normalize_node, 1) if parent is None or parent.op_type not in [ "SkipLayerNormalization", "EmbedLayerNormalization", ]: logger.debug( "First input for skiplayernorm: {}".format(parent.op_type if parent is not None else None) ) continue else: logger.debug( "First input for skiplayernorm: {}".format(parent.op_type if parent is not None else None) ) continue else: # TODO: shall we add back the checking of children op types. pass qkv_nodes = self.match_parent_path( normalize_node, ["Add", "Reshape", "MatMul", "Reshape", "Transpose", "MatMul"], [None, 0, 0, 0, 0, 0], ) if qkv_nodes is None: logger.debug("Failed to match qkv nodes") continue ( add, extra_reshape_0, matmul, reshape_qkv, transpose_qkv, matmul_qkv, ) = qkv_nodes logger.debug("Matched qkv nodes") v_nodes = self.match_parent_path( matmul_qkv, ["Transpose", "Reshape", "Add", "Reshape", "MatMul"], [1, 0, 0, 0, 0], ) if v_nodes is None: logger.debug("Failed to match v path") continue (transpose_v, reshape_v, add_v, extra_reshape_1, matmul_v) = v_nodes qk_nodes = self.match_parent_path(matmul_qkv, ["Softmax", "Sub", "MatMul"], [0, 0, 0]) if qk_nodes is not None: (softmax_qk, sub_qk, matmul_qk) = qk_nodes q_nodes = self.match_parent_path( matmul_qk, ["Mul", "Transpose", "Reshape", "Add", "Reshape", "MatMul"], [0, None, 0, 0, 0, 0], ) if q_nodes is not None: ( mul_q, transpose_q, reshape_q, add_q, extra_reshape_2, matmul_q, ) = q_nodes else: qk_nodes = self.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, None]) if qk_nodes is None: qk_nodes = self.match_parent_path(matmul_qkv, ["Softmax", "Add", "Div", "MatMul"], [0, 0, 0, None]) if qk_nodes is None: logger.debug("Failed to match qk path") continue (softmax_qk, add_qk, mul_qk, matmul_qk) = qk_nodes q_nodes = self.match_parent_path( matmul_qk, ["Transpose", "Reshape", "Add", "Reshape", "MatMul"], [0, 0, 0, 0, 0], ) if q_nodes is not None: (transpose_q, reshape_q, add_q, extra_reshape_2, matmul_q) = q_nodes if q_nodes is None: logger.debug("Failed to match q path") continue k_nodes = self.match_parent_path( matmul_qk, ["Transpose", "Reshape", "Add", "Reshape", "MatMul"], [1, 0, 0, 0, 0], ) if k_nodes is None: logger.debug("Failed to match k path") continue (transpose_k, reshape_k, add_k, extra_reshape_3, matmul_k) = k_nodes mask_nodes = self.match_mask_path(qk_nodes[1]) if mask_nodes is None: logger.debug("Failed to match mask path") continue if not self.has_constant_input(mask_nodes[1], 1): logger.debug("Sub node expected to have an input with constant value 1.0.") continue is_same_root, reshape_nodes = self.check_attention_input( matmul_q, matmul_k, matmul_v, parent, output_name_to_node ) if is_same_root: mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0]) logger.debug("Create an Attention node.") attention_node = self.attention_fusion.create_attention_node( mask_index, matmul_q, matmul_k, matmul_v, add_q, add_k, add_v, self.num_heads, self.hidden_size, parent.output[0], reshape_qkv.output[0], None, ) if attention_node is None: continue self.add_node(attention_node) attention_count += 1 nodes_to_remove.extend([reshape_qkv, transpose_qkv, matmul_qkv]) nodes_to_remove.extend(qk_nodes) nodes_to_remove.extend(q_nodes) nodes_to_remove.extend(k_nodes) nodes_to_remove.extend(v_nodes) nodes_to_remove.extend(mask_nodes) nodes_to_remove.extend(reshape_nodes) nodes_to_remove.append(extra_reshape_0) self.replace_node_input(add, extra_reshape_0.output[0], matmul.output[0]) else: logger.debug("Root node not matched.") continue self.remove_nodes(nodes_to_remove) self.update_graph() logger.info(f"Fused Attention count:{attention_count}") def preprocess(self): self.process_embedding() self.fuse_mask() self.skip_reshape() def skip_reshape(self): input_name_to_nodes = self.input_name_to_nodes() output_name_to_node = self.output_name_to_node() nodes_to_remove = [] attention_count = 0 count = 0 reshape_nodes = self.get_nodes_by_op_type("Reshape") for reshape_node in reshape_nodes: parent = self.get_parent(reshape_node, 0) if parent is not None and parent.op_type == "Reshape": reshape_node.input[0] = parent.input[0] count += 1 if count > 0: logger.info(f"Skip consequent Reshape count: {count}") def fuse_embedding(self, node, output_name_to_node): assert node.op_type == "LayerNormalization" logger.debug(f"start fusing embedding from node with output={node.output[0]}...") word_embed_path = self.match_parent_path(node, ["Add", "Add", "Gather"], [0, 0, 0], output_name_to_node) if word_embed_path is None: logger.debug("failed to match word_embed_path") return False skip_node, add_node, gather_node = word_embed_path word_initializer = self.get_initializer(gather_node.input[0]) if word_initializer is None: logger.debug("failed to get word initializer") return False temp = numpy_helper.to_array(word_initializer) if len(temp.shape) == 2: logger.info("Found word embedding. name:{}, shape:{}".format(word_initializer.name, temp.shape)) word_embedding = word_initializer.name else: logger.info("Failed to find word embedding. name:{}, shape:{}".format(word_initializer.name, temp.shape)) return False pos_initializer = self.get_initializer(add_node.input[1]) if pos_initializer is not None: temp = numpy_helper.to_array(pos_initializer) if len(temp.shape) == 3 and temp.shape[0] == 1: tensor = numpy_helper.from_array(temp.reshape((temp.shape[1], temp.shape[2])), "position_embedding") self.add_initializer(tensor) logger.info("Found position embedding. name:{}, shape:{}".format(pos_initializer.name, temp.shape[1:])) position_embedding = "position_embedding" else: logger.info( "Failed to find position embedding. name:{}, shape:{}".format(pos_initializer.name, temp.shape) ) return False else: pos_embed_path = self.match_parent_path(add_node, ["Gather", "Slice"], [1, 1], output_name_to_node) if pos_embed_path is None: logger.debug("failed to match pos_embed_path") return False pos_gather, pos_slice = pos_embed_path pos_initializer = self.get_initializer(pos_gather.input[0]) if pos_initializer is None: logger.debug("failed to get pos initializer") return False temp = numpy_helper.to_array(pos_initializer) if len(temp.shape) == 2: logger.info("Found word embedding. name:{}, shape:{}".format(pos_initializer.name, temp.shape)) position_embedding = pos_initializer.name else: logger.info( "Failed to find position embedding. name:{}, shape:{}".format(pos_initializer.name, temp.shape) ) return False gather = self.get_parent(skip_node, 1, output_name_to_node) if gather is None or gather.op_type != "Gather": logger.debug("failed to get gather") return False segment_initializer = self.get_initializer(gather.input[0]) if segment_initializer is None: logger.debug("failed to get segment initializer") return False temp = numpy_helper.to_array(segment_initializer) if len(temp.shape) == 2: logger.info("Found segment embedding. name:{}, shape:{}".format(segment_initializer.name, temp.shape)) segment_embedding = segment_initializer.name else: logger.info( "Failed to find segment embedding. name:{}, shape:{}".format(segment_initializer.name, temp.shape) ) return False logger.info("Create Embedding node") self.create_embedding_subgraph(node, word_embedding, segment_embedding, position_embedding) return True def process_embedding(self): """ Automatically detect word, segment and position embeddings. """ logger.info("start processing embedding layer...") output_name_to_node = self.output_name_to_node() for node in self.nodes(): if node.op_type == "LayerNormalization": if self.fuse_embedding(node, output_name_to_node): return break def fuse_mask(self): nodes_to_remove = [] for node in self.nodes(): if node.op_type == "Mul" and self.has_constant_input(node, -10000): mask_path = self.match_parent_path(node, ["Sub", "Cast", "Slice", "Unsqueeze"], [0, 1, 0, 0]) if mask_path is None: continue sub_node, cast_node, slice_node, unsqueeze_node = mask_path mask_input_name = self.attention_mask.get_first_mask() if unsqueeze_node.input[0] != mask_input_name: print("Cast input {} is not mask input {}".format(unsqueeze_node.input[0], mask_input_name)) continue unsqueeze_added_1 = onnx.helper.make_node( "Unsqueeze", inputs=[mask_input_name], outputs=["mask_fuse_unsqueeze1_output"], name="Mask_UnSqueeze_1", axes=[1], ) unsqueeze_added_2 = onnx.helper.make_node( "Unsqueeze", inputs=["mask_fuse_unsqueeze1_output"], outputs=["mask_fuse_unsqueeze2_output"], name="Mask_UnSqueeze_2", axes=[2], ) # self.replace_node_input(cast_node, cast_node.input[0], 'mask_fuse_unsqueeze2_output') cast_node_2 = onnx.helper.make_node( "Cast", inputs=["mask_fuse_unsqueeze2_output"], outputs=["mask_fuse_cast_output"], ) cast_node_2.attribute.extend([onnx.helper.make_attribute("to", 1)]) self.replace_node_input(sub_node, sub_node.input[1], "mask_fuse_cast_output") nodes_to_remove.extend([slice_node, unsqueeze_node, cast_node]) self.add_node(unsqueeze_added_1) self.add_node(unsqueeze_added_2) self.add_node(cast_node_2) self.remove_nodes(nodes_to_remove) # Prune graph is done after removing nodes to remove island nodes. if len(nodes_to_remove) > 0: self.prune_graph() logger.info("Fused mask" if len(nodes_to_remove) > 0 else "Failed to fuse mask") def remove_extra_reshape(self): skiplayernorm_nodes = self.get_nodes_by_op_type("SkipLayerNormalization") reshape_removed = 0 for skiplayernorm_node in skiplayernorm_nodes: path = self.match_parent_path( skiplayernorm_node, [ "Add", "Reshape", "MatMul", "Reshape", "Gelu", "Add", "Reshape", "MatMul", "SkipLayerNormalization", ], [0, 0, 0, 0, 0, 0, 0, 0, 0], ) if path is None: continue ( add_1, reshape_1, matmul_1, reshape_2, gelu, add_2, reshape_3, matmul_2, skiplayernorm, ) = path add_2.input[0] = matmul_2.output[0] self.remove_node(reshape_3) matmul_1.input[0] = gelu.output[0] self.remove_node(reshape_2) add_1.input[0] = matmul_1.output[0] self.remove_node(reshape_1) reshape_removed += 3 return reshape_removed def remove_extra_reshape_2(self): skiplayernorm_nodes = self.get_nodes_by_op_type("SkipLayerNormalization") reshape_removed = 0 for skiplayernorm_node in skiplayernorm_nodes: path = self.match_parent_path( skiplayernorm_node, [ "Add", "Reshape", "MatMul", "Reshape", "Gelu", "Add", "Reshape", "MatMul", "Reshape", "SkipLayerNormalization", ], [None, 0, 0, 0, 0, 0, 0, 0, 0, 0], ) # yapf: disable if path is None: continue ( add_1, reshape_1, matmul_1, reshape_2, gelu, add_2, reshape_3, matmul_2, reshape_4, skiplayernorm, ) = path matmul_2.input[0] = skiplayernorm.output[0] self.remove_node(reshape_4) add_2.input[0] = matmul_2.output[0] self.remove_node(reshape_3) matmul_1.input[0] = gelu.output[0] self.remove_node(reshape_2) add_1.input[0] = matmul_1.output[0] self.remove_node(reshape_1) reshape_removed += 4 return reshape_removed def postprocess(self): reshape_removed = self.remove_extra_reshape() + self.remove_extra_reshape_2() logger.info(f"Remove {reshape_removed} Reshape nodes.") self.prune_graph()