m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

441 lines
17 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. import logging
  7. import os
  8. import sys
  9. import tempfile
  10. from pathlib import Path
  11. from typing import List, Union
  12. import numpy
  13. import onnx
  14. import torch
  15. from past_helper import PastKeyValuesHelper
  16. from t5_encoder import T5EncoderInputs
  17. from transformers import MT5Config, T5Config
  18. from onnxruntime import InferenceSession
  19. sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
  20. from io_binding_helper import TypeHelper # noqa: E402
  21. from onnx_model import OnnxModel # noqa: E402
  22. from torch_onnx_export_helper import torch_onnx_export # noqa: E402
  23. logger = logging.getLogger(__name__)
  24. class T5DecoderInit(torch.nn.Module):
  25. """A T5 decoder with LM head to create initial past key values.
  26. This model is only called once during starting decoding.
  27. """
  28. def __init__(
  29. self,
  30. decoder: torch.nn.Module,
  31. lm_head: torch.nn.Module,
  32. config: Union[T5Config, MT5Config],
  33. decoder_start_token_id: int = None,
  34. ):
  35. super().__init__()
  36. self.decoder = decoder
  37. self.lm_head = lm_head
  38. self.config = config
  39. self.decoder_start_token_id = (
  40. decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
  41. )
  42. def forward(
  43. self,
  44. decoder_input_ids: torch.Tensor,
  45. encoder_attention_mask: torch.Tensor,
  46. encoder_hidden_states: torch.FloatTensor,
  47. ):
  48. if decoder_input_ids is None:
  49. batch_size = encoder_attention_mask.shape[0]
  50. decoder_input_ids = (
  51. torch.ones(
  52. (batch_size, 1),
  53. dtype=torch.long,
  54. device=encoder_attention_mask.device,
  55. )
  56. * self.decoder_start_token_id
  57. )
  58. decoder_outputs = self.decoder(
  59. input_ids=decoder_input_ids,
  60. encoder_hidden_states=encoder_hidden_states,
  61. encoder_attention_mask=encoder_attention_mask,
  62. use_cache=True,
  63. return_dict=True,
  64. )
  65. sequence_output = decoder_outputs.last_hidden_state
  66. present_key_values = decoder_outputs.past_key_values
  67. sequence_output = sequence_output * (self.config.d_model**-0.5)
  68. lm_logits = self.lm_head(sequence_output)
  69. past_self, past_cross = PastKeyValuesHelper.group_by_self_or_cross(present_key_values)
  70. return lm_logits, past_self, past_cross
  71. class T5Decoder(torch.nn.Module):
  72. """A T5 decoder with LM head and past key values"""
  73. def __init__(self, decoder, lm_head, config):
  74. super().__init__()
  75. self.decoder = decoder
  76. self.lm_head = lm_head
  77. self.config = config
  78. def forward(self, decoder_input_ids, encoder_attention_mask, encoder_hidden_states, *past):
  79. past_key_values = PastKeyValuesHelper.group_by_layer(past, self.config.num_layers)
  80. decoder_outputs = self.decoder(
  81. input_ids=decoder_input_ids,
  82. past_key_values=past_key_values,
  83. encoder_hidden_states=encoder_hidden_states,
  84. encoder_attention_mask=encoder_attention_mask,
  85. use_cache=True,
  86. return_dict=True,
  87. )
  88. sequence_output = decoder_outputs.last_hidden_state
  89. present_key_values = decoder_outputs.past_key_values
  90. sequence_output = sequence_output * (self.config.d_model**-0.5)
  91. lm_logits = self.lm_head(sequence_output)
  92. present_self, _ = PastKeyValuesHelper.group_by_self_or_cross(present_key_values)
  93. # Do not return present_cross since they are identical to corresponding past_cross input
  94. return lm_logits, present_self
  95. class T5DecoderInputs:
  96. def __init__(
  97. self,
  98. decoder_input_ids,
  99. encoder_attention_mask,
  100. encoder_hidden_states,
  101. past_key_values=None,
  102. ):
  103. self.decoder_input_ids: torch.LongTensor = decoder_input_ids
  104. self.encoder_attention_mask: torch.LongTensor = encoder_attention_mask
  105. self.encoder_hidden_states: Union[torch.FloatTensor, torch.HalfTensor] = encoder_hidden_states
  106. self.past_key_values: Union[List[torch.FloatTensor], List[torch.HalfTensor], None] = past_key_values
  107. @staticmethod
  108. def create_dummy(
  109. config: Union[T5Config, MT5Config],
  110. batch_size: int,
  111. encode_sequence_length: int,
  112. past_decode_sequence_length: int,
  113. device: torch.device,
  114. float16: bool = False,
  115. use_int32_inputs: bool = False,
  116. ): # -> T5DecoderInputs:
  117. """Create dummy inputs for T5Decoder.
  118. Args:
  119. decoder: decoder
  120. batch_size (int): batch size
  121. encode_sequence_length (int): sequence length of input_ids for encoder
  122. past_decode_sequence_length (int): past sequence length of input_ids for decoder
  123. device (torch.device): device of output tensors
  124. float16 (bool): whether the model uses float32 or float16 in input
  125. use_int32_inputs(bool): whether use int32 instead of int64 for some inputs
  126. Returns:
  127. T5DecoderInputs: dummy inputs for decoder
  128. """
  129. hidden_size: int = config.d_model
  130. num_attention_heads: int = config.num_heads
  131. num_layers: int = config.num_layers
  132. vocab_size: int = config.vocab_size
  133. # Do not use head_size = hidden_size / num_attention_heads here.
  134. # For example, mt5-small, d_model=512 and num_heads=6
  135. head_size: int = config.d_kv
  136. sequence_length: int = 1 # fixed for decoding
  137. decoder_input_ids = torch.randint(
  138. low=0,
  139. high=vocab_size - 1,
  140. size=(batch_size, sequence_length),
  141. dtype=(torch.int32 if use_int32_inputs else torch.int64),
  142. device=device,
  143. )
  144. encoder_inputs = T5EncoderInputs.create_dummy(
  145. batch_size,
  146. encode_sequence_length,
  147. vocab_size,
  148. device,
  149. use_int32_inputs=use_int32_inputs,
  150. )
  151. float_type = torch.float16 if float16 else torch.float32
  152. encoder_hidden_state = torch.rand(
  153. batch_size,
  154. encode_sequence_length,
  155. hidden_size,
  156. dtype=float_type,
  157. device=device,
  158. )
  159. if past_decode_sequence_length > 0:
  160. self_attention_past_shape = [
  161. batch_size,
  162. num_attention_heads,
  163. past_decode_sequence_length,
  164. head_size,
  165. ]
  166. cross_attention_past_shape = [
  167. batch_size,
  168. num_attention_heads,
  169. encode_sequence_length,
  170. head_size,
  171. ]
  172. past = []
  173. for _ in range(2 * num_layers):
  174. past.append(torch.rand(self_attention_past_shape, dtype=float_type, device=device))
  175. for _ in range(2 * num_layers):
  176. past.append(torch.rand(cross_attention_past_shape, dtype=float_type, device=device))
  177. else:
  178. past = None
  179. return T5DecoderInputs(decoder_input_ids, encoder_inputs.attention_mask, encoder_hidden_state, past)
  180. def to_list(self) -> List:
  181. input_list = [
  182. self.decoder_input_ids,
  183. self.encoder_attention_mask,
  184. self.encoder_hidden_states,
  185. ]
  186. if self.past_key_values:
  187. input_list.extend(self.past_key_values)
  188. return input_list
  189. def to_fp32(self):
  190. encoder_hidden_state = self.encoder_hidden_states.to(dtype=torch.float32)
  191. past = [p.to(dtype=torch.float32) for p in self.past_key_values] if self.past_key_values else None
  192. return T5DecoderInputs(
  193. self.decoder_input_ids.clone(),
  194. self.encoder_attention_mask.clone(),
  195. encoder_hidden_state,
  196. past,
  197. )
  198. class T5DecoderHelper:
  199. @staticmethod
  200. def export_onnx(
  201. decoder: Union[T5Decoder, T5DecoderInit],
  202. device: torch.device,
  203. onnx_model_path: str,
  204. verbose: bool = True,
  205. use_external_data_format: bool = False,
  206. use_int32_inputs: bool = False,
  207. ):
  208. """Export decoder to ONNX
  209. Args:
  210. decoder (Union[T5Decoder, T5DecoderNoPastState]): decoder object
  211. device (torch.device): device of decoder object
  212. onnx_model_path (str): onnx path
  213. verbose (bool, optional): print verbose information. Defaults to True.
  214. use_external_data_format (bool, optional): use external data format or not. Defaults to False.
  215. use_int32_inputs (bool, optional): use int32 inputs
  216. """
  217. assert isinstance(decoder, (T5Decoder, T5DecoderInit))
  218. inputs = T5DecoderInputs.create_dummy(
  219. decoder.config,
  220. batch_size=2,
  221. encode_sequence_length=3,
  222. past_decode_sequence_length=5 if isinstance(decoder, T5Decoder) else 0,
  223. device=device,
  224. use_int32_inputs=use_int32_inputs,
  225. )
  226. input_list = inputs.to_list()
  227. past_names = PastKeyValuesHelper.get_past_names(decoder.config.num_layers, present=False)
  228. present_names = PastKeyValuesHelper.get_past_names(decoder.config.num_layers, present=True)
  229. present_self_names = present_names[: 2 * decoder.config.num_layers]
  230. input_past_names = past_names if isinstance(decoder, T5Decoder) else []
  231. output_present_names = present_self_names if isinstance(decoder, T5Decoder) else present_names
  232. output_names = ["logits"] + output_present_names
  233. # Shape of input tensors (sequence_length==1):
  234. # input_ids: (batch_size, sequence_length)
  235. # encoder_attention_mask: (batch_size, encode_sequence_length)
  236. # encoder_hidden_states: (batch_size, encode_sequence_length, hidden_size)
  237. # past_self_*: (batch_size, num_heads, past_decode_sequence_length, head_size)
  238. # past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
  239. # Shape of output tensors:
  240. # logits: (batch_size, sequence_length, vocab_size)
  241. # past_self_*: (batch_size, num_heads, past_decode_sequence_length + sequence_length, head_size)
  242. # past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
  243. input_names = ["input_ids"]
  244. input_names.append("encoder_attention_mask")
  245. input_names.append("encoder_hidden_states")
  246. input_names.extend(input_past_names)
  247. dynamic_axes = {
  248. "input_ids": {
  249. 0: "batch_size",
  250. # 1: 'sequence_length'
  251. },
  252. "encoder_attention_mask": {0: "batch_size", 1: "encode_sequence_length"},
  253. "encoder_hidden_states": {0: "batch_size", 1: "encode_sequence_length"},
  254. "logits": {
  255. 0: "batch_size",
  256. # 1: 'sequence_length'
  257. },
  258. }
  259. for name in input_past_names:
  260. dynamic_axes[name] = {
  261. 0: "batch_size",
  262. 2: "past_decode_sequence_length" if "self" in name else "encode_sequence_length",
  263. }
  264. for name in output_present_names:
  265. if "cross" in name:
  266. dynamic_axes[name] = {0: "batch_size", 2: "encode_sequence_length"}
  267. else: # self attention past state
  268. if isinstance(decoder, T5Decoder):
  269. dynamic_axes[name] = {
  270. 0: "batch_size",
  271. 2: "past_decode_sequence_length + 1",
  272. }
  273. else:
  274. dynamic_axes[name] = {
  275. 0: "batch_size",
  276. # 2: 'sequence_length'
  277. }
  278. Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
  279. with tempfile.TemporaryDirectory() as tmp_dir_name:
  280. temp_onnx_model_path = os.path.join(tmp_dir_name, "decoder.onnx")
  281. Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
  282. torch_onnx_export(
  283. decoder,
  284. args=tuple(input_list),
  285. f=temp_onnx_model_path if use_external_data_format else onnx_model_path,
  286. export_params=True,
  287. input_names=input_names,
  288. output_names=output_names,
  289. dynamic_axes=dynamic_axes,
  290. opset_version=12,
  291. do_constant_folding=True,
  292. use_external_data_format=use_external_data_format,
  293. verbose=verbose,
  294. )
  295. if use_external_data_format:
  296. model = onnx.load_model(temp_onnx_model_path, load_external_data=True)
  297. OnnxModel.save(
  298. model,
  299. onnx_model_path,
  300. save_as_external_data=True,
  301. all_tensors_to_one_file=True,
  302. )
  303. @staticmethod
  304. def onnxruntime_inference(ort_session, inputs: T5DecoderInputs):
  305. """Run inference of ONNX model."""
  306. logger.debug("start onnxruntime_inference")
  307. ort_inputs = {
  308. "input_ids": numpy.ascontiguousarray(inputs.decoder_input_ids.cpu().numpy()),
  309. "encoder_attention_mask": numpy.ascontiguousarray(inputs.encoder_attention_mask.cpu().numpy()),
  310. "encoder_hidden_states": numpy.ascontiguousarray(inputs.encoder_hidden_states.cpu().numpy()),
  311. }
  312. if inputs.past_key_values:
  313. assert len(inputs.past_key_values) % 4 == 0
  314. num_layers = int(len(inputs.past_key_values) / 4)
  315. past_names = PastKeyValuesHelper.get_past_names(num_layers)
  316. for i, past_tensor in enumerate(inputs.past_key_values):
  317. ort_inputs[past_names[i]] = numpy.ascontiguousarray(past_tensor.cpu().numpy())
  318. ort_outputs = ort_session.run(None, ort_inputs)
  319. return ort_outputs
  320. @staticmethod
  321. def verify_onnx(
  322. model: Union[T5Decoder, T5DecoderInit],
  323. ort_session: InferenceSession,
  324. device: torch.device,
  325. use_int32_inputs: bool,
  326. max_cases: int = 4,
  327. ):
  328. """Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
  329. float16: bool = TypeHelper.get_input_type(ort_session, "encoder_hidden_states") == "tensor(float16)"
  330. test_cases = [(4, 11, 3), (1, 2, 5), (3, 1, 1), (8, 5, 2)]
  331. test_cases_max_diff = []
  332. for (
  333. batch_size,
  334. encode_sequence_length,
  335. past_decode_sequence_length,
  336. ) in test_cases[:max_cases]:
  337. if isinstance(model, T5DecoderInit):
  338. past_decode_sequence_length = 0
  339. inputs = T5DecoderInputs.create_dummy(
  340. model.config,
  341. batch_size,
  342. encode_sequence_length,
  343. past_decode_sequence_length,
  344. device=device,
  345. float16=float16,
  346. use_int32_inputs=use_int32_inputs,
  347. )
  348. # We use fp32 PyTroch model as baseline even when ONNX model is fp16
  349. input_list = inputs.to_fp32().to_list()
  350. # Run inference of PyTorch model
  351. with torch.no_grad():
  352. torch_outputs = model(*input_list)
  353. ort_outputs = T5DecoderHelper.onnxruntime_inference(ort_session, inputs)
  354. max_diff = numpy.amax(numpy.abs(torch_outputs[0].cpu().numpy() - ort_outputs[0]))
  355. max_diff_all = max_diff
  356. logger.debug(f"logits max_diff={max_diff}")
  357. for i in range(2 * model.config.num_layers):
  358. max_diff = numpy.amax(numpy.abs(torch_outputs[1][i].cpu().numpy() - ort_outputs[1 + i]))
  359. logger.debug(f"self attention past state {i} max_diff={max_diff}")
  360. max_diff_all = max(max_diff_all, max_diff)
  361. if isinstance(model, T5DecoderInit):
  362. for i in range(2 * model.config.num_layers):
  363. max_diff = numpy.amax(
  364. numpy.abs(torch_outputs[2][i].cpu().numpy() - ort_outputs[1 + 2 * model.config.num_layers + i])
  365. )
  366. logger.debug(f"cross attention past state {i} max_diff={max_diff}")
  367. max_diff_all = max(max_diff_all, max_diff)
  368. test_cases_max_diff.append(max_diff_all)
  369. logger.info(
  370. f"batch_size={batch_size}, encode_sequence_length={encode_sequence_length}, "
  371. + f"past_decode_sequence_length={past_decode_sequence_length}, max_diff={max_diff_all}"
  372. )
  373. return max_diff_all