m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

100 lines
4.1 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. import argparse
  7. import logging
  8. import os
  9. import sys
  10. from utils import (
  11. chain_enc_dec_with_beamsearch,
  12. export_summarization_edinit,
  13. export_summarization_enc_dec_past,
  14. onnx_inference,
  15. )
  16. # GLOBAL ENVS
  17. logging.basicConfig(
  18. format="%(asctime)s | %(levelname)s | %(name)s | [%(filename)s:%(lineno)d] %(message)s",
  19. datefmt="%Y-%m-%d %H:%M:%S",
  20. level=os.environ.get("LOGLEVEL", "INFO").upper(),
  21. stream=sys.stdout,
  22. )
  23. logger = logging.getLogger("generate")
  24. def print_args(args):
  25. for arg in vars(args):
  26. logger.info(f"{arg}: {getattr(args, arg)}")
  27. def user_command():
  28. parent_parser = argparse.ArgumentParser(add_help=False)
  29. parent_parser.add_argument("--max_length", type=int, default=20, help="default to 20")
  30. parent_parser.add_argument("--min_length", type=int, default=0, help="default to 0")
  31. parent_parser.add_argument("-o", "--output", type=str, default="onnx_models", help="default name is onnx_models.")
  32. parent_parser.add_argument("-i", "--input_text", type=str, default=None, help="input text")
  33. parent_parser.add_argument("-s", "--spm_path", type=str, default=None, help="tokenizer model from sentencepice")
  34. parent_parser.add_argument("-v", "--vocab_path", type=str, help="vocab dictionary")
  35. parent_parser.add_argument("-b", "--num_beams", type=int, default=5, help="default to 5")
  36. parent_parser.add_argument("--repetition_penalty", type=float, default=1.0, help="default to 1.0")
  37. parent_parser.add_argument("--no_repeat_ngram_size", type=int, default=3, help="default to 3")
  38. parent_parser.add_argument("--early_stopping", type=bool, default=False, help="default to False")
  39. parent_parser.add_argument("--opset_version", type=int, default=14, help="minimum is 14")
  40. parent_parser.add_argument("--no_encoder", action="store_true")
  41. parent_parser.add_argument("--no_decoder", action="store_true")
  42. parent_parser.add_argument("--no_chain", action="store_true")
  43. parent_parser.add_argument("--no_inference", action="store_true")
  44. required_args = parent_parser.add_argument_group("required input arguments")
  45. required_args.add_argument(
  46. "-m",
  47. "--model_dir",
  48. type=str,
  49. required=True,
  50. help="The directory contains input huggingface model. \
  51. An official model like facebook/bart-base is also acceptable.",
  52. )
  53. print_args(parent_parser.parse_args())
  54. return parent_parser.parse_args()
  55. if __name__ == "__main__":
  56. args = user_command()
  57. if args.opset_version < 14:
  58. raise ValueError(f"The minimum supported opset version is 14! The given one was {args.opset_version}.")
  59. isExist = os.path.exists(args.output)
  60. if not isExist:
  61. os.makedirs(args.output)
  62. # beam search op only supports CPU for now
  63. args.device = "cpu"
  64. logger.info("ENV: CPU ...")
  65. if not args.input_text:
  66. args.input_text = (
  67. "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
  68. "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
  69. "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
  70. )
  71. if not args.no_encoder:
  72. logger.info(f"========== EXPORTING ENCODER ==========")
  73. export_summarization_edinit.export_encoder(args)
  74. if not args.no_decoder:
  75. logger.info(f"========== EXPORTING DECODER ==========")
  76. export_summarization_enc_dec_past.export_decoder(args)
  77. if not args.no_chain:
  78. logger.info(f"========== CONVERTING MODELS ==========")
  79. chain_enc_dec_with_beamsearch.convert_model(args)
  80. if not args.no_inference:
  81. logger.info(f"========== INFERENCING WITH ONNX MODEL ==========")
  82. onnx_inference.run_inference(args)