m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

174 lines
6.1 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. import logging
  7. import os
  8. import random
  9. import sys
  10. import tempfile
  11. from pathlib import Path
  12. from typing import List, Union
  13. import numpy
  14. import onnx
  15. import torch
  16. from transformers import MT5Config, T5Config
  17. from onnxruntime import InferenceSession
  18. sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
  19. from onnx_model import OnnxModel # noqa: E402
  20. from torch_onnx_export_helper import torch_onnx_export # noqa: E402
  21. logger = logging.getLogger(__name__)
  22. class T5Encoder(torch.nn.Module):
  23. """T5 encoder outputs only the last hidden state"""
  24. def __init__(self, encoder, config: Union[T5Config, MT5Config]):
  25. super().__init__()
  26. self.encoder = encoder
  27. self.config = config
  28. def forward(self, input_ids, attention_mask):
  29. return self.encoder(input_ids, attention_mask)[0]
  30. class T5EncoderInputs:
  31. def __init__(self, input_ids, attention_mask):
  32. self.input_ids: torch.LongTensor = input_ids
  33. self.attention_mask: torch.LongTensor = attention_mask
  34. @staticmethod
  35. def create_dummy(
  36. batch_size: int, sequence_length: int, vocab_size: int, device: torch.device, use_int32_inputs: bool = False
  37. ): # -> T5EncoderInputs
  38. """Create dummy inputs for T5 encoder.
  39. Args:
  40. batch_size (int): batch size
  41. sequence_length (int): sequence length
  42. vocab_size (int): vocabulary size
  43. device (torch.device): device of output tensors
  44. Returns:
  45. T5EncoderInputs: dummy inputs for encoder
  46. """
  47. dtype = torch.int32 if use_int32_inputs else torch.int64
  48. input_ids = torch.randint(
  49. low=0,
  50. high=vocab_size - 1,
  51. size=(batch_size, sequence_length),
  52. dtype=dtype,
  53. device=device,
  54. )
  55. attention_mask = torch.ones([batch_size, sequence_length], dtype=dtype, device=device)
  56. if sequence_length >= 2:
  57. for i in range(batch_size):
  58. padding_position = random.randint(0, sequence_length - 1)
  59. attention_mask[i, :padding_position] = 0
  60. return T5EncoderInputs(input_ids, attention_mask)
  61. def to_list(self) -> List:
  62. input_list = [v for v in [self.input_ids, self.attention_mask] if v is not None]
  63. return input_list
  64. class T5EncoderHelper:
  65. @staticmethod
  66. def export_onnx(
  67. encoder: T5Encoder,
  68. device: torch.device,
  69. onnx_model_path: str,
  70. verbose: bool = True,
  71. use_external_data_format: bool = False,
  72. use_int32_inputs: bool = False,
  73. ):
  74. """Export encoder to ONNX
  75. Args:
  76. encoder (T5Encoder): encoder object
  77. device (torch.device): device of encoder object
  78. onnx_model_path (str): onnx path
  79. verbose (bool, optional): print verbose information. Defaults to True.
  80. use_external_data_format (bool, optional): use external data format or not. Defaults to False.
  81. """
  82. config = encoder.config
  83. encoder_inputs = T5EncoderInputs.create_dummy(
  84. batch_size=2,
  85. sequence_length=4,
  86. vocab_size=config.vocab_size,
  87. device=device,
  88. use_int32_inputs=use_int32_inputs,
  89. )
  90. Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
  91. with tempfile.TemporaryDirectory() as tmp_dir_name:
  92. temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder.onnx")
  93. Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
  94. torch_onnx_export(
  95. encoder,
  96. args=tuple(encoder_inputs.to_list()),
  97. f=temp_onnx_model_path if use_external_data_format else onnx_model_path,
  98. export_params=True,
  99. input_names=["input_ids", "attention_mask"],
  100. output_names=["hidden_states"],
  101. dynamic_axes={
  102. "input_ids": {0: "batch_size", 1: "sequence_length"},
  103. "attention_mask": {0: "batch_size", 1: "sequence_length"},
  104. "hidden_states": {0: "batch_size", 1: "sequence_length"},
  105. },
  106. opset_version=12,
  107. do_constant_folding=True,
  108. use_external_data_format=use_external_data_format,
  109. verbose=verbose,
  110. )
  111. if use_external_data_format:
  112. model = onnx.load_model(temp_onnx_model_path, load_external_data=True)
  113. OnnxModel.save(
  114. model,
  115. onnx_model_path,
  116. save_as_external_data=True,
  117. all_tensors_to_one_file=True,
  118. )
  119. @staticmethod
  120. def onnxruntime_inference(ort_session, inputs: T5EncoderInputs):
  121. """Run inference of ONNX model."""
  122. ort_inputs = {
  123. "input_ids": numpy.ascontiguousarray(inputs.input_ids.cpu().numpy()),
  124. "attention_mask": numpy.ascontiguousarray(inputs.attention_mask.cpu().numpy()),
  125. }
  126. return ort_session.run(None, ort_inputs)
  127. @staticmethod
  128. def verify_onnx(
  129. model: T5Encoder, ort_session: InferenceSession, device: torch.device, use_int32_inputs: bool = False
  130. ):
  131. """Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
  132. inputs = T5EncoderInputs.create_dummy(
  133. batch_size=4,
  134. sequence_length=11,
  135. vocab_size=model.config.vocab_size,
  136. device=device,
  137. use_int32_inputs=use_int32_inputs,
  138. )
  139. input_list = inputs.to_list()
  140. torch_outputs = model(*input_list)
  141. ort_outputs = T5EncoderHelper.onnxruntime_inference(ort_session, inputs)
  142. max_diff = numpy.amax(numpy.abs(torch_outputs.cpu().numpy() - ort_outputs[0]))
  143. logger.info(f"max_diff={max_diff}")
  144. return max_diff