m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

691 lines
23 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. import logging
  7. import os
  8. import sys
  9. from pathlib import Path
  10. import numpy
  11. import torch
  12. from affinity_helper import AffinitySetting
  13. from benchmark_helper import OptimizerInfo, Precision, create_onnxruntime_session
  14. from huggingface_models import MODEL_CLASSES
  15. from quantize_helper import QuantizeHelper
  16. from torch_onnx_export_helper import torch_onnx_export
  17. from transformers import AutoConfig, AutoTokenizer, LxmertConfig, TransfoXLConfig
  18. sys.path.append(os.path.join(os.path.dirname(__file__), "models", "gpt2"))
  19. from gpt2_helper import PRETRAINED_GPT2_MODELS, GPT2ModelNoPastState, TFGPT2ModelNoPastState
  20. os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
  21. logger = logging.getLogger(__name__)
  22. # Workaround by replacing torch.triu using self-defined op
  23. # Since torch.triu cannot be exported to ONNX. See https://github.com/pytorch/pytorch/issues/32968
  24. torch_func = {"triu": torch.triu}
  25. def triu_onnx(x, diagonal=0, out=None):
  26. assert out is None
  27. assert len(x.shape) == 2 and x.size(0) == x.size(1)
  28. torch_triu = torch_func["triu"]
  29. template = torch_triu(torch.ones((1024, 1024), dtype=torch.uint8), diagonal)
  30. mask = template[: x.size(0), : x.size(1)]
  31. return torch.where(mask.bool(), x, torch.zeros_like(x))
  32. def replace_torch_functions():
  33. torch.triu = triu_onnx
  34. def restore_torch_functions():
  35. torch.triu = torch_func["triu"]
  36. def create_onnxruntime_input(vocab_size, batch_size, sequence_length, input_names, config, data_type=numpy.int64):
  37. input_ids = numpy.random.randint(low=0, high=vocab_size - 1, size=(batch_size, sequence_length), dtype=data_type)
  38. inputs = {"input_ids": input_ids}
  39. if "attention_mask" in input_names:
  40. attention_mask = numpy.ones([batch_size, sequence_length], dtype=data_type)
  41. inputs["attention_mask"] = attention_mask
  42. if "token_type_ids" in input_names:
  43. segment_ids = numpy.zeros([batch_size, sequence_length], dtype=data_type)
  44. inputs["token_type_ids"] = segment_ids
  45. if config.is_encoder_decoder:
  46. inputs["decoder_input_ids"] = input_ids
  47. if isinstance(config, LxmertConfig):
  48. inputs["visual_feats"] = numpy.random.randn(1, 1, config.visual_feat_dim).astype(numpy.float32)
  49. inputs["visual_pos"] = numpy.random.randn(1, 1, config.visual_pos_dim).astype(numpy.float32)
  50. if isinstance(config, TransfoXLConfig):
  51. inputs["tf_transfo_xl_model/transformer/pos_emb/einsum/Einsum/inputs_1:0"] = numpy.zeros(
  52. [config.hidden_size], dtype=numpy.float32
  53. )
  54. return inputs
  55. def filter_inputs(inputs, input_names):
  56. remaining_model_inputs = {}
  57. for input_name in input_names:
  58. if input_name in inputs:
  59. remaining_model_inputs[input_name] = inputs[input_name]
  60. return remaining_model_inputs
  61. def flatten(inputs):
  62. return [[flatten(i) for i in inputs] if isinstance(inputs, (list, tuple)) else inputs]
  63. def update_flatten_list(inputs, res_list):
  64. for i in inputs:
  65. res_list.append(i) if not isinstance(i, (list, tuple)) else update_flatten_list(i, res_list)
  66. return res_list
  67. def build_dynamic_axes(example_inputs, outputs_flatten):
  68. sequence_length = example_inputs["input_ids"].shape[-1]
  69. dynamic_axes = {key: {0: "batch_size", 1: "seq_len"} for key in example_inputs.keys()}
  70. output_names = ["output_" + str(i + 1) for i in range(len(outputs_flatten))]
  71. for i, output_name in enumerate(output_names):
  72. dynamic_axes[output_name] = {0: "batch_size"}
  73. dims = outputs_flatten[i].shape
  74. for j, dim in enumerate(dims):
  75. if dim == sequence_length:
  76. dynamic_axes[output_name].update({j: "seq_len"})
  77. return dynamic_axes, output_names
  78. def validate_onnx_model(
  79. onnx_model_path,
  80. example_inputs,
  81. example_outputs_flatten,
  82. use_gpu,
  83. fp16,
  84. output_names=None,
  85. ):
  86. test_session = create_onnxruntime_session(onnx_model_path, use_gpu, enable_all_optimization=False)
  87. if test_session is None:
  88. logger.error(f"{onnx_model_path} is an invalid ONNX model")
  89. return False
  90. logger.info(f"{onnx_model_path} is a valid ONNX model")
  91. # Compare the inference result with PyTorch or Tensorflow
  92. example_ort_inputs = {k: t.numpy() for k, t in example_inputs.items()}
  93. example_ort_outputs = test_session.run(output_names, example_ort_inputs)
  94. if len(example_outputs_flatten) != len(example_ort_outputs):
  95. logger.error(
  96. f"Number of output tensors expected {len(example_outputs_flatten)}, got {len(example_ort_outputs)}"
  97. )
  98. return False
  99. for i in range(len(example_outputs_flatten)):
  100. abs_diff = numpy.amax(numpy.abs(example_ort_outputs[i] - example_outputs_flatten[i].cpu().numpy()))
  101. if abs_diff > 1e-4:
  102. logger.info(f"Max absolute diff={abs_diff} for output tensor {i}")
  103. rtol = 5e-02 if fp16 else 1e-4
  104. atol = 1e-01 if fp16 else 1e-4
  105. if not numpy.allclose(
  106. example_ort_outputs[i],
  107. example_outputs_flatten[i].cpu().numpy(),
  108. rtol=rtol,
  109. atol=atol,
  110. ):
  111. logger.error(f"Output tensor {i} is not close: rtol={rtol}, atol={atol}")
  112. return False
  113. logger.info(f"inference result of onnxruntime is validated on {onnx_model_path}")
  114. return True
  115. def get_onnx_file_path(
  116. onnx_dir: str,
  117. model_name: str,
  118. input_count: int,
  119. optimized_by_script: bool,
  120. use_gpu: bool,
  121. precision: Precision,
  122. optimized_by_onnxruntime: bool,
  123. use_external_data: bool,
  124. ):
  125. from re import sub
  126. normalized_model_name = sub(r"[^a-zA-Z0-9_]", "_", model_name)
  127. if not optimized_by_script:
  128. filename = f"{normalized_model_name}_{input_count}"
  129. else:
  130. device = "gpu" if use_gpu else "cpu"
  131. filename = f"{normalized_model_name}_{input_count}_{precision}_{device}"
  132. if optimized_by_onnxruntime:
  133. filename += f"_ort"
  134. directory = onnx_dir
  135. # ONNXRuntime will not write external data so the raw and optimized models shall be in same directory.
  136. if use_external_data and not optimized_by_onnxruntime:
  137. directory = os.path.join(onnx_dir, filename)
  138. if not os.path.exists(directory):
  139. os.makedirs(directory)
  140. return os.path.join(directory, f"{filename}.onnx")
  141. def add_filename_suffix(file_path: str, suffix: str) -> str:
  142. """
  143. Append a suffix at the filename (before the extension).
  144. Args:
  145. path: pathlib.Path The actual path object we would like to add a suffix
  146. suffix: The suffix to add
  147. Returns: path with suffix appended at the end of the filename and before extension
  148. """
  149. path = Path(file_path)
  150. return str(path.parent.joinpath(path.stem + suffix).with_suffix(path.suffix))
  151. def optimize_onnx_model_by_ort(onnx_model_path, ort_model_path, use_gpu, overwrite, model_fusion_statistics):
  152. if overwrite or not os.path.exists(ort_model_path):
  153. Path(ort_model_path).parent.mkdir(parents=True, exist_ok=True)
  154. from optimizer import get_fusion_statistics, optimize_by_onnxruntime
  155. # Use onnxruntime to optimize model, which will be saved to *_ort.onnx
  156. _ = optimize_by_onnxruntime(
  157. onnx_model_path,
  158. use_gpu=use_gpu,
  159. optimized_model_path=ort_model_path,
  160. opt_level=99,
  161. )
  162. model_fusion_statistics[ort_model_path] = get_fusion_statistics(ort_model_path)
  163. else:
  164. logger.info(f"Skip optimization since model existed: {ort_model_path}")
  165. def optimize_onnx_model(
  166. onnx_model_path,
  167. optimized_model_path,
  168. model_type,
  169. num_attention_heads,
  170. hidden_size,
  171. use_gpu,
  172. precision,
  173. use_raw_attention_mask,
  174. overwrite,
  175. model_fusion_statistics,
  176. use_external_data_format,
  177. optimization_options=None,
  178. ):
  179. if overwrite or not os.path.exists(optimized_model_path):
  180. Path(optimized_model_path).parent.mkdir(parents=True, exist_ok=True)
  181. from fusion_options import FusionOptions
  182. from optimizer import optimize_model
  183. if optimization_options is None:
  184. optimization_options = FusionOptions(model_type)
  185. optimization_options.use_raw_attention_mask(use_raw_attention_mask)
  186. if Precision.FLOAT16 == precision:
  187. optimization_options.enable_gelu_approximation = True
  188. if Precision.INT8 == precision:
  189. optimization_options.enable_embed_layer_norm = False
  190. # Use script to optimize model.
  191. # Use opt_level <= 1 for models to be converted to fp16, because some fused op (like FusedGemm) has only fp32 and no fp16.
  192. # It is better to be conservative so we use opt_level=0 here, in case MemcpyFromHost is added to the graph by OnnxRuntime.
  193. opt_model = optimize_model(
  194. onnx_model_path,
  195. model_type,
  196. num_heads=num_attention_heads,
  197. hidden_size=hidden_size,
  198. opt_level=0,
  199. optimization_options=optimization_options,
  200. use_gpu=use_gpu,
  201. only_onnxruntime=False,
  202. )
  203. if model_type == "bert_keras" or model_type == "bert_tf":
  204. opt_model.use_dynamic_axes()
  205. model_fusion_statistics[optimized_model_path] = opt_model.get_fused_operator_statistics()
  206. if Precision.FLOAT16 == precision:
  207. opt_model.convert_float_to_float16(keep_io_types=True)
  208. opt_model.save_model_to_file(optimized_model_path, use_external_data_format)
  209. else:
  210. logger.info(f"Skip optimization since model existed: {optimized_model_path}")
  211. def modelclass_dispatcher(model_name, custom_model_class):
  212. if custom_model_class != None:
  213. if custom_model_class in MODEL_CLASSES:
  214. return custom_model_class
  215. else:
  216. raise Exception("Valid model class: " + " ".join(MODEL_CLASSES))
  217. if model_name in PRETRAINED_GPT2_MODELS:
  218. return "GPT2ModelNoPastState"
  219. import re
  220. if re.search("-squad$", model_name) != None:
  221. return "AutoModelForQuestionAnswering"
  222. elif re.search("-mprc$", model_name) != None:
  223. return "AutoModelForSequenceClassification"
  224. elif re.search("gpt2", model_name) != None:
  225. return "AutoModelWithLMHead"
  226. return "AutoModel"
  227. def load_pretrained_model(model_name, config, cache_dir, custom_model_class, is_tf_model=False):
  228. model_class_name = modelclass_dispatcher(model_name, custom_model_class)
  229. if model_class_name == "GPT2ModelNoPastState":
  230. if is_tf_model:
  231. return TFGPT2ModelNoPastState.from_pretrained(model_name, config=config, cache_dir=cache_dir)
  232. else:
  233. return GPT2ModelNoPastState.from_pretrained(model_name, config=config, cache_dir=cache_dir)
  234. if is_tf_model:
  235. model_class_name = "TF" + model_class_name
  236. transformers_module = __import__("transformers", fromlist=[model_class_name])
  237. logger.info(f"Model class name: {model_class_name}")
  238. model_class = getattr(transformers_module, model_class_name)
  239. return model_class.from_pretrained(model_name, config=config, cache_dir=cache_dir)
  240. def load_pt_model(model_name, model_class, cache_dir, config_modifier):
  241. config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
  242. if hasattr(config, "return_dict"):
  243. config.return_dict = False
  244. config_modifier.modify(config)
  245. model = load_pretrained_model(model_name, config=config, cache_dir=cache_dir, custom_model_class=model_class)
  246. return config, model
  247. def load_tf_model(model_name, model_class, cache_dir, config_modifier):
  248. config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
  249. config_modifier.modify(config)
  250. # Loading tf model from transformers limits the cpu affinity to {0} when KMP_AFFINITY is set
  251. # Restore the affinity after model loading for expected ORT performance
  252. affinity_setting = AffinitySetting()
  253. affinity_setting.get_affinity()
  254. model = load_pretrained_model(
  255. model_name,
  256. config=config,
  257. cache_dir=cache_dir,
  258. custom_model_class=model_class,
  259. is_tf_model=True,
  260. )
  261. affinity_setting.set_affinity()
  262. return config, model
  263. # For test only
  264. def load_pt_model_from_tf(model_name):
  265. # Note that we could get pt model from tf, but model source and its structure in this case is different from directly using
  266. # load_pt_model() and load_tf_model() even with the same name. Therefore it should not be used for comparing with them
  267. from convert_tf_models_to_pytorch import tf2pt_pipeline
  268. config, model = tf2pt_pipeline(model_name)
  269. return config, model
  270. def validate_and_optimize_onnx(
  271. model_name,
  272. use_external_data_format,
  273. model_type,
  274. onnx_dir,
  275. input_names,
  276. use_gpu,
  277. precision,
  278. optimize_info,
  279. validate_onnx,
  280. use_raw_attention_mask,
  281. overwrite,
  282. config,
  283. model_fusion_statistics,
  284. onnx_model_path,
  285. example_inputs,
  286. example_outputs_flatten,
  287. output_names,
  288. fusion_options,
  289. ):
  290. is_valid_onnx_model = True
  291. if validate_onnx:
  292. is_valid_onnx_model = validate_onnx_model(
  293. onnx_model_path,
  294. example_inputs,
  295. example_outputs_flatten,
  296. use_gpu,
  297. False,
  298. output_names,
  299. )
  300. if optimize_info == OptimizerInfo.NOOPT:
  301. return onnx_model_path, is_valid_onnx_model, config.vocab_size
  302. if (
  303. optimize_info == OptimizerInfo.BYSCRIPT or precision == Precision.FLOAT16 or precision == Precision.INT8
  304. ): # Use script (optimizer.py) to optimize
  305. optimized_model_path = get_onnx_file_path(
  306. onnx_dir,
  307. model_name,
  308. len(input_names),
  309. True,
  310. use_gpu,
  311. precision,
  312. False,
  313. use_external_data_format,
  314. )
  315. optimize_onnx_model(
  316. onnx_model_path,
  317. optimized_model_path,
  318. model_type,
  319. config.num_attention_heads,
  320. config.hidden_size,
  321. use_gpu,
  322. precision,
  323. use_raw_attention_mask,
  324. overwrite,
  325. model_fusion_statistics,
  326. use_external_data_format,
  327. fusion_options,
  328. )
  329. onnx_model_path = optimized_model_path
  330. if validate_onnx:
  331. is_valid_onnx_model = validate_onnx_model(
  332. onnx_model_path,
  333. example_inputs,
  334. example_outputs_flatten,
  335. use_gpu,
  336. precision == Precision.FLOAT16,
  337. output_names,
  338. )
  339. if precision == Precision.INT8:
  340. logger.info(f"Quantizing model: {onnx_model_path}")
  341. QuantizeHelper.quantize_onnx_model(onnx_model_path, onnx_model_path, use_external_data_format)
  342. logger.info(f"Finished quantizing model: {onnx_model_path}")
  343. if optimize_info == OptimizerInfo.BYORT: # Use OnnxRuntime to optimize
  344. if is_valid_onnx_model:
  345. ort_model_path = add_filename_suffix(onnx_model_path, "_ort")
  346. optimize_onnx_model_by_ort(
  347. onnx_model_path,
  348. ort_model_path,
  349. use_gpu,
  350. overwrite,
  351. model_fusion_statistics,
  352. )
  353. return onnx_model_path, is_valid_onnx_model, config.vocab_size
  354. def export_onnx_model_from_pt(
  355. model_name,
  356. opset_version,
  357. use_external_data_format,
  358. model_type,
  359. model_class,
  360. config_modifier,
  361. cache_dir,
  362. onnx_dir,
  363. input_names,
  364. use_gpu,
  365. precision,
  366. optimizer_info,
  367. validate_onnx,
  368. use_raw_attention_mask,
  369. overwrite,
  370. model_fusion_statistics,
  371. fusion_options,
  372. ):
  373. config, model = load_pt_model(model_name, model_class, cache_dir, config_modifier)
  374. # config, model = load_pt_model_from_tf(model_name)
  375. model.cpu()
  376. tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
  377. max_input_size = (
  378. tokenizer.max_model_input_sizes[model_name] if model_name in tokenizer.max_model_input_sizes else 1024
  379. )
  380. example_inputs = tokenizer.encode_plus("This is a sample input", return_tensors="pt")
  381. example_inputs = filter_inputs(example_inputs, input_names)
  382. example_outputs = model(**example_inputs)
  383. assert isinstance(example_outputs, (list, tuple)), f"type of output is not list or tuple: {type(example_outputs)}"
  384. # Flatten is needed for gpt2 and distilgpt2.
  385. example_outputs_flatten = flatten(example_outputs)
  386. example_outputs_flatten = update_flatten_list(example_outputs_flatten, [])
  387. onnx_model_path = get_onnx_file_path(
  388. onnx_dir,
  389. model_name,
  390. len(input_names),
  391. False,
  392. use_gpu,
  393. precision,
  394. False,
  395. use_external_data_format,
  396. )
  397. if overwrite or not os.path.exists(onnx_model_path):
  398. logger.info("Exporting ONNX model to {}".format(onnx_model_path))
  399. Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
  400. dynamic_axes, output_names = build_dynamic_axes(example_inputs, example_outputs_flatten)
  401. replace_torch_functions()
  402. torch_onnx_export(
  403. model=model,
  404. args=tuple(example_inputs.values()),
  405. f=onnx_model_path,
  406. input_names=list(example_inputs.keys()),
  407. output_names=output_names,
  408. dynamic_axes=dynamic_axes,
  409. do_constant_folding=True,
  410. opset_version=opset_version,
  411. use_external_data_format=use_external_data_format,
  412. )
  413. restore_torch_functions()
  414. else:
  415. logger.info(f"Skip export since model existed: {onnx_model_path}")
  416. onnx_model_file, is_valid_onnx_model, vocab_size = validate_and_optimize_onnx(
  417. model_name,
  418. use_external_data_format,
  419. model_type,
  420. onnx_dir,
  421. input_names,
  422. use_gpu,
  423. precision,
  424. optimizer_info,
  425. validate_onnx,
  426. use_raw_attention_mask,
  427. overwrite,
  428. config,
  429. model_fusion_statistics,
  430. onnx_model_path,
  431. example_inputs,
  432. example_outputs_flatten,
  433. None,
  434. fusion_options,
  435. )
  436. return onnx_model_file, is_valid_onnx_model, vocab_size, max_input_size
  437. def export_onnx_model_from_tf(
  438. model_name,
  439. opset_version,
  440. use_external_data_format,
  441. model_type,
  442. model_class,
  443. config_modifier,
  444. cache_dir,
  445. onnx_dir,
  446. input_names,
  447. use_gpu,
  448. precision,
  449. optimizer_info,
  450. validate_onnx,
  451. use_raw_attention_mask,
  452. overwrite,
  453. model_fusion_statistics,
  454. fusion_options,
  455. ):
  456. # Use CPU to export
  457. import tensorflow as tf
  458. tf.config.set_visible_devices([], "GPU")
  459. tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
  460. # Fix "Using pad_token, but it is not set yet" error.
  461. if tokenizer.pad_token is None:
  462. tokenizer.add_special_tokens({"pad_token": "[PAD]"})
  463. max_input_size = (
  464. tokenizer.max_model_input_sizes[model_name] if model_name in tokenizer.max_model_input_sizes else 1024
  465. )
  466. config, model = load_tf_model(model_name, model_class, cache_dir, config_modifier)
  467. model.resize_token_embeddings(len(tokenizer))
  468. example_inputs = tokenizer.encode_plus(
  469. "This is a sample input",
  470. return_tensors="tf",
  471. max_length=max_input_size,
  472. padding="max_length",
  473. truncation=True,
  474. )
  475. example_inputs = filter_inputs(example_inputs, input_names)
  476. if config.is_encoder_decoder:
  477. example_inputs["decoder_input_ids"] = tokenizer.encode_plus(
  478. "This is a sample input",
  479. return_tensors="tf",
  480. max_length=max_input_size,
  481. padding="max_length",
  482. truncation=True,
  483. ).input_ids
  484. if model_name == "unc-nlp/lxmert-base-uncased":
  485. example_inputs["visual_feats"] = tf.random.normal([1, 1, config.visual_feat_dim])
  486. example_inputs["visual_pos"] = tf.random.normal([1, 1, config.visual_pos_dim])
  487. try:
  488. # Use no past state for these models
  489. if config.use_cache:
  490. config.use_cache = False
  491. except:
  492. pass
  493. example_outputs = model(example_inputs, training=False)
  494. output_names = None
  495. # For xlnet models, only compare the last_hidden_state output.
  496. if model_name == "xlnet-base-cased" or model_name == "xlnet-large-cased":
  497. output_names = ["last_hidden_state"]
  498. example_outputs = example_outputs["last_hidden_state"]
  499. # Flatten is needed for gpt2 and distilgpt2. Output name sorting is needed for tf2onnx outputs to match onnx outputs.
  500. from tensorflow.python.util import nest
  501. example_outputs_flatten = nest.flatten(example_outputs)
  502. onnx_model_path = get_onnx_file_path(
  503. onnx_dir,
  504. model_name,
  505. len(input_names),
  506. False,
  507. use_gpu,
  508. precision,
  509. False,
  510. use_external_data_format,
  511. )
  512. tf_internal_model_path = onnx_model_path[:-5] if use_external_data_format else onnx_model_path
  513. if overwrite or not os.path.exists(tf_internal_model_path):
  514. logger.info("Exporting ONNX model to {}".format(onnx_model_path))
  515. if not use_external_data_format:
  516. Path(tf_internal_model_path).parent.mkdir(parents=True, exist_ok=True)
  517. import zipfile
  518. import tf2onnx
  519. tf2onnx.logging.set_level(tf2onnx.logging.ERROR)
  520. specs = []
  521. for name, value in example_inputs.items():
  522. dims = [None] * len(value.shape)
  523. specs.append(tf.TensorSpec(tuple(dims), value.dtype, name=name))
  524. _, _ = tf2onnx.convert.from_keras(
  525. model,
  526. input_signature=tuple(specs),
  527. opset=opset_version,
  528. large_model=use_external_data_format,
  529. output_path=tf_internal_model_path,
  530. )
  531. if use_external_data_format:
  532. # need to unpack the zip for run_onnxruntime()
  533. with zipfile.ZipFile(tf_internal_model_path, "r") as z:
  534. z.extractall(os.path.dirname(tf_internal_model_path))
  535. tf_internal_model_path = os.path.join(os.path.dirname(tf_internal_model_path), "__MODEL_PROTO.onnx")
  536. if os.path.exists(onnx_model_path):
  537. os.remove(onnx_model_path)
  538. os.rename(tf_internal_model_path, onnx_model_path)
  539. else:
  540. logger.info(f"Skip export since model existed: {onnx_model_path}")
  541. model_type = model_type + "_tf"
  542. optimized_onnx_path, is_valid_onnx_model, vocab_size = validate_and_optimize_onnx(
  543. model_name,
  544. use_external_data_format,
  545. model_type,
  546. onnx_dir,
  547. input_names,
  548. use_gpu,
  549. precision,
  550. optimizer_info,
  551. validate_onnx,
  552. use_raw_attention_mask,
  553. overwrite,
  554. config,
  555. model_fusion_statistics,
  556. onnx_model_path,
  557. example_inputs,
  558. example_outputs_flatten,
  559. output_names,
  560. fusion_options,
  561. )
  562. return (
  563. optimized_onnx_path,
  564. is_valid_onnx_model,
  565. vocab_size,
  566. max_input_size,
  567. )