m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

234 lines
8.5 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from argparse import ArgumentParser
  6. class AttentionMaskFormat:
  7. # Build 1D mask indice (sequence length). It requires right side padding! Recommended for BERT model to get best performance.
  8. MaskIndexEnd = 0
  9. # For experiment only. Do not use it in production.
  10. MaskIndexEndAndStart = 1
  11. # Raw attention mask with 0 means padding (or no attention) and 1 otherwise.
  12. AttentionMask = 2
  13. # No attention mask
  14. NoMask = 3
  15. class FusionOptions:
  16. """Options of fusion in graph optimization"""
  17. def __init__(self, model_type):
  18. self.enable_gelu = True
  19. self.enable_layer_norm = True
  20. self.enable_attention = True
  21. # Use MultiHeadAttention instead of Attention operator. The difference:
  22. # (1) Attention has merged weights for Q/K/V projection, which might be faster in some cases since 3 MatMul is
  23. # merged into one.
  24. # (2) Attention could only handle self attention; MultiHeadAttention could handle both self and cross attention.
  25. # (3) MultiHeadAttention has only cuda implementation right now.
  26. self.use_multi_head_attention = False
  27. self.enable_skip_layer_norm = True
  28. self.enable_embed_layer_norm = True
  29. self.enable_bias_skip_layer_norm = True
  30. self.enable_bias_gelu = True
  31. self.enable_gelu_approximation = False
  32. self.enable_qordered_matmul = True
  33. self.enable_shape_inference = True
  34. self.enable_gemm_fast_gelu = False
  35. # Set default to sequence length for BERT model to use fused attention to speed up.
  36. # Note that embed layer normalization will convert 2D mask to 1D when mask type is MaskIndexEnd.
  37. self.attention_mask_format = (
  38. AttentionMaskFormat.MaskIndexEnd if model_type == "bert" else AttentionMaskFormat.AttentionMask
  39. )
  40. # options for stable diffusion
  41. self.enable_group_norm = model_type == "unet"
  42. self.enable_bias_splitgelu = model_type == "unet"
  43. self.enable_packed_kv = model_type == "unet"
  44. def use_raw_attention_mask(self, use_raw_mask=True):
  45. if use_raw_mask:
  46. self.attention_mask_format = AttentionMaskFormat.AttentionMask
  47. else:
  48. self.attention_mask_format = AttentionMaskFormat.MaskIndexEnd
  49. def disable_attention_mask(self):
  50. self.attention_mask_format = AttentionMaskFormat.NoMask
  51. @staticmethod
  52. def parse(args):
  53. options = FusionOptions(args.model_type)
  54. if args.disable_gelu:
  55. options.enable_gelu = False
  56. if args.disable_layer_norm:
  57. options.enable_layer_norm = False
  58. if args.disable_attention:
  59. options.enable_attention = False
  60. if args.use_multi_head_attention:
  61. options.use_multi_head_attention = True
  62. if args.disable_skip_layer_norm:
  63. options.enable_skip_layer_norm = False
  64. if args.disable_embed_layer_norm:
  65. options.enable_embed_layer_norm = False
  66. if args.disable_bias_skip_layer_norm:
  67. options.enable_bias_skip_layer_norm = False
  68. if args.disable_bias_gelu:
  69. options.enable_bias_gelu = False
  70. if args.enable_gelu_approximation:
  71. options.enable_gelu_approximation = True
  72. if args.disable_shape_inference:
  73. options.enable_shape_inference = False
  74. if args.enable_gemm_fast_gelu:
  75. options.enable_gemm_fast_gelu = True
  76. if args.use_mask_index:
  77. options.use_raw_attention_mask(False)
  78. if args.use_raw_attention_mask:
  79. options.use_raw_attention_mask(True)
  80. if args.no_attention_mask:
  81. options.disable_attention_mask()
  82. if args.disable_group_norm:
  83. options.enable_group_norm = False
  84. if args.disable_packed_kv:
  85. options.enable_packed_kv = False
  86. return options
  87. @staticmethod
  88. def add_arguments(parser: ArgumentParser):
  89. parser.add_argument(
  90. "--disable_attention",
  91. required=False,
  92. action="store_true",
  93. help="disable Attention fusion",
  94. )
  95. parser.set_defaults(disable_attention=False)
  96. parser.add_argument(
  97. "--disable_skip_layer_norm",
  98. required=False,
  99. action="store_true",
  100. help="disable SkipLayerNormalization fusion",
  101. )
  102. parser.set_defaults(disable_skip_layer_norm=False)
  103. parser.add_argument(
  104. "--disable_embed_layer_norm",
  105. required=False,
  106. action="store_true",
  107. help="disable EmbedLayerNormalization fusion",
  108. )
  109. parser.set_defaults(disable_embed_layer_norm=False)
  110. parser.add_argument(
  111. "--disable_bias_skip_layer_norm",
  112. required=False,
  113. action="store_true",
  114. help="disable Add Bias and SkipLayerNormalization fusion",
  115. )
  116. parser.set_defaults(disable_bias_skip_layer_norm=False)
  117. parser.add_argument(
  118. "--disable_bias_gelu",
  119. required=False,
  120. action="store_true",
  121. help="disable Add Bias and Gelu/FastGelu fusion",
  122. )
  123. parser.set_defaults(disable_bias_gelu=False)
  124. parser.add_argument(
  125. "--disable_layer_norm",
  126. required=False,
  127. action="store_true",
  128. help="disable LayerNormalization fusion",
  129. )
  130. parser.set_defaults(disable_layer_norm=False)
  131. parser.add_argument(
  132. "--disable_gelu",
  133. required=False,
  134. action="store_true",
  135. help="disable Gelu fusion",
  136. )
  137. parser.set_defaults(disable_gelu=False)
  138. parser.add_argument(
  139. "--enable_gelu_approximation",
  140. required=False,
  141. action="store_true",
  142. help="enable Gelu/BiasGelu to FastGelu conversion",
  143. )
  144. parser.set_defaults(enable_gelu_approximation=False)
  145. parser.add_argument(
  146. "--disable_shape_inference",
  147. required=False,
  148. action="store_true",
  149. help="disable symbolic shape inference",
  150. )
  151. parser.set_defaults(disable_shape_inference=False)
  152. parser.add_argument(
  153. "--enable_gemm_fast_gelu",
  154. required=False,
  155. action="store_true",
  156. help="enable GemmfastGelu fusion",
  157. )
  158. parser.set_defaults(enable_gemm_fast_gelu=False)
  159. parser.add_argument(
  160. "--use_mask_index",
  161. required=False,
  162. action="store_true",
  163. help="use mask index to activate fused attention to speed up. It requires right-side padding!",
  164. )
  165. parser.set_defaults(use_mask_index=False)
  166. parser.add_argument(
  167. "--use_raw_attention_mask",
  168. required=False,
  169. action="store_true",
  170. help="use raw attention mask. Use this option if your input is not right-side padding. This might deactivate fused attention and get worse performance.",
  171. )
  172. parser.set_defaults(use_raw_attention_mask=False)
  173. parser.add_argument(
  174. "--no_attention_mask",
  175. required=False,
  176. action="store_true",
  177. help="no attention mask. Only works for model_type=bert",
  178. )
  179. parser.set_defaults(no_attention_mask=False)
  180. parser.add_argument(
  181. "--use_multi_head_attention",
  182. required=False,
  183. action="store_true",
  184. help="Use MultiHeadAttention instead of Attention operator for testing purpose. "
  185. "Note that MultiHeadAttention might be slower than Attention since MatMul of input projection is excluded. "
  186. "MultiHeadAttention has only CUDA implementation so the model can only run with cuda execution provider.",
  187. )
  188. parser.set_defaults(use_multi_head_attention=False)
  189. parser.add_argument(
  190. "--disable_group_norm",
  191. required=False,
  192. action="store_true",
  193. help="not fuse GroupNorm. Only works for model_type=unet",
  194. )
  195. parser.set_defaults(disable_group_norm=False)
  196. parser.add_argument(
  197. "--disable_packed_kv",
  198. required=False,
  199. action="store_true",
  200. help="not use packed kv in cross attention. Only works for model_type=unet",
  201. )
  202. parser.set_defaults(disable_packed_kv=False)