m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

561 lines
20 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. import csv
  7. import logging
  8. import os
  9. import random
  10. import timeit
  11. from concurrent.futures import ThreadPoolExecutor
  12. from datetime import datetime
  13. from enum import Enum
  14. from time import sleep
  15. from typing import Any, Dict, List, Optional
  16. import coloredlogs
  17. import numpy
  18. import torch
  19. import transformers
  20. from packaging import version
  21. import onnxruntime
  22. logger = logging.getLogger(__name__)
  23. class Precision(Enum):
  24. FLOAT32 = "fp32"
  25. FLOAT16 = "fp16"
  26. INT8 = "int8"
  27. def __str__(self):
  28. return self.value
  29. class OptimizerInfo(Enum):
  30. # no_opt means using the raw ONNX model, but OnnxRuntime might still apply optimization as long as
  31. # graph optimization level is not 0 (disable all).
  32. NOOPT = "no_opt"
  33. BYORT = "by_ort"
  34. BYSCRIPT = "by_script"
  35. def __str__(self):
  36. return self.value
  37. class ConfigModifier:
  38. def __init__(self, num_layers):
  39. self.num_layers = num_layers
  40. def modify(self, config):
  41. if self.num_layers is None:
  42. return
  43. if hasattr(config, "num_hidden_layers"):
  44. config.num_hidden_layers = self.num_layers
  45. logger.info(f"Modifying pytorch model's number of hidden layers to: {self.num_layers}")
  46. if hasattr(config, "encoder_layers"):
  47. config.encoder_layers = self.num_layers
  48. logger.info(f"Modifying pytorch model's number of encoder layers to: {self.num_layers}")
  49. if hasattr(config, "decoder_layers "):
  50. config.decoder_layers = self.num_layers
  51. logger.info(f"Modifying pytorch model's number of decoder layers to: {self.num_layers}")
  52. def get_layer_num(self):
  53. return self.num_layers
  54. IO_BINDING_DATA_TYPE_MAP = {
  55. "float32": numpy.float32,
  56. # TODO: Add more.
  57. }
  58. def create_onnxruntime_session(
  59. onnx_model_path,
  60. use_gpu,
  61. provider=None,
  62. enable_all_optimization=True,
  63. num_threads=-1,
  64. enable_profiling=False,
  65. verbose=False,
  66. provider_options={}, # map execution provider name to its option
  67. ):
  68. session = None
  69. try:
  70. sess_options = onnxruntime.SessionOptions()
  71. if enable_all_optimization:
  72. sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
  73. else:
  74. sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC
  75. if enable_profiling:
  76. sess_options.enable_profiling = True
  77. if num_threads > 0:
  78. sess_options.intra_op_num_threads = num_threads
  79. logger.debug(f"Session option: intra_op_num_threads={sess_options.intra_op_num_threads}")
  80. if verbose:
  81. sess_options.log_severity_level = 0
  82. else:
  83. sess_options.log_severity_level = 4
  84. logger.debug(f"Create session for onnx model: {onnx_model_path}")
  85. if use_gpu:
  86. if provider == "dml":
  87. providers = ["DmlExecutionProvider", "CPUExecutionProvider"]
  88. elif provider == "rocm":
  89. providers = ["ROCMExecutionProvider", "CPUExecutionProvider"]
  90. elif provider == "migraphx":
  91. providers = [
  92. "MIGraphXExecutionProvider",
  93. "ROCMExecutionProvider",
  94. "CPUExecutionProvider",
  95. ]
  96. elif provider == "cuda":
  97. providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
  98. elif provider == "tensorrt":
  99. providers = [
  100. "TensorrtExecutionProvider",
  101. "CUDAExecutionProvider",
  102. "CPUExecutionProvider",
  103. ]
  104. else:
  105. providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
  106. else:
  107. providers = ["CPUExecutionProvider"]
  108. if provider_options:
  109. providers = [(name, provider_options[name]) if name in provider_options else name for name in providers]
  110. session = onnxruntime.InferenceSession(onnx_model_path, sess_options, providers=providers)
  111. except:
  112. logger.error("Exception", exc_info=True)
  113. return session
  114. def setup_logger(verbose=True):
  115. if verbose:
  116. coloredlogs.install(
  117. level="DEBUG",
  118. fmt="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s",
  119. )
  120. else:
  121. coloredlogs.install(fmt="%(message)s")
  122. logging.getLogger("transformers").setLevel(logging.WARNING)
  123. def prepare_environment(cache_dir, output_dir, use_gpu, provider=None):
  124. if cache_dir and not os.path.exists(cache_dir):
  125. os.makedirs(cache_dir)
  126. if output_dir and not os.path.exists(output_dir):
  127. os.makedirs(output_dir)
  128. if use_gpu:
  129. if provider == "dml":
  130. assert (
  131. "DmlExecutionProvider" in onnxruntime.get_available_providers()
  132. ), "Please install onnxruntime-directml package to test GPU inference."
  133. else:
  134. assert not set(onnxruntime.get_available_providers()).isdisjoint(
  135. ["CUDAExecutionProvider", "ROCMExecutionProvider", "MIGraphXExecutionProvider"]
  136. ), "Please install onnxruntime-gpu package, or install ROCm support, to test GPU inference."
  137. logger.info(f"PyTorch Version:{torch.__version__}")
  138. logger.info(f"Transformers Version:{transformers.__version__}")
  139. logger.info(f"Onnxruntime Version:{onnxruntime.__version__}")
  140. # Support three major versions of PyTorch and OnnxRuntime, and up to 9 months of transformers.
  141. assert version.parse(torch.__version__) >= version.parse("1.10.0")
  142. assert version.parse(transformers.__version__) >= version.parse("4.12.0")
  143. assert version.parse(onnxruntime.__version__) >= version.parse("1.10.0")
  144. def get_latency_result(latency_list, batch_size):
  145. latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
  146. latency_variance = numpy.var(latency_list, dtype=numpy.float64) * 1000.0
  147. throughput = batch_size * (1000.0 / latency_ms)
  148. return {
  149. "test_times": len(latency_list),
  150. "latency_variance": "{:.2f}".format(latency_variance),
  151. "latency_90_percentile": "{:.2f}".format(numpy.percentile(latency_list, 90) * 1000.0),
  152. "latency_95_percentile": "{:.2f}".format(numpy.percentile(latency_list, 95) * 1000.0),
  153. "latency_99_percentile": "{:.2f}".format(numpy.percentile(latency_list, 99) * 1000.0),
  154. "average_latency_ms": "{:.2f}".format(latency_ms),
  155. "QPS": "{:.2f}".format(throughput),
  156. }
  157. def output_details(results, csv_filename):
  158. with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
  159. column_names = [
  160. "engine",
  161. "version",
  162. "providers",
  163. "device",
  164. "precision",
  165. "optimizer",
  166. "io_binding",
  167. "model_name",
  168. "inputs",
  169. "threads",
  170. "batch_size",
  171. "sequence_length",
  172. "custom_layer_num",
  173. "datetime",
  174. "test_times",
  175. "QPS",
  176. "average_latency_ms",
  177. "latency_variance",
  178. "latency_90_percentile",
  179. "latency_95_percentile",
  180. "latency_99_percentile",
  181. ]
  182. csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
  183. csv_writer.writeheader()
  184. for result in results:
  185. csv_writer.writerow(result)
  186. logger.info(f"Detail results are saved to csv file: {csv_filename}")
  187. def output_summary(results, csv_filename, args):
  188. with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
  189. header_names = [
  190. "model_name",
  191. "inputs",
  192. "custom_layer_num",
  193. "engine",
  194. "version",
  195. "providers",
  196. "device",
  197. "precision",
  198. "optimizer",
  199. "io_binding",
  200. "threads",
  201. ]
  202. data_names = []
  203. for batch_size in args.batch_sizes:
  204. for sequence_length in args.sequence_lengths:
  205. data_names.append(f"b{batch_size}_s{sequence_length}")
  206. csv_writer = csv.DictWriter(csv_file, fieldnames=header_names + data_names)
  207. csv_writer.writeheader()
  208. for model_name in args.models:
  209. for input_count in [1, 2, 3]:
  210. for engine_name in args.engines:
  211. for io_binding in [True, False, ""]:
  212. for threads in args.num_threads:
  213. row = {}
  214. for result in results:
  215. if (
  216. result["model_name"] == model_name
  217. and result["inputs"] == input_count
  218. and result["engine"] == engine_name
  219. and result["io_binding"] == io_binding
  220. and result["threads"] == threads
  221. ):
  222. headers = {k: v for k, v in result.items() if k in header_names}
  223. if not row:
  224. row.update(headers)
  225. row.update({k: "" for k in data_names})
  226. else:
  227. for k in header_names:
  228. assert row[k] == headers[k]
  229. b = result["batch_size"]
  230. s = result["sequence_length"]
  231. row[f"b{b}_s{s}"] = result["average_latency_ms"]
  232. if row:
  233. csv_writer.writerow(row)
  234. logger.info(f"Summary results are saved to csv file: {csv_filename}")
  235. def output_fusion_statistics(model_fusion_statistics, csv_filename):
  236. with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
  237. column_names = ["model_filename", "datetime", "transformers", "torch"] + list(
  238. next(iter(model_fusion_statistics.values())).keys()
  239. )
  240. csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
  241. csv_writer.writeheader()
  242. for key in model_fusion_statistics.keys():
  243. model_fusion_statistics[key]["datetime"] = str(datetime.now())
  244. model_fusion_statistics[key]["transformers"] = transformers.__version__
  245. model_fusion_statistics[key]["torch"] = torch.__version__
  246. model_fusion_statistics[key]["model_filename"] = key
  247. csv_writer.writerow(model_fusion_statistics[key])
  248. logger.info(f"Fusion statistics is saved to csv file: {csv_filename}")
  249. def inference_ort(ort_session, ort_inputs, result_template, repeat_times, batch_size, warm_up_repeat=0):
  250. result = {}
  251. timeit.repeat(lambda: ort_session.run(None, ort_inputs), number=1, repeat=warm_up_repeat) # Dry run
  252. latency_list = timeit.repeat(lambda: ort_session.run(None, ort_inputs), number=1, repeat=repeat_times)
  253. result.update(result_template)
  254. result.update({"io_binding": False})
  255. result.update(get_latency_result(latency_list, batch_size))
  256. return result
  257. def inference_ort_with_io_binding(
  258. ort_session,
  259. ort_inputs,
  260. result_template,
  261. repeat_times,
  262. ort_output_names,
  263. ort_outputs,
  264. output_buffers,
  265. output_buffer_max_sizes,
  266. batch_size,
  267. device,
  268. data_type=numpy.longlong,
  269. warm_up_repeat=0,
  270. ):
  271. result = {}
  272. # Bind inputs and outputs to onnxruntime session
  273. io_binding = ort_session.io_binding()
  274. # Bind inputs to device
  275. for name in ort_inputs.keys():
  276. np_input = torch.from_numpy(ort_inputs[name]).to(device)
  277. input_type = (
  278. IO_BINDING_DATA_TYPE_MAP[str(ort_inputs[name].dtype)]
  279. if str(ort_inputs[name].dtype) in IO_BINDING_DATA_TYPE_MAP
  280. else data_type
  281. )
  282. io_binding.bind_input(
  283. name,
  284. np_input.device.type,
  285. 0,
  286. input_type,
  287. np_input.shape,
  288. np_input.data_ptr(),
  289. )
  290. # Bind outputs buffers with the sizes needed if not allocated already
  291. if len(output_buffers) == 0:
  292. allocateOutputBuffers(output_buffers, output_buffer_max_sizes, device)
  293. for i, ort_output_name in enumerate(ort_output_names):
  294. io_binding.bind_output(
  295. ort_output_name,
  296. output_buffers[i].device.type,
  297. 0,
  298. numpy.float32,
  299. ort_outputs[i].shape,
  300. output_buffers[i].data_ptr(),
  301. )
  302. timeit.repeat(
  303. lambda: ort_session.run_with_iobinding(io_binding),
  304. number=1,
  305. repeat=warm_up_repeat,
  306. ) # Dry run
  307. latency_list = timeit.repeat(
  308. lambda: ort_session.run_with_iobinding(io_binding),
  309. number=1,
  310. repeat=repeat_times,
  311. )
  312. result.update(result_template)
  313. result.update({"io_binding": True})
  314. result.update(get_latency_result(latency_list, batch_size))
  315. return result
  316. def allocateOutputBuffers(output_buffers, output_buffer_max_sizes, device):
  317. # Allocate output tensors with the largest test size needed. So the allocated memory can be reused
  318. # for each test run.
  319. for i in output_buffer_max_sizes:
  320. output_buffers.append(torch.empty(i, dtype=torch.float32, device=device))
  321. def set_random_seed(seed=123):
  322. """Set random seed manually to get deterministic results"""
  323. random.seed(seed)
  324. numpy.random.seed(seed)
  325. torch.manual_seed(seed)
  326. torch.cuda.manual_seed(seed)
  327. torch.cuda.manual_seed_all(seed)
  328. # torch.backends.cudnn.enabled = False
  329. # torch.backends.cudnn.benchmark = False
  330. # torch.backends.cudnn.deterministic = True
  331. def get_gpu_info() -> Optional[List[Dict[str, Any]]]:
  332. from py3nvml.py3nvml import (
  333. NVMLError,
  334. nvmlDeviceGetCount,
  335. nvmlDeviceGetHandleByIndex,
  336. nvmlDeviceGetMemoryInfo,
  337. nvmlDeviceGetName,
  338. nvmlInit,
  339. nvmlShutdown,
  340. )
  341. try:
  342. nvmlInit()
  343. result = []
  344. device_count = nvmlDeviceGetCount()
  345. if not isinstance(device_count, int):
  346. return None
  347. for i in range(device_count):
  348. info = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i))
  349. if isinstance(info, str):
  350. return None
  351. result.append(
  352. {
  353. "id": i,
  354. "name": nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)),
  355. "total": info.total,
  356. "free": info.free,
  357. "used": info.used,
  358. }
  359. )
  360. nvmlShutdown()
  361. return result
  362. except NVMLError as error:
  363. print("Error fetching GPU information using nvml: %s", error)
  364. return None
  365. def measure_memory(is_gpu, func):
  366. class MemoryMonitor:
  367. def __init__(self, keep_measuring=True):
  368. self.keep_measuring = keep_measuring
  369. def measure_cpu_usage(self):
  370. import psutil
  371. max_usage = 0
  372. while True:
  373. max_usage = max(max_usage, psutil.Process(os.getpid()).memory_info().rss / 1024**2)
  374. sleep(0.005) # 5ms
  375. if not self.keep_measuring:
  376. break
  377. return max_usage
  378. def measure_gpu_usage(self) -> Optional[List[Dict[str, Any]]]:
  379. from py3nvml.py3nvml import (
  380. NVMLError,
  381. nvmlDeviceGetCount,
  382. nvmlDeviceGetHandleByIndex,
  383. nvmlDeviceGetMemoryInfo,
  384. nvmlDeviceGetName,
  385. nvmlInit,
  386. nvmlShutdown,
  387. )
  388. max_gpu_usage = []
  389. gpu_name = []
  390. try:
  391. nvmlInit()
  392. device_count = nvmlDeviceGetCount()
  393. if not isinstance(device_count, int):
  394. logger.error(f"nvmlDeviceGetCount result is not integer: {device_count}")
  395. return None
  396. max_gpu_usage = [0 for i in range(device_count)]
  397. gpu_name = [nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)) for i in range(device_count)]
  398. while True:
  399. for i in range(device_count):
  400. info = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i))
  401. if isinstance(info, str):
  402. logger.error(f"nvmlDeviceGetMemoryInfo returns str: {info}")
  403. return None
  404. max_gpu_usage[i] = max(max_gpu_usage[i], info.used / 1024**2)
  405. sleep(0.005) # 5ms
  406. if not self.keep_measuring:
  407. break
  408. nvmlShutdown()
  409. return [
  410. {
  411. "device_id": i,
  412. "name": gpu_name[i],
  413. "max_used_MB": max_gpu_usage[i],
  414. }
  415. for i in range(device_count)
  416. ]
  417. except NVMLError as error:
  418. logger.error("Error fetching GPU information using nvml: %s", error)
  419. return None
  420. monitor = MemoryMonitor(False)
  421. if is_gpu:
  422. memory_before_test = monitor.measure_gpu_usage()
  423. if memory_before_test is None:
  424. return None
  425. with ThreadPoolExecutor() as executor:
  426. monitor = MemoryMonitor()
  427. mem_thread = executor.submit(monitor.measure_gpu_usage)
  428. try:
  429. fn_thread = executor.submit(func)
  430. _ = fn_thread.result()
  431. finally:
  432. monitor.keep_measuring = False
  433. max_usage = mem_thread.result()
  434. if max_usage is None:
  435. return None
  436. print(f"GPU memory usage: before={memory_before_test} peak={max_usage}")
  437. if len(memory_before_test) >= 1 and len(max_usage) >= 1 and len(memory_before_test) == len(max_usage):
  438. # When there are multiple GPUs, we will check the one with maximum usage.
  439. max_used = 0
  440. for i, memory_before in enumerate(memory_before_test):
  441. before = memory_before["max_used_MB"]
  442. after = max_usage[i]["max_used_MB"]
  443. used = after - before
  444. max_used = max(max_used, used)
  445. return max_used
  446. return None
  447. # CPU memory
  448. memory_before_test = monitor.measure_cpu_usage()
  449. with ThreadPoolExecutor() as executor:
  450. monitor = MemoryMonitor()
  451. mem_thread = executor.submit(monitor.measure_cpu_usage)
  452. try:
  453. fn_thread = executor.submit(func)
  454. _ = fn_thread.result()
  455. finally:
  456. monitor.keep_measuring = False
  457. max_usage = mem_thread.result()
  458. print(f"CPU memory usage: before={memory_before_test:.1f} MB, peak={max_usage:.1f} MB")
  459. return max_usage - memory_before_test
  460. def get_ort_environment_variables():
  461. # Environment variables might impact ORT performance on transformer models. Note that they are for testing only.
  462. env_names = [
  463. "ORT_DISABLE_FUSED_ATTENTION",
  464. "ORT_ENABLE_FUSED_CAUSAL_ATTENTION",
  465. "ORT_DISABLE_FUSED_CROSS_ATTENTION",
  466. "ORT_DISABLE_TRT_FLASH_ATTENTION",
  467. "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION",
  468. "ORT_TRANSFORMER_OPTIONS",
  469. "ORT_CUDA_GEMM_OPTIONS",
  470. ]
  471. env = ""
  472. for name in env_names:
  473. value = os.getenv(name)
  474. if value is None:
  475. continue
  476. if env:
  477. env += ","
  478. env += f"{name}={value}"
  479. return env