m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

161 lines
8.4 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. # Maps model class name to a tuple of model class
  7. MODEL_CLASSES = [
  8. "AutoModel",
  9. "AutoModelWithLMHead",
  10. "AutoModelForSequenceClassification",
  11. "AutoModelForQuestionAnswering",
  12. "AutoModelForCausalLM",
  13. ]
  14. # List of pretrained models: https://huggingface.co/transformers/pretrained_models.html
  15. # Pretrained model name to a tuple of input names, opset_version, use_external_data_format, optimization model type
  16. MODELS = {
  17. # BERT
  18. "bert-base-uncased": (
  19. ["input_ids", "attention_mask", "token_type_ids"],
  20. 12,
  21. False,
  22. "bert",
  23. ),
  24. "bert-large-uncased": (
  25. ["input_ids", "attention_mask", "token_type_ids"],
  26. 12,
  27. False,
  28. "bert",
  29. ),
  30. "bert-base-cased": (
  31. ["input_ids", "attention_mask", "token_type_ids"],
  32. 12,
  33. False,
  34. "bert",
  35. ),
  36. # "bert-large-cased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
  37. # "bert-base-multilingual-uncased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
  38. # "bert-base-multilingual-cased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
  39. # "bert-base-chinese": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
  40. # "bert-base-german-cased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
  41. # "bert-large-uncased-whole-word-masking": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
  42. # "bert-large-cased-whole-word-masking": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
  43. # "bert-large-uncased-whole-word-masking-finetuned-squad": (["input_ids", "attention_mask",
  44. # "token_type_ids"], 12, False, "bert"),
  45. # "bert-large-cased-whole-word-masking-finetuned-squad": (["input_ids", "attention_mask",
  46. # "token_type_ids"], 12, False, "bert"),
  47. # "bert-base-cased-finetuned-mrpc": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
  48. # "bert-base-german-dbmdz-cased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
  49. # "bert-base-german-dbmdz-uncased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
  50. # todo: more models to add
  51. # GPT (no past state)
  52. "openai-gpt": (["input_ids"], 11, False, "gpt2"),
  53. # GPT-2 (no past state, use benchmark_gpt2.py for past_key_values)
  54. "gpt2": (["input_ids"], 11, False, "gpt2"),
  55. "gpt2-medium": (["input_ids"], 11, False, "gpt2"),
  56. "gpt2-large": (["input_ids"], 11, True, "gpt2"),
  57. "gpt2-xl": (["input_ids"], 11, True, "gpt2"),
  58. "distilgpt2": (["input_ids"], 11, False, "gpt2"),
  59. # Transformer-XL (Models uses Einsum, which need opset version 12 or later.)
  60. "transfo-xl-wt103": (["input_ids", "mems"], 12, False, "bert"),
  61. # XLNet
  62. "xlnet-base-cased": (["input_ids"], 12, False, "bert"),
  63. "xlnet-large-cased": (["input_ids"], 12, False, "bert"),
  64. # XLM
  65. "xlm-mlm-en-2048": (["input_ids"], 11, True, "bert"),
  66. "xlm-mlm-ende-1024": (["input_ids"], 11, False, "bert"),
  67. "xlm-mlm-enfr-1024": (["input_ids"], 11, False, "bert"),
  68. # RoBERTa
  69. "roberta-base": (["input_ids", "attention_mask"], 12, False, "bert"),
  70. "roberta-large": (["input_ids", "attention_mask"], 12, False, "bert"),
  71. "roberta-large-mnli": (["input_ids", "attention_mask"], 12, False, "bert"),
  72. "deepset/roberta-base-squad2": (["input_ids", "attention_mask"], 11, False, "bert"),
  73. "distilroberta-base": (["input_ids", "attention_mask"], 12, False, "bert"),
  74. # DistilBERT
  75. "distilbert-base-uncased": (["input_ids", "attention_mask"], 11, False, "bert"),
  76. "distilbert-base-uncased-distilled-squad": (
  77. ["input_ids", "attention_mask"],
  78. 11,
  79. False,
  80. "bert",
  81. ),
  82. # CTRL
  83. "ctrl": (["input_ids"], 11, True, "bert"),
  84. # CamemBERT
  85. "camembert-base": (["input_ids"], 11, False, "bert"),
  86. # ALBERT
  87. "albert-base-v1": (["input_ids"], 12, False, "bert"),
  88. "albert-large-v1": (["input_ids"], 12, False, "bert"),
  89. "albert-xlarge-v1": (["input_ids"], 12, True, "bert"),
  90. # "albert-xxlarge-v1": (["input_ids"], 12, True, "bert"),
  91. "albert-base-v2": (["input_ids"], 12, False, "bert"),
  92. "albert-large-v2": (["input_ids"], 12, False, "bert"),
  93. "albert-xlarge-v2": (["input_ids"], 12, True, "bert"),
  94. # "albert-xxlarge-v2": (["input_ids"], 12, True, "bert"),
  95. # T5 (use benchmark_t5.py instead)
  96. # "t5-small": (["input_ids", "decoder_input_ids"], 12, False, "bert"),
  97. # "t5-base": (["input_ids", "decoder_input_ids"], 12, False, "bert"),
  98. # "t5-large": (["input_ids", "decoder_input_ids"], 12, True, "bert"),
  99. # "t5-3b": (["input_ids", "decoder_input_ids"], 12, True, "bert"),
  100. # "t5-11b": (["input_ids", "decoder_input_ids"], 12, True, "bert"),
  101. # "valhalla/t5-small-qa-qg-hl": (["input_ids"], 12, True, "bert"),
  102. # XLM-RoBERTa
  103. "xlm-roberta-base": (["input_ids"], 11, False, "bert"),
  104. "xlm-roberta-large": (["input_ids"], 11, True, "bert"),
  105. # FlauBERT
  106. "flaubert/flaubert_small_cased": (["input_ids"], 11, False, "bert"),
  107. # "flaubert/flaubert_base_uncased": (["input_ids"], 11, False, "bert"),
  108. "flaubert/flaubert_base_cased": (["input_ids"], 11, False, "bert"),
  109. # "flaubert/flaubert_large_cased": (["input_ids"], 11, False, "bert"),
  110. # Bart
  111. "facebook/bart-large": (["input_ids", "attention_mask"], 11, False, "bart"),
  112. "facebook/bart-base": (["input_ids", "attention_mask"], 11, False, "bart"),
  113. "facebook/bart-large-mnli": (["input_ids", "attention_mask"], 11, False, "bart"),
  114. "facebook/bart-large-cnn": (["input_ids", "attention_mask"], 11, False, "bart"),
  115. # DialoGPT
  116. "microsoft/DialoGPT-small": (["input_ids"], 11, False, "gpt2"),
  117. "microsoft/DialoGPT-medium": (["input_ids"], 11, False, "gpt2"),
  118. # "microsoft/DialoGPT-large": (["input_ids"], 11, True, "gpt2"),
  119. # Reformer
  120. # "google/reformer-enwik8": (["input_ids"], 11, False, "bert"),
  121. # "google/reformer-crime-and-punishment": (["input_ids"], 11, False, "bert"),
  122. # MarianMT
  123. # "Helsinki-NLP/opus-mt-ROMANCE-en": (["input_ids"], 12, False, "bert"),
  124. # Longformer (use benchmark_longformer.py instead)
  125. # "allenai/longformer-base-4096": (["input_ids"], 12, False, "bert"),
  126. # "allenai/longformer-large-4096": (["input_ids"], 12, False, "bert"),
  127. # MBart
  128. "facebook/mbart-large-cc25": (["input_ids"], 11, True, "bert"),
  129. "facebook/mbart-large-en-ro": (["input_ids"], 11, True, "bert"),
  130. # "Helsinki-NLP/opus-mt-ROMANCE-en": (["input_ids"], 12, False, "bert"),
  131. # # Longformer
  132. # "allenai/longformer-base-4096": (["input_ids"], 12, False, "bert"),
  133. # "allenai/longformer-large-4096": (["input_ids"], 12, True, "bert"),
  134. # "funnel-transformer/small": (["input_ids"], 12, False, "bert"),
  135. # "funnel-transformer/small-base": (["input_ids"], 12, False, "bert"),
  136. # "funnel-transformer/medium": (["input_ids"], 12, False, "bert"),
  137. # "funnel-transformer/medium-base": (["input_ids"], 12, False, "bert"),
  138. # "funnel-transformer/intermediate": (["input_ids"], 12, False, "bert"),
  139. # "funnel-transformer/intermediate-base": (["input_ids"], 12, False, "bert"),
  140. # "funnel-transformer/large": (["input_ids"], 12, True, "bert"),
  141. # "funnel-transformer/large-base": (["input_ids"], 12, True, "bert"),
  142. # "funnel-transformer/xlarge": (["input_ids"], 12, True, "bert"),
  143. # "funnel-transformer/xlarge-base": (["input_ids"], 12, True, "bert"),
  144. # Layoutlm
  145. "microsoft/layoutlm-base-uncased": (["input_ids"], 11, False, "bert"),
  146. "microsoft/layoutlm-large-uncased": (["input_ids"], 11, False, "bert"),
  147. # Squeezebert
  148. "squeezebert/squeezebert-uncased": (["input_ids"], 11, False, "bert"),
  149. "squeezebert/squeezebert-mnli": (["input_ids"], 11, False, "bert"),
  150. "squeezebert/squeezebert-mnli-headless": (["input_ids"], 11, False, "bert"),
  151. "unc-nlp/lxmert-base-uncased": (
  152. ["input_ids", "visual_feats", "visual_pos"],
  153. 11,
  154. False,
  155. "bert",
  156. ),
  157. # "google/pegasus-xsum": (["input_ids"], 11, False, "bert"),
  158. # "google/pegasus-large": (["input_ids"], 11, False, "bert"),
  159. }