m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

503 lines
19 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. # This script helps evaluation of GPT-2 model.
  7. import logging
  8. import math
  9. import os
  10. import statistics
  11. import sys
  12. import timeit
  13. import numpy
  14. import torch
  15. from gpt2_helper import Gpt2Helper, Gpt2Inputs
  16. sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
  17. from benchmark_helper import Precision
  18. logger = logging.getLogger(__name__)
  19. class Gpt2Metric:
  20. def __init__(self, treatment_name, baseline_name="Torch", top_k=20):
  21. assert top_k > 1 and top_k <= 100
  22. self.baseline = baseline_name
  23. self.treatment = treatment_name
  24. self.name: str = f"{treatment_name} vs {baseline_name}"
  25. self.top_k = top_k
  26. self.top_1_error: int = 0
  27. self.top_k_error: int = 0
  28. self.total_samples: int = 0
  29. self.max_logits_diff: float = 0 # for non-empty past state
  30. self.max_logits_diff_no_past: float = 0 # for empty past state
  31. self.batch_top1_error: torch.FloatTensor = None # top 1 error for current batch
  32. self.batch_topk_error: torch.FloatTensor = None # top k error for current batch
  33. self.seq_len_latency = {}
  34. def print(self):
  35. if self.baseline != self.treatment:
  36. print("---")
  37. print(f"Metrics for {self.treatment} (baseline={self.baseline}):")
  38. if self.total_samples > 0:
  39. top_1_error_rate = 100.0 * self.top_1_error / self.total_samples
  40. top_k_error_rate = 100.0 * self.top_k_error / self.total_samples
  41. print(
  42. f"Total={self.total_samples} Top1Error={self.top_1_error} ({top_1_error_rate:.2f}%) Top{self.top_k}Error={self.top_k_error} ({top_k_error_rate:.2f}%)"
  43. )
  44. print("Max logits diffs:")
  45. print(f"\twith past = {self.max_logits_diff:.6f}")
  46. print(f"\tempty past = {self.max_logits_diff_no_past:.6f}")
  47. else:
  48. print(f"Metrics for {self.treatment} (baseline):")
  49. if self.seq_len_latency:
  50. print("Past sequence length range and average latency:")
  51. total = 0
  52. count = 0
  53. for key in sorted(self.seq_len_latency.keys()):
  54. average = statistics.mean(self.seq_len_latency[key]) * 1000.0
  55. if key == 0:
  56. print("\t{}: \t{:.2f} ms".format(key, average))
  57. else:
  58. print("\t[{}, {}]:\t{:.2f} ms".format(2**key, 2 ** (key + 1) - 1, average))
  59. total += average * len(self.seq_len_latency[key])
  60. count += len(self.seq_len_latency[key])
  61. print("Average Latency: {:.2f} ms".format(total / count))
  62. def diff_logits(self, baseline_logits, treatment_logits, is_empty_past: bool):
  63. diff = (baseline_logits - treatment_logits).abs().max()
  64. if is_empty_past:
  65. self.max_logits_diff_no_past = max(self.max_logits_diff_no_past, diff)
  66. else:
  67. self.max_logits_diff = max(self.max_logits_diff, diff)
  68. return diff
  69. def start_batch(self, batch_size: int):
  70. self.total_samples += batch_size
  71. self.batch_top1_error = torch.zeros((batch_size, 1), dtype=torch.bool)
  72. self.batch_topk_error = torch.zeros((batch_size, 1), dtype=torch.bool)
  73. def eval_batch(self, baseline, treatment, past_seq_len, verbose=True):
  74. self._eval_topk(baseline.top_1_tokens, treatment.top_1_tokens, 1, verbose)
  75. self._eval_topk(baseline.top_k_tokens, treatment.top_k_tokens, self.top_k, verbose)
  76. max_diff = self.diff_logits(baseline.logits, treatment.logits, past_seq_len == 0)
  77. if verbose:
  78. print(f"Max logits diffs of {self.name}: {max_diff}")
  79. def _eval_topk(self, baseline_topk, treatment_topk, top_k, verbose=True):
  80. if not torch.all(torch.eq(baseline_topk, treatment_topk)):
  81. if top_k == 1:
  82. if verbose:
  83. print(f"Generated tokens not matched for {self.name}")
  84. self.batch_top1_error |= torch.eq(baseline_topk, treatment_topk).logical_not()
  85. else:
  86. if verbose:
  87. print(
  88. f"Top {top_k} tokens not matched for {self.name}. This will lead to wrong beam search results"
  89. )
  90. self.batch_topk_error |= (
  91. torch.eq(baseline_topk, treatment_topk).logical_not().sum(1).unsqueeze(dim=1) > 0
  92. )
  93. def end_batch(self):
  94. self.top_1_error += self.batch_top1_error.sum()
  95. self.top_k_error += self.batch_topk_error.sum()
  96. def add_latency(self, past_seq_len, latency):
  97. key = int(math.log2(past_seq_len)) + 1 if past_seq_len > 0 else 0
  98. if key not in self.seq_len_latency:
  99. self.seq_len_latency[key] = []
  100. self.seq_len_latency[key].append(latency)
  101. class Gpt2Tester:
  102. def __init__(
  103. self,
  104. input_ids,
  105. position_ids,
  106. attention_mask,
  107. num_attention_heads,
  108. hidden_size,
  109. num_layer,
  110. device,
  111. is_fp16=False,
  112. top_k=20,
  113. top_k_required_order=False,
  114. ):
  115. self.batch_size = input_ids.shape[0]
  116. self.input_length = input_ids.shape[1]
  117. self.n_layer = num_layer
  118. self.input_ids = input_ids
  119. self.position_ids = position_ids
  120. self.attention_mask = attention_mask
  121. self.has_position_ids = position_ids is not None
  122. self.has_attention_mask = attention_mask is not None
  123. # Emtpy past state for first inference
  124. self.past = []
  125. past_shape = [
  126. 2,
  127. self.batch_size,
  128. num_attention_heads,
  129. 0,
  130. hidden_size // num_attention_heads,
  131. ]
  132. for i in range(num_layer):
  133. empty_past = torch.empty(past_shape).type(torch.float16 if is_fp16 else torch.float32)
  134. self.past.append(empty_past.to(device))
  135. self.logits = None
  136. self.top_1_tokens = None
  137. self.top_k_tokens = None
  138. self.top_k = top_k
  139. self.top_k_required_order = top_k_required_order
  140. def get_inputs(self) -> Gpt2Inputs:
  141. return Gpt2Inputs(self.input_ids, self.position_ids, self.attention_mask, self.past)
  142. def save_test_data(self, session, output, save_test_data_dir, test_case_id):
  143. from onnx import numpy_helper
  144. path = os.path.join(save_test_data_dir, "test_data_set_" + str(test_case_id))
  145. if os.path.exists(path):
  146. print(f"Directory {path} existed. Skip saving test data")
  147. return
  148. os.makedirs(path, exist_ok=True)
  149. def add_tensor(input_tensors, torch_tensor, name):
  150. input_tensors.append(numpy_helper.from_array(torch_tensor.clone().cpu().numpy(), name))
  151. input_tensors = []
  152. add_tensor(input_tensors, self.input_ids, "input_ids")
  153. if self.has_position_ids:
  154. add_tensor(input_tensors, self.position_ids, "position_ids")
  155. if self.has_attention_mask:
  156. add_tensor(input_tensors, self.attention_mask, "attention_mask")
  157. for i in range(self.n_layer):
  158. add_tensor(input_tensors, self.past[i], "past_" + str(i))
  159. for i, tensor in enumerate(input_tensors):
  160. with open(os.path.join(path, "input_{}.pb".format(i)), "wb") as f:
  161. f.write(tensor.SerializeToString())
  162. output_names = [output.name for output in session.get_outputs()]
  163. for i, name in enumerate(output_names):
  164. tensor = numpy_helper.from_array(
  165. output[i] if isinstance(output[i], numpy.ndarray) else output[i].clone().cpu().numpy()
  166. )
  167. with open(os.path.join(path, "output_{}.pb".format(i)), "wb") as f:
  168. f.write(tensor.SerializeToString())
  169. print(f"Test data saved to directory {path}")
  170. def update(self, output, step, device):
  171. """
  172. Update the inputs for next inference.
  173. """
  174. self.logits = (
  175. torch.from_numpy(output[0]) if isinstance(output[0], numpy.ndarray) else output[0].clone().detach().cpu()
  176. )
  177. self.top_1_tokens = Gpt2Tester.predict_next_token(self.logits)
  178. self.top_k_tokens = Gpt2Tester.predict_next_token(self.logits, self.top_k, self.top_k_required_order)
  179. self.input_ids = self.top_1_tokens.clone().detach().reshape([self.batch_size, 1]).to(device)
  180. if self.has_position_ids:
  181. self.position_ids = (
  182. torch.tensor([self.input_length + step - 1]).unsqueeze(0).repeat(self.batch_size, 1).to(device)
  183. )
  184. if self.has_attention_mask:
  185. self.attention_mask = torch.cat(
  186. [
  187. self.attention_mask,
  188. torch.ones([self.batch_size, 1]).type_as(self.attention_mask),
  189. ],
  190. 1,
  191. ).to(device)
  192. self.past = []
  193. if isinstance(output[1], tuple): # past in torch output is tuple
  194. self.past = list(output[1])
  195. else:
  196. for i in range(self.n_layer):
  197. past_i = (
  198. torch.from_numpy(output[i + 1])
  199. if isinstance(output[i + 1], numpy.ndarray)
  200. else output[i + 1].clone().detach()
  201. )
  202. self.past.append(past_i.to(device))
  203. def diff(self, baseline):
  204. """
  205. Compare inputs and logits output.
  206. """
  207. print("start diff...")
  208. if self.logits is not None:
  209. max_io_diff = (self.logits - baseline.logits).abs().max()
  210. if max_io_diff > 1e-4:
  211. print(f"Max logits difference is too large: {max_io_diff}")
  212. if not torch.all(self.input_ids == baseline.input_ids):
  213. print("Input_ids is different", self.input_ids, baseline.input_ids)
  214. if self.has_position_ids:
  215. if not torch.all(self.position_ids == baseline.position_ids):
  216. print(
  217. "position_ids is different",
  218. self.position_ids,
  219. baseline.position_ids,
  220. )
  221. if self.has_attention_mask:
  222. if not torch.all(self.attention_mask == baseline.attention_mask):
  223. print(
  224. "attention_mask is different",
  225. self.attention_mask,
  226. baseline.attention_mask,
  227. )
  228. assert len(self.past) == len(baseline.past)
  229. for i, past_i in enumerate(self.past):
  230. assert past_i.shape == baseline.past[i].shape
  231. if past_i.nelement() > 0:
  232. max_past_diff = (past_i - baseline.past[i]).abs().max()
  233. if max_past_diff > 1e-4:
  234. print(f"max_past_diff[{i}]={max_past_diff}")
  235. @staticmethod
  236. def predict_next_token(logits, top_k=1, required_order=False):
  237. """
  238. Get top k topkens based on logits.
  239. """
  240. # logits has shape (batch_size, seq_len, vocab_size)
  241. # last token logits has shape (batch_size, vocab_size)
  242. lastTokenLogits = logits[:, -1]
  243. if top_k == 1:
  244. generatedTokens = torch.argmax(lastTokenLogits, 1, True)
  245. return generatedTokens
  246. else:
  247. topk = torch.argsort(lastTokenLogits, -1, descending=True)[:, :top_k]
  248. if not required_order:
  249. sorted_topk, _ = topk.sort()
  250. return sorted_topk
  251. return topk
  252. @staticmethod
  253. def diff_present(onnx_output, onnx_io_output, n_layer):
  254. """
  255. Compare the present outputs of two outputs from ONNX Runtime.
  256. """
  257. present_diff_max = []
  258. for i in range(n_layer):
  259. onnx_present_i = (
  260. torch.from_numpy(onnx_output[i + 1])
  261. if isinstance(onnx_output[i + 1], numpy.ndarray)
  262. else onnx_output[i + 1]
  263. )
  264. onnx_io_present_i = (
  265. torch.from_numpy(onnx_io_output[i + 1])
  266. if isinstance(onnx_io_output[i + 1], numpy.ndarray)
  267. else onnx_io_output[i + 1]
  268. )
  269. max_diff = (onnx_present_i - onnx_io_present_i).abs().max()
  270. present_diff_max.append(max_diff)
  271. print(f"present_diff_max={present_diff_max}")
  272. @staticmethod
  273. def is_quantized_onnx_model(onnx_model_path):
  274. """
  275. Returns True if the ONNX model is quantized.
  276. """
  277. from onnx import load
  278. model = load(onnx_model_path)
  279. from onnxruntime.quantization.quantize import __producer__ as quantize_producer
  280. return model.producer_name == quantize_producer
  281. @staticmethod
  282. def test_generation(
  283. session,
  284. model,
  285. device,
  286. test_inputs,
  287. precision=Precision.FLOAT32,
  288. model_class="Gpt2LMHeadModel",
  289. top_k=20,
  290. top_k_no_order=True,
  291. max_steps=24,
  292. max_inputs=0,
  293. verbose=False,
  294. save_test_data=0,
  295. save_test_data_dir=".",
  296. ):
  297. """
  298. Test Generation using greedy beam search (without sampling) to compare PyTorch and ONNX model.
  299. It will print top 1 and top k errors on the given test inputs.
  300. """
  301. print(
  302. f"start test generation: (top_k={top_k} top_k_no_order={top_k_no_order} max_steps={max_steps} test_inputs={len(test_inputs)} max_inputs={max_inputs})"
  303. )
  304. n_layer = model.config.n_layer
  305. n_head = model.config.n_head
  306. n_embd = model.config.n_embd
  307. eos_token_id = model.config.eos_token_id
  308. test_data_saved = 0
  309. is_float16 = precision == Precision.FLOAT16
  310. if is_float16:
  311. assert "float16" in session.get_outputs()[0].type
  312. # We will still use fp32 torch model as baseline when onnx model if fp16
  313. model.eval().to(device)
  314. # Allocate initial buffers for IO Binding of ONNX Runtimne. The buffer size will automatically increase later.
  315. init_output_shapes = Gpt2Helper.get_output_shapes(
  316. batch_size=4,
  317. past_sequence_length=128,
  318. sequence_length=32,
  319. config=model.config,
  320. model_class=model_class,
  321. )
  322. output_buffers = Gpt2Helper.get_output_buffers(init_output_shapes, device, is_float16=is_float16)
  323. baseline_name = "Torch"
  324. treatment_name = "Quantized Onnx" if precision == Precision.INT8 else "Onnx"
  325. torch_metric = Gpt2Metric(baseline_name, baseline_name, top_k)
  326. onnx_metric = Gpt2Metric(treatment_name, baseline_name, top_k)
  327. onnx_io_metric = Gpt2Metric(treatment_name + " with IO Binding", baseline_name, top_k)
  328. for i, inputs in enumerate(test_inputs):
  329. if max_inputs > 0 and i == max_inputs:
  330. break
  331. if i % 10 == 0:
  332. print(f"{i}")
  333. input_ids = inputs["input_ids"]
  334. position_ids = inputs["position_ids"] if "position_ids" in inputs else None
  335. attention_mask = inputs["attention_mask"] if "attention_mask" in inputs else None
  336. onnx_runner = Gpt2Tester(
  337. input_ids,
  338. position_ids,
  339. attention_mask,
  340. n_head,
  341. n_embd,
  342. n_layer,
  343. device,
  344. is_float16,
  345. top_k,
  346. not top_k_no_order,
  347. )
  348. onnx_io_runner = Gpt2Tester(
  349. input_ids,
  350. position_ids,
  351. attention_mask,
  352. n_head,
  353. n_embd,
  354. n_layer,
  355. device,
  356. is_float16,
  357. top_k,
  358. not top_k_no_order,
  359. )
  360. torch_runner = Gpt2Tester(
  361. input_ids,
  362. position_ids,
  363. attention_mask,
  364. n_head,
  365. n_embd,
  366. n_layer,
  367. device,
  368. False,
  369. top_k,
  370. not top_k_no_order,
  371. ) # Torch model baseline is fp32
  372. batch_size = torch_runner.batch_size
  373. onnx_metric.start_batch(batch_size)
  374. onnx_io_metric.start_batch(batch_size)
  375. with torch.no_grad():
  376. done = torch.zeros(batch_size, dtype=torch.bool)
  377. for step in range(max_steps):
  378. seq_len = list(onnx_runner.input_ids.size())[1]
  379. past_seq_len = list(onnx_runner.past[0].size())[3]
  380. start_time = timeit.default_timer()
  381. pytorch_output = Gpt2Helper.pytorch_inference(model, torch_runner.get_inputs())
  382. torch_metric.add_latency(past_seq_len, timeit.default_timer() - start_time)
  383. torch_runner.update(pytorch_output, step, device)
  384. onnx_output, avg_latency_ms = Gpt2Helper.onnxruntime_inference(
  385. session, onnx_runner.get_inputs(), total_runs=1
  386. )
  387. onnx_metric.add_latency(past_seq_len, avg_latency_ms / 1000.0)
  388. onnx_runner.update(onnx_output, step, device)
  389. output_shapes = Gpt2Helper.get_output_shapes(
  390. batch_size,
  391. past_seq_len,
  392. seq_len,
  393. model.config,
  394. model_class=model_class,
  395. )
  396. Gpt2Helper.auto_increase_buffer_size(output_buffers, output_shapes)
  397. (onnx_io_output, avg_latency_ms,) = Gpt2Helper.onnxruntime_inference_with_binded_io(
  398. session,
  399. onnx_io_runner.get_inputs(),
  400. output_buffers,
  401. output_shapes,
  402. total_runs=1,
  403. return_numpy=False,
  404. include_copy_output_latency=True,
  405. )
  406. onnx_io_metric.add_latency(past_seq_len, avg_latency_ms / 1000.0)
  407. if test_data_saved < save_test_data:
  408. onnx_io_runner.save_test_data(session, onnx_io_output, save_test_data_dir, test_data_saved)
  409. test_data_saved += 1
  410. onnx_io_runner.update(onnx_io_output, step, device)
  411. if verbose:
  412. onnx_runner.diff(onnx_io_runner)
  413. Gpt2Tester.diff_present(onnx_output, onnx_io_output, n_layer)
  414. print("Top 1 tokens:")
  415. print("\tTorch", torch_runner.top_1_tokens)
  416. print("\tONNX", onnx_runner.top_1_tokens)
  417. print("\tONNX with IO binding", onnx_io_runner.top_1_tokens)
  418. onnx_metric.eval_batch(torch_runner, onnx_runner, past_seq_len, verbose=verbose)
  419. onnx_io_metric.eval_batch(torch_runner, onnx_io_runner, past_seq_len, verbose=verbose)
  420. done = done | (torch_runner.top_1_tokens == eos_token_id).any()
  421. if torch.all(done):
  422. break
  423. onnx_metric.end_batch()
  424. onnx_io_metric.end_batch()
  425. torch_metric.print()
  426. onnx_metric.print()
  427. onnx_io_metric.print()