You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
100 lines
4.1 KiB
100 lines
4.1 KiB
# -------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License. See License.txt in the project root for
|
|
# license information.
|
|
# --------------------------------------------------------------------------
|
|
|
|
import argparse
|
|
import logging
|
|
import os
|
|
import sys
|
|
|
|
from utils import (
|
|
chain_enc_dec_with_beamsearch,
|
|
export_summarization_edinit,
|
|
export_summarization_enc_dec_past,
|
|
onnx_inference,
|
|
)
|
|
|
|
# GLOBAL ENVS
|
|
logging.basicConfig(
|
|
format="%(asctime)s | %(levelname)s | %(name)s | [%(filename)s:%(lineno)d] %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
|
stream=sys.stdout,
|
|
)
|
|
logger = logging.getLogger("generate")
|
|
|
|
|
|
def print_args(args):
|
|
for arg in vars(args):
|
|
logger.info(f"{arg}: {getattr(args, arg)}")
|
|
|
|
|
|
def user_command():
|
|
|
|
parent_parser = argparse.ArgumentParser(add_help=False)
|
|
parent_parser.add_argument("--max_length", type=int, default=20, help="default to 20")
|
|
parent_parser.add_argument("--min_length", type=int, default=0, help="default to 0")
|
|
parent_parser.add_argument("-o", "--output", type=str, default="onnx_models", help="default name is onnx_models.")
|
|
parent_parser.add_argument("-i", "--input_text", type=str, default=None, help="input text")
|
|
parent_parser.add_argument("-s", "--spm_path", type=str, default=None, help="tokenizer model from sentencepice")
|
|
parent_parser.add_argument("-v", "--vocab_path", type=str, help="vocab dictionary")
|
|
parent_parser.add_argument("-b", "--num_beams", type=int, default=5, help="default to 5")
|
|
parent_parser.add_argument("--repetition_penalty", type=float, default=1.0, help="default to 1.0")
|
|
parent_parser.add_argument("--no_repeat_ngram_size", type=int, default=3, help="default to 3")
|
|
parent_parser.add_argument("--early_stopping", type=bool, default=False, help="default to False")
|
|
parent_parser.add_argument("--opset_version", type=int, default=14, help="minimum is 14")
|
|
|
|
parent_parser.add_argument("--no_encoder", action="store_true")
|
|
parent_parser.add_argument("--no_decoder", action="store_true")
|
|
parent_parser.add_argument("--no_chain", action="store_true")
|
|
parent_parser.add_argument("--no_inference", action="store_true")
|
|
|
|
required_args = parent_parser.add_argument_group("required input arguments")
|
|
required_args.add_argument(
|
|
"-m",
|
|
"--model_dir",
|
|
type=str,
|
|
required=True,
|
|
help="The directory contains input huggingface model. \
|
|
An official model like facebook/bart-base is also acceptable.",
|
|
)
|
|
|
|
print_args(parent_parser.parse_args())
|
|
return parent_parser.parse_args()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
args = user_command()
|
|
if args.opset_version < 14:
|
|
raise ValueError(f"The minimum supported opset version is 14! The given one was {args.opset_version}.")
|
|
|
|
isExist = os.path.exists(args.output)
|
|
if not isExist:
|
|
os.makedirs(args.output)
|
|
|
|
# beam search op only supports CPU for now
|
|
args.device = "cpu"
|
|
logger.info("ENV: CPU ...")
|
|
|
|
if not args.input_text:
|
|
args.input_text = (
|
|
"PG&E stated it scheduled the blackouts in response to forecasts for high winds "
|
|
"amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
|
|
"scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
|
|
)
|
|
|
|
if not args.no_encoder:
|
|
logger.info(f"========== EXPORTING ENCODER ==========")
|
|
export_summarization_edinit.export_encoder(args)
|
|
if not args.no_decoder:
|
|
logger.info(f"========== EXPORTING DECODER ==========")
|
|
export_summarization_enc_dec_past.export_decoder(args)
|
|
if not args.no_chain:
|
|
logger.info(f"========== CONVERTING MODELS ==========")
|
|
chain_enc_dec_with_beamsearch.convert_model(args)
|
|
if not args.no_inference:
|
|
logger.info(f"========== INFERENCING WITH ONNX MODEL ==========")
|
|
onnx_inference.run_inference(args)
|