m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

561 lines
20 KiB

# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import csv
import logging
import os
import random
import timeit
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from enum import Enum
from time import sleep
from typing import Any, Dict, List, Optional
import coloredlogs
import numpy
import torch
import transformers
from packaging import version
import onnxruntime
logger = logging.getLogger(__name__)
class Precision(Enum):
FLOAT32 = "fp32"
FLOAT16 = "fp16"
INT8 = "int8"
def __str__(self):
return self.value
class OptimizerInfo(Enum):
# no_opt means using the raw ONNX model, but OnnxRuntime might still apply optimization as long as
# graph optimization level is not 0 (disable all).
NOOPT = "no_opt"
BYORT = "by_ort"
BYSCRIPT = "by_script"
def __str__(self):
return self.value
class ConfigModifier:
def __init__(self, num_layers):
self.num_layers = num_layers
def modify(self, config):
if self.num_layers is None:
return
if hasattr(config, "num_hidden_layers"):
config.num_hidden_layers = self.num_layers
logger.info(f"Modifying pytorch model's number of hidden layers to: {self.num_layers}")
if hasattr(config, "encoder_layers"):
config.encoder_layers = self.num_layers
logger.info(f"Modifying pytorch model's number of encoder layers to: {self.num_layers}")
if hasattr(config, "decoder_layers "):
config.decoder_layers = self.num_layers
logger.info(f"Modifying pytorch model's number of decoder layers to: {self.num_layers}")
def get_layer_num(self):
return self.num_layers
IO_BINDING_DATA_TYPE_MAP = {
"float32": numpy.float32,
# TODO: Add more.
}
def create_onnxruntime_session(
onnx_model_path,
use_gpu,
provider=None,
enable_all_optimization=True,
num_threads=-1,
enable_profiling=False,
verbose=False,
provider_options={}, # map execution provider name to its option
):
session = None
try:
sess_options = onnxruntime.SessionOptions()
if enable_all_optimization:
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
else:
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC
if enable_profiling:
sess_options.enable_profiling = True
if num_threads > 0:
sess_options.intra_op_num_threads = num_threads
logger.debug(f"Session option: intra_op_num_threads={sess_options.intra_op_num_threads}")
if verbose:
sess_options.log_severity_level = 0
else:
sess_options.log_severity_level = 4
logger.debug(f"Create session for onnx model: {onnx_model_path}")
if use_gpu:
if provider == "dml":
providers = ["DmlExecutionProvider", "CPUExecutionProvider"]
elif provider == "rocm":
providers = ["ROCMExecutionProvider", "CPUExecutionProvider"]
elif provider == "migraphx":
providers = [
"MIGraphXExecutionProvider",
"ROCMExecutionProvider",
"CPUExecutionProvider",
]
elif provider == "cuda":
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
elif provider == "tensorrt":
providers = [
"TensorrtExecutionProvider",
"CUDAExecutionProvider",
"CPUExecutionProvider",
]
else:
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
else:
providers = ["CPUExecutionProvider"]
if provider_options:
providers = [(name, provider_options[name]) if name in provider_options else name for name in providers]
session = onnxruntime.InferenceSession(onnx_model_path, sess_options, providers=providers)
except:
logger.error("Exception", exc_info=True)
return session
def setup_logger(verbose=True):
if verbose:
coloredlogs.install(
level="DEBUG",
fmt="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s",
)
else:
coloredlogs.install(fmt="%(message)s")
logging.getLogger("transformers").setLevel(logging.WARNING)
def prepare_environment(cache_dir, output_dir, use_gpu, provider=None):
if cache_dir and not os.path.exists(cache_dir):
os.makedirs(cache_dir)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
if use_gpu:
if provider == "dml":
assert (
"DmlExecutionProvider" in onnxruntime.get_available_providers()
), "Please install onnxruntime-directml package to test GPU inference."
else:
assert not set(onnxruntime.get_available_providers()).isdisjoint(
["CUDAExecutionProvider", "ROCMExecutionProvider", "MIGraphXExecutionProvider"]
), "Please install onnxruntime-gpu package, or install ROCm support, to test GPU inference."
logger.info(f"PyTorch Version:{torch.__version__}")
logger.info(f"Transformers Version:{transformers.__version__}")
logger.info(f"Onnxruntime Version:{onnxruntime.__version__}")
# Support three major versions of PyTorch and OnnxRuntime, and up to 9 months of transformers.
assert version.parse(torch.__version__) >= version.parse("1.10.0")
assert version.parse(transformers.__version__) >= version.parse("4.12.0")
assert version.parse(onnxruntime.__version__) >= version.parse("1.10.0")
def get_latency_result(latency_list, batch_size):
latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
latency_variance = numpy.var(latency_list, dtype=numpy.float64) * 1000.0
throughput = batch_size * (1000.0 / latency_ms)
return {
"test_times": len(latency_list),
"latency_variance": "{:.2f}".format(latency_variance),
"latency_90_percentile": "{:.2f}".format(numpy.percentile(latency_list, 90) * 1000.0),
"latency_95_percentile": "{:.2f}".format(numpy.percentile(latency_list, 95) * 1000.0),
"latency_99_percentile": "{:.2f}".format(numpy.percentile(latency_list, 99) * 1000.0),
"average_latency_ms": "{:.2f}".format(latency_ms),
"QPS": "{:.2f}".format(throughput),
}
def output_details(results, csv_filename):
with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
column_names = [
"engine",
"version",
"providers",
"device",
"precision",
"optimizer",
"io_binding",
"model_name",
"inputs",
"threads",
"batch_size",
"sequence_length",
"custom_layer_num",
"datetime",
"test_times",
"QPS",
"average_latency_ms",
"latency_variance",
"latency_90_percentile",
"latency_95_percentile",
"latency_99_percentile",
]
csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
csv_writer.writeheader()
for result in results:
csv_writer.writerow(result)
logger.info(f"Detail results are saved to csv file: {csv_filename}")
def output_summary(results, csv_filename, args):
with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
header_names = [
"model_name",
"inputs",
"custom_layer_num",
"engine",
"version",
"providers",
"device",
"precision",
"optimizer",
"io_binding",
"threads",
]
data_names = []
for batch_size in args.batch_sizes:
for sequence_length in args.sequence_lengths:
data_names.append(f"b{batch_size}_s{sequence_length}")
csv_writer = csv.DictWriter(csv_file, fieldnames=header_names + data_names)
csv_writer.writeheader()
for model_name in args.models:
for input_count in [1, 2, 3]:
for engine_name in args.engines:
for io_binding in [True, False, ""]:
for threads in args.num_threads:
row = {}
for result in results:
if (
result["model_name"] == model_name
and result["inputs"] == input_count
and result["engine"] == engine_name
and result["io_binding"] == io_binding
and result["threads"] == threads
):
headers = {k: v for k, v in result.items() if k in header_names}
if not row:
row.update(headers)
row.update({k: "" for k in data_names})
else:
for k in header_names:
assert row[k] == headers[k]
b = result["batch_size"]
s = result["sequence_length"]
row[f"b{b}_s{s}"] = result["average_latency_ms"]
if row:
csv_writer.writerow(row)
logger.info(f"Summary results are saved to csv file: {csv_filename}")
def output_fusion_statistics(model_fusion_statistics, csv_filename):
with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
column_names = ["model_filename", "datetime", "transformers", "torch"] + list(
next(iter(model_fusion_statistics.values())).keys()
)
csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
csv_writer.writeheader()
for key in model_fusion_statistics.keys():
model_fusion_statistics[key]["datetime"] = str(datetime.now())
model_fusion_statistics[key]["transformers"] = transformers.__version__
model_fusion_statistics[key]["torch"] = torch.__version__
model_fusion_statistics[key]["model_filename"] = key
csv_writer.writerow(model_fusion_statistics[key])
logger.info(f"Fusion statistics is saved to csv file: {csv_filename}")
def inference_ort(ort_session, ort_inputs, result_template, repeat_times, batch_size, warm_up_repeat=0):
result = {}
timeit.repeat(lambda: ort_session.run(None, ort_inputs), number=1, repeat=warm_up_repeat) # Dry run
latency_list = timeit.repeat(lambda: ort_session.run(None, ort_inputs), number=1, repeat=repeat_times)
result.update(result_template)
result.update({"io_binding": False})
result.update(get_latency_result(latency_list, batch_size))
return result
def inference_ort_with_io_binding(
ort_session,
ort_inputs,
result_template,
repeat_times,
ort_output_names,
ort_outputs,
output_buffers,
output_buffer_max_sizes,
batch_size,
device,
data_type=numpy.longlong,
warm_up_repeat=0,
):
result = {}
# Bind inputs and outputs to onnxruntime session
io_binding = ort_session.io_binding()
# Bind inputs to device
for name in ort_inputs.keys():
np_input = torch.from_numpy(ort_inputs[name]).to(device)
input_type = (
IO_BINDING_DATA_TYPE_MAP[str(ort_inputs[name].dtype)]
if str(ort_inputs[name].dtype) in IO_BINDING_DATA_TYPE_MAP
else data_type
)
io_binding.bind_input(
name,
np_input.device.type,
0,
input_type,
np_input.shape,
np_input.data_ptr(),
)
# Bind outputs buffers with the sizes needed if not allocated already
if len(output_buffers) == 0:
allocateOutputBuffers(output_buffers, output_buffer_max_sizes, device)
for i, ort_output_name in enumerate(ort_output_names):
io_binding.bind_output(
ort_output_name,
output_buffers[i].device.type,
0,
numpy.float32,
ort_outputs[i].shape,
output_buffers[i].data_ptr(),
)
timeit.repeat(
lambda: ort_session.run_with_iobinding(io_binding),
number=1,
repeat=warm_up_repeat,
) # Dry run
latency_list = timeit.repeat(
lambda: ort_session.run_with_iobinding(io_binding),
number=1,
repeat=repeat_times,
)
result.update(result_template)
result.update({"io_binding": True})
result.update(get_latency_result(latency_list, batch_size))
return result
def allocateOutputBuffers(output_buffers, output_buffer_max_sizes, device):
# Allocate output tensors with the largest test size needed. So the allocated memory can be reused
# for each test run.
for i in output_buffer_max_sizes:
output_buffers.append(torch.empty(i, dtype=torch.float32, device=device))
def set_random_seed(seed=123):
"""Set random seed manually to get deterministic results"""
random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# torch.backends.cudnn.enabled = False
# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.deterministic = True
def get_gpu_info() -> Optional[List[Dict[str, Any]]]:
from py3nvml.py3nvml import (
NVMLError,
nvmlDeviceGetCount,
nvmlDeviceGetHandleByIndex,
nvmlDeviceGetMemoryInfo,
nvmlDeviceGetName,
nvmlInit,
nvmlShutdown,
)
try:
nvmlInit()
result = []
device_count = nvmlDeviceGetCount()
if not isinstance(device_count, int):
return None
for i in range(device_count):
info = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i))
if isinstance(info, str):
return None
result.append(
{
"id": i,
"name": nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)),
"total": info.total,
"free": info.free,
"used": info.used,
}
)
nvmlShutdown()
return result
except NVMLError as error:
print("Error fetching GPU information using nvml: %s", error)
return None
def measure_memory(is_gpu, func):
class MemoryMonitor:
def __init__(self, keep_measuring=True):
self.keep_measuring = keep_measuring
def measure_cpu_usage(self):
import psutil
max_usage = 0
while True:
max_usage = max(max_usage, psutil.Process(os.getpid()).memory_info().rss / 1024**2)
sleep(0.005) # 5ms
if not self.keep_measuring:
break
return max_usage
def measure_gpu_usage(self) -> Optional[List[Dict[str, Any]]]:
from py3nvml.py3nvml import (
NVMLError,
nvmlDeviceGetCount,
nvmlDeviceGetHandleByIndex,
nvmlDeviceGetMemoryInfo,
nvmlDeviceGetName,
nvmlInit,
nvmlShutdown,
)
max_gpu_usage = []
gpu_name = []
try:
nvmlInit()
device_count = nvmlDeviceGetCount()
if not isinstance(device_count, int):
logger.error(f"nvmlDeviceGetCount result is not integer: {device_count}")
return None
max_gpu_usage = [0 for i in range(device_count)]
gpu_name = [nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)) for i in range(device_count)]
while True:
for i in range(device_count):
info = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i))
if isinstance(info, str):
logger.error(f"nvmlDeviceGetMemoryInfo returns str: {info}")
return None
max_gpu_usage[i] = max(max_gpu_usage[i], info.used / 1024**2)
sleep(0.005) # 5ms
if not self.keep_measuring:
break
nvmlShutdown()
return [
{
"device_id": i,
"name": gpu_name[i],
"max_used_MB": max_gpu_usage[i],
}
for i in range(device_count)
]
except NVMLError as error:
logger.error("Error fetching GPU information using nvml: %s", error)
return None
monitor = MemoryMonitor(False)
if is_gpu:
memory_before_test = monitor.measure_gpu_usage()
if memory_before_test is None:
return None
with ThreadPoolExecutor() as executor:
monitor = MemoryMonitor()
mem_thread = executor.submit(monitor.measure_gpu_usage)
try:
fn_thread = executor.submit(func)
_ = fn_thread.result()
finally:
monitor.keep_measuring = False
max_usage = mem_thread.result()
if max_usage is None:
return None
print(f"GPU memory usage: before={memory_before_test} peak={max_usage}")
if len(memory_before_test) >= 1 and len(max_usage) >= 1 and len(memory_before_test) == len(max_usage):
# When there are multiple GPUs, we will check the one with maximum usage.
max_used = 0
for i, memory_before in enumerate(memory_before_test):
before = memory_before["max_used_MB"]
after = max_usage[i]["max_used_MB"]
used = after - before
max_used = max(max_used, used)
return max_used
return None
# CPU memory
memory_before_test = monitor.measure_cpu_usage()
with ThreadPoolExecutor() as executor:
monitor = MemoryMonitor()
mem_thread = executor.submit(monitor.measure_cpu_usage)
try:
fn_thread = executor.submit(func)
_ = fn_thread.result()
finally:
monitor.keep_measuring = False
max_usage = mem_thread.result()
print(f"CPU memory usage: before={memory_before_test:.1f} MB, peak={max_usage:.1f} MB")
return max_usage - memory_before_test
def get_ort_environment_variables():
# Environment variables might impact ORT performance on transformer models. Note that they are for testing only.
env_names = [
"ORT_DISABLE_FUSED_ATTENTION",
"ORT_ENABLE_FUSED_CAUSAL_ATTENTION",
"ORT_DISABLE_FUSED_CROSS_ATTENTION",
"ORT_DISABLE_TRT_FLASH_ATTENTION",
"ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION",
"ORT_TRANSFORMER_OPTIONS",
"ORT_CUDA_GEMM_OPTIONS",
]
env = ""
for name in env_names:
value = os.getenv(name)
if value is None:
continue
if env:
env += ","
env += f"{name}={value}"
return env