# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- # This script helps evaluation of GPT-2 model. import logging import math import os import statistics import sys import timeit import numpy import torch from gpt2_helper import Gpt2Helper, Gpt2Inputs sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) from benchmark_helper import Precision logger = logging.getLogger(__name__) class Gpt2Metric: def __init__(self, treatment_name, baseline_name="Torch", top_k=20): assert top_k > 1 and top_k <= 100 self.baseline = baseline_name self.treatment = treatment_name self.name: str = f"{treatment_name} vs {baseline_name}" self.top_k = top_k self.top_1_error: int = 0 self.top_k_error: int = 0 self.total_samples: int = 0 self.max_logits_diff: float = 0 # for non-empty past state self.max_logits_diff_no_past: float = 0 # for empty past state self.batch_top1_error: torch.FloatTensor = None # top 1 error for current batch self.batch_topk_error: torch.FloatTensor = None # top k error for current batch self.seq_len_latency = {} def print(self): if self.baseline != self.treatment: print("---") print(f"Metrics for {self.treatment} (baseline={self.baseline}):") if self.total_samples > 0: top_1_error_rate = 100.0 * self.top_1_error / self.total_samples top_k_error_rate = 100.0 * self.top_k_error / self.total_samples print( f"Total={self.total_samples} Top1Error={self.top_1_error} ({top_1_error_rate:.2f}%) Top{self.top_k}Error={self.top_k_error} ({top_k_error_rate:.2f}%)" ) print("Max logits diffs:") print(f"\twith past = {self.max_logits_diff:.6f}") print(f"\tempty past = {self.max_logits_diff_no_past:.6f}") else: print(f"Metrics for {self.treatment} (baseline):") if self.seq_len_latency: print("Past sequence length range and average latency:") total = 0 count = 0 for key in sorted(self.seq_len_latency.keys()): average = statistics.mean(self.seq_len_latency[key]) * 1000.0 if key == 0: print("\t{}: \t{:.2f} ms".format(key, average)) else: print("\t[{}, {}]:\t{:.2f} ms".format(2**key, 2 ** (key + 1) - 1, average)) total += average * len(self.seq_len_latency[key]) count += len(self.seq_len_latency[key]) print("Average Latency: {:.2f} ms".format(total / count)) def diff_logits(self, baseline_logits, treatment_logits, is_empty_past: bool): diff = (baseline_logits - treatment_logits).abs().max() if is_empty_past: self.max_logits_diff_no_past = max(self.max_logits_diff_no_past, diff) else: self.max_logits_diff = max(self.max_logits_diff, diff) return diff def start_batch(self, batch_size: int): self.total_samples += batch_size self.batch_top1_error = torch.zeros((batch_size, 1), dtype=torch.bool) self.batch_topk_error = torch.zeros((batch_size, 1), dtype=torch.bool) def eval_batch(self, baseline, treatment, past_seq_len, verbose=True): self._eval_topk(baseline.top_1_tokens, treatment.top_1_tokens, 1, verbose) self._eval_topk(baseline.top_k_tokens, treatment.top_k_tokens, self.top_k, verbose) max_diff = self.diff_logits(baseline.logits, treatment.logits, past_seq_len == 0) if verbose: print(f"Max logits diffs of {self.name}: {max_diff}") def _eval_topk(self, baseline_topk, treatment_topk, top_k, verbose=True): if not torch.all(torch.eq(baseline_topk, treatment_topk)): if top_k == 1: if verbose: print(f"Generated tokens not matched for {self.name}") self.batch_top1_error |= torch.eq(baseline_topk, treatment_topk).logical_not() else: if verbose: print( f"Top {top_k} tokens not matched for {self.name}. This will lead to wrong beam search results" ) self.batch_topk_error |= ( torch.eq(baseline_topk, treatment_topk).logical_not().sum(1).unsqueeze(dim=1) > 0 ) def end_batch(self): self.top_1_error += self.batch_top1_error.sum() self.top_k_error += self.batch_topk_error.sum() def add_latency(self, past_seq_len, latency): key = int(math.log2(past_seq_len)) + 1 if past_seq_len > 0 else 0 if key not in self.seq_len_latency: self.seq_len_latency[key] = [] self.seq_len_latency[key].append(latency) class Gpt2Tester: def __init__( self, input_ids, position_ids, attention_mask, num_attention_heads, hidden_size, num_layer, device, is_fp16=False, top_k=20, top_k_required_order=False, ): self.batch_size = input_ids.shape[0] self.input_length = input_ids.shape[1] self.n_layer = num_layer self.input_ids = input_ids self.position_ids = position_ids self.attention_mask = attention_mask self.has_position_ids = position_ids is not None self.has_attention_mask = attention_mask is not None # Emtpy past state for first inference self.past = [] past_shape = [ 2, self.batch_size, num_attention_heads, 0, hidden_size // num_attention_heads, ] for i in range(num_layer): empty_past = torch.empty(past_shape).type(torch.float16 if is_fp16 else torch.float32) self.past.append(empty_past.to(device)) self.logits = None self.top_1_tokens = None self.top_k_tokens = None self.top_k = top_k self.top_k_required_order = top_k_required_order def get_inputs(self) -> Gpt2Inputs: return Gpt2Inputs(self.input_ids, self.position_ids, self.attention_mask, self.past) def save_test_data(self, session, output, save_test_data_dir, test_case_id): from onnx import numpy_helper path = os.path.join(save_test_data_dir, "test_data_set_" + str(test_case_id)) if os.path.exists(path): print(f"Directory {path} existed. Skip saving test data") return os.makedirs(path, exist_ok=True) def add_tensor(input_tensors, torch_tensor, name): input_tensors.append(numpy_helper.from_array(torch_tensor.clone().cpu().numpy(), name)) input_tensors = [] add_tensor(input_tensors, self.input_ids, "input_ids") if self.has_position_ids: add_tensor(input_tensors, self.position_ids, "position_ids") if self.has_attention_mask: add_tensor(input_tensors, self.attention_mask, "attention_mask") for i in range(self.n_layer): add_tensor(input_tensors, self.past[i], "past_" + str(i)) for i, tensor in enumerate(input_tensors): with open(os.path.join(path, "input_{}.pb".format(i)), "wb") as f: f.write(tensor.SerializeToString()) output_names = [output.name for output in session.get_outputs()] for i, name in enumerate(output_names): tensor = numpy_helper.from_array( output[i] if isinstance(output[i], numpy.ndarray) else output[i].clone().cpu().numpy() ) with open(os.path.join(path, "output_{}.pb".format(i)), "wb") as f: f.write(tensor.SerializeToString()) print(f"Test data saved to directory {path}") def update(self, output, step, device): """ Update the inputs for next inference. """ self.logits = ( torch.from_numpy(output[0]) if isinstance(output[0], numpy.ndarray) else output[0].clone().detach().cpu() ) self.top_1_tokens = Gpt2Tester.predict_next_token(self.logits) self.top_k_tokens = Gpt2Tester.predict_next_token(self.logits, self.top_k, self.top_k_required_order) self.input_ids = self.top_1_tokens.clone().detach().reshape([self.batch_size, 1]).to(device) if self.has_position_ids: self.position_ids = ( torch.tensor([self.input_length + step - 1]).unsqueeze(0).repeat(self.batch_size, 1).to(device) ) if self.has_attention_mask: self.attention_mask = torch.cat( [ self.attention_mask, torch.ones([self.batch_size, 1]).type_as(self.attention_mask), ], 1, ).to(device) self.past = [] if isinstance(output[1], tuple): # past in torch output is tuple self.past = list(output[1]) else: for i in range(self.n_layer): past_i = ( torch.from_numpy(output[i + 1]) if isinstance(output[i + 1], numpy.ndarray) else output[i + 1].clone().detach() ) self.past.append(past_i.to(device)) def diff(self, baseline): """ Compare inputs and logits output. """ print("start diff...") if self.logits is not None: max_io_diff = (self.logits - baseline.logits).abs().max() if max_io_diff > 1e-4: print(f"Max logits difference is too large: {max_io_diff}") if not torch.all(self.input_ids == baseline.input_ids): print("Input_ids is different", self.input_ids, baseline.input_ids) if self.has_position_ids: if not torch.all(self.position_ids == baseline.position_ids): print( "position_ids is different", self.position_ids, baseline.position_ids, ) if self.has_attention_mask: if not torch.all(self.attention_mask == baseline.attention_mask): print( "attention_mask is different", self.attention_mask, baseline.attention_mask, ) assert len(self.past) == len(baseline.past) for i, past_i in enumerate(self.past): assert past_i.shape == baseline.past[i].shape if past_i.nelement() > 0: max_past_diff = (past_i - baseline.past[i]).abs().max() if max_past_diff > 1e-4: print(f"max_past_diff[{i}]={max_past_diff}") @staticmethod def predict_next_token(logits, top_k=1, required_order=False): """ Get top k topkens based on logits. """ # logits has shape (batch_size, seq_len, vocab_size) # last token logits has shape (batch_size, vocab_size) lastTokenLogits = logits[:, -1] if top_k == 1: generatedTokens = torch.argmax(lastTokenLogits, 1, True) return generatedTokens else: topk = torch.argsort(lastTokenLogits, -1, descending=True)[:, :top_k] if not required_order: sorted_topk, _ = topk.sort() return sorted_topk return topk @staticmethod def diff_present(onnx_output, onnx_io_output, n_layer): """ Compare the present outputs of two outputs from ONNX Runtime. """ present_diff_max = [] for i in range(n_layer): onnx_present_i = ( torch.from_numpy(onnx_output[i + 1]) if isinstance(onnx_output[i + 1], numpy.ndarray) else onnx_output[i + 1] ) onnx_io_present_i = ( torch.from_numpy(onnx_io_output[i + 1]) if isinstance(onnx_io_output[i + 1], numpy.ndarray) else onnx_io_output[i + 1] ) max_diff = (onnx_present_i - onnx_io_present_i).abs().max() present_diff_max.append(max_diff) print(f"present_diff_max={present_diff_max}") @staticmethod def is_quantized_onnx_model(onnx_model_path): """ Returns True if the ONNX model is quantized. """ from onnx import load model = load(onnx_model_path) from onnxruntime.quantization.quantize import __producer__ as quantize_producer return model.producer_name == quantize_producer @staticmethod def test_generation( session, model, device, test_inputs, precision=Precision.FLOAT32, model_class="Gpt2LMHeadModel", top_k=20, top_k_no_order=True, max_steps=24, max_inputs=0, verbose=False, save_test_data=0, save_test_data_dir=".", ): """ Test Generation using greedy beam search (without sampling) to compare PyTorch and ONNX model. It will print top 1 and top k errors on the given test inputs. """ print( f"start test generation: (top_k={top_k} top_k_no_order={top_k_no_order} max_steps={max_steps} test_inputs={len(test_inputs)} max_inputs={max_inputs})" ) n_layer = model.config.n_layer n_head = model.config.n_head n_embd = model.config.n_embd eos_token_id = model.config.eos_token_id test_data_saved = 0 is_float16 = precision == Precision.FLOAT16 if is_float16: assert "float16" in session.get_outputs()[0].type # We will still use fp32 torch model as baseline when onnx model if fp16 model.eval().to(device) # Allocate initial buffers for IO Binding of ONNX Runtimne. The buffer size will automatically increase later. init_output_shapes = Gpt2Helper.get_output_shapes( batch_size=4, past_sequence_length=128, sequence_length=32, config=model.config, model_class=model_class, ) output_buffers = Gpt2Helper.get_output_buffers(init_output_shapes, device, is_float16=is_float16) baseline_name = "Torch" treatment_name = "Quantized Onnx" if precision == Precision.INT8 else "Onnx" torch_metric = Gpt2Metric(baseline_name, baseline_name, top_k) onnx_metric = Gpt2Metric(treatment_name, baseline_name, top_k) onnx_io_metric = Gpt2Metric(treatment_name + " with IO Binding", baseline_name, top_k) for i, inputs in enumerate(test_inputs): if max_inputs > 0 and i == max_inputs: break if i % 10 == 0: print(f"{i}") input_ids = inputs["input_ids"] position_ids = inputs["position_ids"] if "position_ids" in inputs else None attention_mask = inputs["attention_mask"] if "attention_mask" in inputs else None onnx_runner = Gpt2Tester( input_ids, position_ids, attention_mask, n_head, n_embd, n_layer, device, is_float16, top_k, not top_k_no_order, ) onnx_io_runner = Gpt2Tester( input_ids, position_ids, attention_mask, n_head, n_embd, n_layer, device, is_float16, top_k, not top_k_no_order, ) torch_runner = Gpt2Tester( input_ids, position_ids, attention_mask, n_head, n_embd, n_layer, device, False, top_k, not top_k_no_order, ) # Torch model baseline is fp32 batch_size = torch_runner.batch_size onnx_metric.start_batch(batch_size) onnx_io_metric.start_batch(batch_size) with torch.no_grad(): done = torch.zeros(batch_size, dtype=torch.bool) for step in range(max_steps): seq_len = list(onnx_runner.input_ids.size())[1] past_seq_len = list(onnx_runner.past[0].size())[3] start_time = timeit.default_timer() pytorch_output = Gpt2Helper.pytorch_inference(model, torch_runner.get_inputs()) torch_metric.add_latency(past_seq_len, timeit.default_timer() - start_time) torch_runner.update(pytorch_output, step, device) onnx_output, avg_latency_ms = Gpt2Helper.onnxruntime_inference( session, onnx_runner.get_inputs(), total_runs=1 ) onnx_metric.add_latency(past_seq_len, avg_latency_ms / 1000.0) onnx_runner.update(onnx_output, step, device) output_shapes = Gpt2Helper.get_output_shapes( batch_size, past_seq_len, seq_len, model.config, model_class=model_class, ) Gpt2Helper.auto_increase_buffer_size(output_buffers, output_shapes) (onnx_io_output, avg_latency_ms,) = Gpt2Helper.onnxruntime_inference_with_binded_io( session, onnx_io_runner.get_inputs(), output_buffers, output_shapes, total_runs=1, return_numpy=False, include_copy_output_latency=True, ) onnx_io_metric.add_latency(past_seq_len, avg_latency_ms / 1000.0) if test_data_saved < save_test_data: onnx_io_runner.save_test_data(session, onnx_io_output, save_test_data_dir, test_data_saved) test_data_saved += 1 onnx_io_runner.update(onnx_io_output, step, device) if verbose: onnx_runner.diff(onnx_io_runner) Gpt2Tester.diff_present(onnx_output, onnx_io_output, n_layer) print("Top 1 tokens:") print("\tTorch", torch_runner.top_1_tokens) print("\tONNX", onnx_runner.top_1_tokens) print("\tONNX with IO binding", onnx_io_runner.top_1_tokens) onnx_metric.eval_batch(torch_runner, onnx_runner, past_seq_len, verbose=verbose) onnx_io_metric.eval_batch(torch_runner, onnx_io_runner, past_seq_len, verbose=verbose) done = done | (torch_runner.top_1_tokens == eos_token_id).any() if torch.all(done): break onnx_metric.end_batch() onnx_io_metric.end_batch() torch_metric.print() onnx_metric.print() onnx_io_metric.print()