# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- from argparse import ArgumentParser class AttentionMaskFormat: # Build 1D mask indice (sequence length). It requires right side padding! Recommended for BERT model to get best performance. MaskIndexEnd = 0 # For experiment only. Do not use it in production. MaskIndexEndAndStart = 1 # Raw attention mask with 0 means padding (or no attention) and 1 otherwise. AttentionMask = 2 # No attention mask NoMask = 3 class FusionOptions: """Options of fusion in graph optimization""" def __init__(self, model_type): self.enable_gelu = True self.enable_layer_norm = True self.enable_attention = True # Use MultiHeadAttention instead of Attention operator. The difference: # (1) Attention has merged weights for Q/K/V projection, which might be faster in some cases since 3 MatMul is # merged into one. # (2) Attention could only handle self attention; MultiHeadAttention could handle both self and cross attention. # (3) MultiHeadAttention has only cuda implementation right now. self.use_multi_head_attention = False self.enable_skip_layer_norm = True self.enable_embed_layer_norm = True self.enable_bias_skip_layer_norm = True self.enable_bias_gelu = True self.enable_gelu_approximation = False self.enable_qordered_matmul = True self.enable_shape_inference = True self.enable_gemm_fast_gelu = False # Set default to sequence length for BERT model to use fused attention to speed up. # Note that embed layer normalization will convert 2D mask to 1D when mask type is MaskIndexEnd. self.attention_mask_format = ( AttentionMaskFormat.MaskIndexEnd if model_type == "bert" else AttentionMaskFormat.AttentionMask ) # options for stable diffusion self.enable_group_norm = model_type == "unet" self.enable_bias_splitgelu = model_type == "unet" self.enable_packed_kv = model_type == "unet" def use_raw_attention_mask(self, use_raw_mask=True): if use_raw_mask: self.attention_mask_format = AttentionMaskFormat.AttentionMask else: self.attention_mask_format = AttentionMaskFormat.MaskIndexEnd def disable_attention_mask(self): self.attention_mask_format = AttentionMaskFormat.NoMask @staticmethod def parse(args): options = FusionOptions(args.model_type) if args.disable_gelu: options.enable_gelu = False if args.disable_layer_norm: options.enable_layer_norm = False if args.disable_attention: options.enable_attention = False if args.use_multi_head_attention: options.use_multi_head_attention = True if args.disable_skip_layer_norm: options.enable_skip_layer_norm = False if args.disable_embed_layer_norm: options.enable_embed_layer_norm = False if args.disable_bias_skip_layer_norm: options.enable_bias_skip_layer_norm = False if args.disable_bias_gelu: options.enable_bias_gelu = False if args.enable_gelu_approximation: options.enable_gelu_approximation = True if args.disable_shape_inference: options.enable_shape_inference = False if args.enable_gemm_fast_gelu: options.enable_gemm_fast_gelu = True if args.use_mask_index: options.use_raw_attention_mask(False) if args.use_raw_attention_mask: options.use_raw_attention_mask(True) if args.no_attention_mask: options.disable_attention_mask() if args.disable_group_norm: options.enable_group_norm = False if args.disable_packed_kv: options.enable_packed_kv = False return options @staticmethod def add_arguments(parser: ArgumentParser): parser.add_argument( "--disable_attention", required=False, action="store_true", help="disable Attention fusion", ) parser.set_defaults(disable_attention=False) parser.add_argument( "--disable_skip_layer_norm", required=False, action="store_true", help="disable SkipLayerNormalization fusion", ) parser.set_defaults(disable_skip_layer_norm=False) parser.add_argument( "--disable_embed_layer_norm", required=False, action="store_true", help="disable EmbedLayerNormalization fusion", ) parser.set_defaults(disable_embed_layer_norm=False) parser.add_argument( "--disable_bias_skip_layer_norm", required=False, action="store_true", help="disable Add Bias and SkipLayerNormalization fusion", ) parser.set_defaults(disable_bias_skip_layer_norm=False) parser.add_argument( "--disable_bias_gelu", required=False, action="store_true", help="disable Add Bias and Gelu/FastGelu fusion", ) parser.set_defaults(disable_bias_gelu=False) parser.add_argument( "--disable_layer_norm", required=False, action="store_true", help="disable LayerNormalization fusion", ) parser.set_defaults(disable_layer_norm=False) parser.add_argument( "--disable_gelu", required=False, action="store_true", help="disable Gelu fusion", ) parser.set_defaults(disable_gelu=False) parser.add_argument( "--enable_gelu_approximation", required=False, action="store_true", help="enable Gelu/BiasGelu to FastGelu conversion", ) parser.set_defaults(enable_gelu_approximation=False) parser.add_argument( "--disable_shape_inference", required=False, action="store_true", help="disable symbolic shape inference", ) parser.set_defaults(disable_shape_inference=False) parser.add_argument( "--enable_gemm_fast_gelu", required=False, action="store_true", help="enable GemmfastGelu fusion", ) parser.set_defaults(enable_gemm_fast_gelu=False) parser.add_argument( "--use_mask_index", required=False, action="store_true", help="use mask index to activate fused attention to speed up. It requires right-side padding!", ) parser.set_defaults(use_mask_index=False) parser.add_argument( "--use_raw_attention_mask", required=False, action="store_true", help="use raw attention mask. Use this option if your input is not right-side padding. This might deactivate fused attention and get worse performance.", ) parser.set_defaults(use_raw_attention_mask=False) parser.add_argument( "--no_attention_mask", required=False, action="store_true", help="no attention mask. Only works for model_type=bert", ) parser.set_defaults(no_attention_mask=False) parser.add_argument( "--use_multi_head_attention", required=False, action="store_true", help="Use MultiHeadAttention instead of Attention operator for testing purpose. " "Note that MultiHeadAttention might be slower than Attention since MatMul of input projection is excluded. " "MultiHeadAttention has only CUDA implementation so the model can only run with cuda execution provider.", ) parser.set_defaults(use_multi_head_attention=False) parser.add_argument( "--disable_group_norm", required=False, action="store_true", help="not fuse GroupNorm. Only works for model_type=unet", ) parser.set_defaults(disable_group_norm=False) parser.add_argument( "--disable_packed_kv", required=False, action="store_true", help="not use packed kv in cross attention. Only works for model_type=unet", ) parser.set_defaults(disable_packed_kv=False)