You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
234 lines
8.5 KiB
234 lines
8.5 KiB
# -------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
# --------------------------------------------------------------------------
|
|
from argparse import ArgumentParser
|
|
|
|
|
|
class AttentionMaskFormat:
|
|
# Build 1D mask indice (sequence length). It requires right side padding! Recommended for BERT model to get best performance.
|
|
MaskIndexEnd = 0
|
|
|
|
# For experiment only. Do not use it in production.
|
|
MaskIndexEndAndStart = 1
|
|
|
|
# Raw attention mask with 0 means padding (or no attention) and 1 otherwise.
|
|
AttentionMask = 2
|
|
|
|
# No attention mask
|
|
NoMask = 3
|
|
|
|
|
|
class FusionOptions:
|
|
"""Options of fusion in graph optimization"""
|
|
|
|
def __init__(self, model_type):
|
|
self.enable_gelu = True
|
|
self.enable_layer_norm = True
|
|
self.enable_attention = True
|
|
|
|
# Use MultiHeadAttention instead of Attention operator. The difference:
|
|
# (1) Attention has merged weights for Q/K/V projection, which might be faster in some cases since 3 MatMul is
|
|
# merged into one.
|
|
# (2) Attention could only handle self attention; MultiHeadAttention could handle both self and cross attention.
|
|
# (3) MultiHeadAttention has only cuda implementation right now.
|
|
self.use_multi_head_attention = False
|
|
|
|
self.enable_skip_layer_norm = True
|
|
self.enable_embed_layer_norm = True
|
|
self.enable_bias_skip_layer_norm = True
|
|
self.enable_bias_gelu = True
|
|
self.enable_gelu_approximation = False
|
|
self.enable_qordered_matmul = True
|
|
|
|
self.enable_shape_inference = True
|
|
self.enable_gemm_fast_gelu = False
|
|
|
|
# Set default to sequence length for BERT model to use fused attention to speed up.
|
|
# Note that embed layer normalization will convert 2D mask to 1D when mask type is MaskIndexEnd.
|
|
self.attention_mask_format = (
|
|
AttentionMaskFormat.MaskIndexEnd if model_type == "bert" else AttentionMaskFormat.AttentionMask
|
|
)
|
|
|
|
# options for stable diffusion
|
|
self.enable_group_norm = model_type == "unet"
|
|
self.enable_bias_splitgelu = model_type == "unet"
|
|
self.enable_packed_kv = model_type == "unet"
|
|
|
|
def use_raw_attention_mask(self, use_raw_mask=True):
|
|
if use_raw_mask:
|
|
self.attention_mask_format = AttentionMaskFormat.AttentionMask
|
|
else:
|
|
self.attention_mask_format = AttentionMaskFormat.MaskIndexEnd
|
|
|
|
def disable_attention_mask(self):
|
|
self.attention_mask_format = AttentionMaskFormat.NoMask
|
|
|
|
@staticmethod
|
|
def parse(args):
|
|
options = FusionOptions(args.model_type)
|
|
if args.disable_gelu:
|
|
options.enable_gelu = False
|
|
if args.disable_layer_norm:
|
|
options.enable_layer_norm = False
|
|
if args.disable_attention:
|
|
options.enable_attention = False
|
|
if args.use_multi_head_attention:
|
|
options.use_multi_head_attention = True
|
|
if args.disable_skip_layer_norm:
|
|
options.enable_skip_layer_norm = False
|
|
if args.disable_embed_layer_norm:
|
|
options.enable_embed_layer_norm = False
|
|
if args.disable_bias_skip_layer_norm:
|
|
options.enable_bias_skip_layer_norm = False
|
|
if args.disable_bias_gelu:
|
|
options.enable_bias_gelu = False
|
|
if args.enable_gelu_approximation:
|
|
options.enable_gelu_approximation = True
|
|
if args.disable_shape_inference:
|
|
options.enable_shape_inference = False
|
|
if args.enable_gemm_fast_gelu:
|
|
options.enable_gemm_fast_gelu = True
|
|
if args.use_mask_index:
|
|
options.use_raw_attention_mask(False)
|
|
if args.use_raw_attention_mask:
|
|
options.use_raw_attention_mask(True)
|
|
if args.no_attention_mask:
|
|
options.disable_attention_mask()
|
|
if args.disable_group_norm:
|
|
options.enable_group_norm = False
|
|
if args.disable_packed_kv:
|
|
options.enable_packed_kv = False
|
|
return options
|
|
|
|
@staticmethod
|
|
def add_arguments(parser: ArgumentParser):
|
|
parser.add_argument(
|
|
"--disable_attention",
|
|
required=False,
|
|
action="store_true",
|
|
help="disable Attention fusion",
|
|
)
|
|
parser.set_defaults(disable_attention=False)
|
|
|
|
parser.add_argument(
|
|
"--disable_skip_layer_norm",
|
|
required=False,
|
|
action="store_true",
|
|
help="disable SkipLayerNormalization fusion",
|
|
)
|
|
parser.set_defaults(disable_skip_layer_norm=False)
|
|
|
|
parser.add_argument(
|
|
"--disable_embed_layer_norm",
|
|
required=False,
|
|
action="store_true",
|
|
help="disable EmbedLayerNormalization fusion",
|
|
)
|
|
parser.set_defaults(disable_embed_layer_norm=False)
|
|
|
|
parser.add_argument(
|
|
"--disable_bias_skip_layer_norm",
|
|
required=False,
|
|
action="store_true",
|
|
help="disable Add Bias and SkipLayerNormalization fusion",
|
|
)
|
|
parser.set_defaults(disable_bias_skip_layer_norm=False)
|
|
|
|
parser.add_argument(
|
|
"--disable_bias_gelu",
|
|
required=False,
|
|
action="store_true",
|
|
help="disable Add Bias and Gelu/FastGelu fusion",
|
|
)
|
|
parser.set_defaults(disable_bias_gelu=False)
|
|
|
|
parser.add_argument(
|
|
"--disable_layer_norm",
|
|
required=False,
|
|
action="store_true",
|
|
help="disable LayerNormalization fusion",
|
|
)
|
|
parser.set_defaults(disable_layer_norm=False)
|
|
|
|
parser.add_argument(
|
|
"--disable_gelu",
|
|
required=False,
|
|
action="store_true",
|
|
help="disable Gelu fusion",
|
|
)
|
|
parser.set_defaults(disable_gelu=False)
|
|
|
|
parser.add_argument(
|
|
"--enable_gelu_approximation",
|
|
required=False,
|
|
action="store_true",
|
|
help="enable Gelu/BiasGelu to FastGelu conversion",
|
|
)
|
|
parser.set_defaults(enable_gelu_approximation=False)
|
|
|
|
parser.add_argument(
|
|
"--disable_shape_inference",
|
|
required=False,
|
|
action="store_true",
|
|
help="disable symbolic shape inference",
|
|
)
|
|
parser.set_defaults(disable_shape_inference=False)
|
|
|
|
parser.add_argument(
|
|
"--enable_gemm_fast_gelu",
|
|
required=False,
|
|
action="store_true",
|
|
help="enable GemmfastGelu fusion",
|
|
)
|
|
parser.set_defaults(enable_gemm_fast_gelu=False)
|
|
|
|
parser.add_argument(
|
|
"--use_mask_index",
|
|
required=False,
|
|
action="store_true",
|
|
help="use mask index to activate fused attention to speed up. It requires right-side padding!",
|
|
)
|
|
parser.set_defaults(use_mask_index=False)
|
|
|
|
parser.add_argument(
|
|
"--use_raw_attention_mask",
|
|
required=False,
|
|
action="store_true",
|
|
help="use raw attention mask. Use this option if your input is not right-side padding. This might deactivate fused attention and get worse performance.",
|
|
)
|
|
parser.set_defaults(use_raw_attention_mask=False)
|
|
|
|
parser.add_argument(
|
|
"--no_attention_mask",
|
|
required=False,
|
|
action="store_true",
|
|
help="no attention mask. Only works for model_type=bert",
|
|
)
|
|
parser.set_defaults(no_attention_mask=False)
|
|
|
|
parser.add_argument(
|
|
"--use_multi_head_attention",
|
|
required=False,
|
|
action="store_true",
|
|
help="Use MultiHeadAttention instead of Attention operator for testing purpose. "
|
|
"Note that MultiHeadAttention might be slower than Attention since MatMul of input projection is excluded. "
|
|
"MultiHeadAttention has only CUDA implementation so the model can only run with cuda execution provider.",
|
|
)
|
|
parser.set_defaults(use_multi_head_attention=False)
|
|
|
|
parser.add_argument(
|
|
"--disable_group_norm",
|
|
required=False,
|
|
action="store_true",
|
|
help="not fuse GroupNorm. Only works for model_type=unet",
|
|
)
|
|
parser.set_defaults(disable_group_norm=False)
|
|
|
|
parser.add_argument(
|
|
"--disable_packed_kv",
|
|
required=False,
|
|
action="store_true",
|
|
help="not use packed kv in cross attention. Only works for model_type=unet",
|
|
)
|
|
parser.set_defaults(disable_packed_kv=False)
|