m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

243 lines
7.3 KiB

# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# It is a tool to compare the inference results of the original model and optimized model.
import argparse
import statistics
from pathlib import Path
import numpy as np
import psutil
from bert_perf_test import create_session, onnxruntime_inference
from bert_test_data import generate_test_data, get_bert_inputs, output_test_data
def run_model(model_path, all_inputs, use_gpu, disable_optimization):
import onnxruntime
graph_optimization_level = None
if disable_optimization:
graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
intra_op_num_threads = psutil.cpu_count(logical=False)
session = create_session(
model_path, use_gpu, "cuda" if use_gpu else "cpu", intra_op_num_threads, graph_optimization_level
)
output_names = [output.name for output in session.get_outputs()]
results, latency_list = onnxruntime_inference(session, all_inputs, output_names)
return results, latency_list, output_names
def compare(baseline_results, treatment_results, verbose, rtol=1e-3, atol=1e-4):
# Validate the output of baseline and treatment, to make sure the results are similar.
diff_count = 0
max_rel_diff = 0
max_abs_diff = 0
for test_case_id, results in enumerate(baseline_results):
case_passed = True
for i in range(len(results)):
treatment_output = treatment_results[test_case_id][i]
rel_diff = np.amax(np.abs((treatment_output - results[i]) / results[i]))
abs_diff = np.amax(np.abs(treatment_output - results[i]))
max_rel_diff = max(max_rel_diff, rel_diff)
max_abs_diff = max(max_abs_diff, abs_diff)
if not np.allclose(results[i].tolist(), treatment_output.tolist(), rtol=rtol, atol=atol):
if case_passed:
case_passed = False
diff_count += 1
if verbose:
print("case {} output {}".format(test_case_id, i))
print("baseline={}\ntreatment={}".format(results[i].tolist(), treatment_output))
print("rel_diff={} abs_diff={}".format(rel_diff, abs_diff))
if diff_count == 0:
print(
"100% passed for {} random inputs given thresholds (rtol={}, atol={}).".format(
len(baseline_results), rtol, atol
)
)
else:
print(
"WARNING: {} out of {} results NOT passed for thresholds (rtol={}, atol={}).".format(
diff_count, len(baseline_results), rtol, atol
)
)
print("maximum absolute difference={}".format(max_abs_diff))
print("maximum relative difference={}".format(max_rel_diff))
def run_test(
baseline_model,
optimized_model,
output_dir,
batch_size,
sequence_length,
use_gpu,
test_cases,
seed,
verbose,
rtol,
atol,
input_ids_name,
segment_ids_name,
input_mask_name,
):
# Try deduce input names from optimized model.
input_ids, segment_ids, input_mask = get_bert_inputs(
optimized_model, input_ids_name, segment_ids_name, input_mask_name
)
# Use random mask length for accuracy test. It might introduce slight inflation in latency reported in this script.
all_inputs = generate_test_data(
batch_size,
sequence_length,
test_cases,
seed,
verbose,
input_ids,
segment_ids,
input_mask,
random_mask_length=True,
)
baseline_results, baseline_latency, output_names = run_model(
baseline_model, all_inputs, use_gpu, disable_optimization=True
)
if verbose:
print(
"baseline average latency (all optimizations disabled): {} ms".format(
statistics.mean(baseline_latency) * 1000
)
)
if output_dir is not None:
for i, inputs in enumerate(all_inputs):
output_test_data(output_dir, i, inputs)
treatment_results, treatment_latency, treatment_output_names = run_model(
optimized_model, all_inputs, use_gpu, disable_optimization=False
)
if verbose:
print("treatment average latency: {} ms".format(statistics.mean(treatment_latency) * 1000))
# Validate the output of baseline and treatment, to make sure the results are similar.
compare(baseline_results, treatment_results, verbose, rtol, atol)
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--baseline_model", required=True, type=str, help="baseline onnx model path.")
parser.add_argument(
"--optimized_model",
required=True,
type=str,
default=None,
help="path of the optimized model. It shall have same inputs as the baseline model.",
)
parser.add_argument(
"--output_dir",
required=False,
type=str,
default=None,
help="output test data path. If not specified, test data will not be saved.",
)
parser.add_argument("--batch_size", required=True, type=int, help="batch size of input")
parser.add_argument(
"--sequence_length",
required=True,
type=int,
help="maximum sequence length of input",
)
parser.add_argument("--rtol", required=False, type=float, default=1e-3, help="relative tolerance")
parser.add_argument("--atol", required=False, type=float, default=1e-4, help="absolute tolerance")
parser.add_argument(
"--samples",
required=False,
type=int,
default=100,
help="number of test cases to be generated",
)
parser.add_argument("--seed", required=False, type=int, default=3, help="random seed")
parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU")
parser.set_defaults(use_gpu=False)
parser.add_argument(
"--verbose",
required=False,
action="store_true",
help="print verbose information",
)
parser.set_defaults(verbose=False)
parser.add_argument(
"--input_ids",
required=False,
type=str,
default=None,
help="input name for input ids",
)
parser.add_argument(
"--segment_ids",
required=False,
type=str,
default=None,
help="input name for segment ids",
)
parser.add_argument(
"--input_mask",
required=False,
type=str,
default=None,
help="input name for attention mask",
)
args = parser.parse_args()
return args
def main():
args = parse_arguments()
if args.output_dir is not None:
# create the output directory if not existed
path = Path(args.output_dir)
path.mkdir(parents=True, exist_ok=True)
run_test(
args.baseline_model,
args.optimized_model,
args.output_dir,
args.batch_size,
args.sequence_length,
args.use_gpu,
args.samples,
args.seed,
args.verbose,
args.rtol,
args.atol,
args.input_ids,
args.segment_ids,
args.input_mask,
)
if __name__ == "__main__":
main()