图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

725 lines
24 KiB

  1. import argparse
  2. import json
  3. import os
  4. import numpy
  5. import psutil
  6. from onnx import TensorProto
  7. """
  8. This profiler tool could run a transformer model and print out the kernel time spent on each Node of the model.
  9. Example of profiling of longformer model:
  10. python profiler.py --model longformer-base-4096_fp32.onnx --batch_size 1 --sequence_length 4096 --global_length 8 --samples 1000 --thread_num 8 --dummy_inputs longformer --use_gpu
  11. Example of importing profile result file from onnxruntime_perf_test:
  12. python profiler.py --input profile_2021-10-25_12-02-41.json
  13. """
  14. NODES_TYPE_CONTAINING_SUBGRAPH = ["Scan", "Loop", "If"]
  15. def parse_arguments(argv=None):
  16. parser = argparse.ArgumentParser()
  17. parser.add_argument(
  18. "-i",
  19. "--input",
  20. required=False,
  21. type=str,
  22. help="Set the input file for reading the profile results",
  23. )
  24. parser.add_argument(
  25. "-m",
  26. "--model",
  27. required=False,
  28. type=str,
  29. help="onnx model path to run profiling. Required when --input is not specified.",
  30. )
  31. parser.add_argument(
  32. "-b",
  33. "--batch_size",
  34. required=False,
  35. type=int,
  36. default=1,
  37. help="batch size of input",
  38. )
  39. parser.add_argument(
  40. "-s",
  41. "--sequence_length",
  42. required=False,
  43. type=int,
  44. default=32,
  45. help="sequence length of input",
  46. )
  47. parser.add_argument(
  48. "--past_sequence_length",
  49. required=False,
  50. type=int,
  51. default=1,
  52. help="past sequence length for gpt2",
  53. )
  54. parser.add_argument(
  55. "--global_length",
  56. required=False,
  57. type=int,
  58. default=1,
  59. help="number of global tokens for longformer",
  60. )
  61. parser.add_argument(
  62. "--samples",
  63. required=False,
  64. type=int,
  65. default=1000,
  66. help="number of samples to test. Set it large enough to reduce the variance of performance result.",
  67. )
  68. parser.add_argument(
  69. "--threshold",
  70. required=False,
  71. type=float,
  72. default=0.01,
  73. help="Threshold of run time ratio among all nodes. Nodes with larger ratio will show in top expensive nodes.",
  74. )
  75. parser.add_argument(
  76. "--thread_num",
  77. required=False,
  78. type=int,
  79. default=-1,
  80. help="number of threads to use",
  81. )
  82. parser.add_argument(
  83. "--input_ids_name",
  84. required=False,
  85. type=str,
  86. default=None,
  87. help="input name for input IDs, for bert",
  88. )
  89. parser.add_argument(
  90. "--segment_ids_name",
  91. required=False,
  92. type=str,
  93. default=None,
  94. help="input name for segment IDs, for bert",
  95. )
  96. parser.add_argument(
  97. "--input_mask_name",
  98. required=False,
  99. type=str,
  100. default=None,
  101. help="input name for attention mask, for bert",
  102. )
  103. parser.add_argument(
  104. "--dummy_inputs",
  105. required=False,
  106. default="default",
  107. choices=["bert", "gpt2", "longformer", "default"],
  108. help="Type of model inputs. The default will create dummy inputs with ones.",
  109. )
  110. parser.add_argument("-g", "--use_gpu", required=False, action="store_true", help="use GPU")
  111. parser.set_defaults(use_gpu=False)
  112. parser.add_argument(
  113. "--provider",
  114. required=False,
  115. type=str,
  116. default="cuda",
  117. help="Execution provider to use",
  118. )
  119. parser.add_argument(
  120. "--basic_optimization",
  121. required=False,
  122. action="store_true",
  123. help="Enable only basic graph optimizations. By default, all optimizations are enabled in OnnxRuntime",
  124. )
  125. parser.set_defaults(basic_optimization=False)
  126. parser.add_argument(
  127. "--kernel_time_only",
  128. required=False,
  129. action="store_true",
  130. help="Only include the kernel time and no fence time",
  131. )
  132. parser.set_defaults(kernel_time_only=False)
  133. parser.add_argument("-v", "--verbose", required=False, action="store_true")
  134. parser.set_defaults(verbose=False)
  135. return parser.parse_args(argv)
  136. def run_profile(onnx_model_path, use_gpu, provider, basic_optimization, thread_num, all_inputs):
  137. from benchmark_helper import create_onnxruntime_session
  138. session = create_onnxruntime_session(
  139. onnx_model_path,
  140. use_gpu,
  141. provider,
  142. enable_all_optimization=not basic_optimization,
  143. num_threads=thread_num,
  144. enable_profiling=True,
  145. )
  146. for inputs in all_inputs:
  147. _ = session.run(None, inputs)
  148. profile_file = session.end_profiling()
  149. return profile_file
  150. def load_profile_json(profile_file):
  151. print(f"loading profile output {profile_file} ...")
  152. with open(profile_file, "r") as opened_file:
  153. sess_time = json.load(opened_file)
  154. assert isinstance(sess_time, list)
  155. return sess_time
  156. def parse_kernel_results(sess_time, threshold=0):
  157. """Parse profile data and output nodes in two sections - nodes in the original order, and top expensive nodes.
  158. Args:
  159. sess_time (List[Dict]): profile data
  160. kernel_time_only (bool, optional): Only include items for kernel time. Defaults to False.
  161. threshold (int, optional): Minimum ratio of duration among all. Defaults to 0.
  162. Returns:
  163. List[str]: lines of string for output.
  164. """
  165. kernel_name_to_op_name = {}
  166. kernel_time = {}
  167. kernel_freq = {}
  168. total = 0
  169. session_init = False
  170. for item in sess_time:
  171. # Skip all MemcpyHostToDevice before session_initialization
  172. if item["cat"] == "Session" and item["name"] == "session_initialization":
  173. session_init = True
  174. if not session_init:
  175. continue
  176. if item["cat"] == "Kernel" and "dur" in item and "args" in item and "op_name" in item["args"]:
  177. kernel_name = item["name"]
  178. op_name = item["args"]["op_name"]
  179. if op_name in NODES_TYPE_CONTAINING_SUBGRAPH:
  180. continue
  181. # Handle MemcpyHostToDevice and MemcpyDeviceToHost here
  182. if not op_name:
  183. op_name = f"({kernel_name})"
  184. if kernel_name in kernel_time:
  185. kernel_time[kernel_name] += item["dur"]
  186. kernel_freq[kernel_name] += 1
  187. else:
  188. kernel_time[kernel_name] = item["dur"]
  189. kernel_freq[kernel_name] = 1
  190. kernel_name_to_op_name[kernel_name] = op_name
  191. total += item["dur"]
  192. if not kernel_time:
  193. return ["No kernel record found!"]
  194. # Output items with run time ratio > thresholds, and sorted by duration in the descending order.
  195. lines = []
  196. lines.append(f"\nTop expensive kernels with Time% >= {threshold*100:.2f}:")
  197. lines.append("-" * 64)
  198. lines.append("Total(μs)\tTime%\tCalls\tAvg(μs)\tKernel")
  199. for kernel_name, duration in sorted(kernel_time.items(), key=lambda x: x[1], reverse=True):
  200. ratio = duration / total
  201. if ratio < threshold:
  202. continue
  203. calls = kernel_freq[kernel_name]
  204. avg_time = duration / float(calls)
  205. lines.append(f"{duration:10d}\t{ratio * 100.0:5.2f}\t{calls:5d}\t{avg_time:8.1f}\t{kernel_name}")
  206. # Group by operator
  207. op_time = {}
  208. for kernel_name, op_name in kernel_name_to_op_name.items():
  209. duration = kernel_time[kernel_name]
  210. if op_name in op_time:
  211. op_time[op_name] += duration
  212. else:
  213. op_time[op_name] = duration
  214. lines.append(f"\nGroup kernel time by operator:")
  215. lines.append("-" * 64)
  216. lines.append("Total(μs)\tTime%\tOperator")
  217. for op_name, duration in sorted(op_time.items(), key=lambda x: x[1], reverse=True):
  218. ratio = duration / total
  219. lines.append(f"{duration:10d}\t{ratio * 100.0:5.2f}\t{op_name}")
  220. return lines
  221. def parse_node_results(sess_time, kernel_time_only=False, threshold=0):
  222. """Parse profile data and output nodes in two sections - nodes in the original order, and top expensive nodes.
  223. Args:
  224. sess_time (List[Dict]): profile data
  225. kernel_time_only (bool, optional): Only include items for kernel time. Defaults to False.
  226. threshold (int, optional): Minimum ratio of duration among all. Defaults to 0.
  227. Returns:
  228. List[str]: lines of string for output.
  229. """
  230. node_name_list = []
  231. node_time = {}
  232. node_freq = {}
  233. node_provider = {}
  234. total = 0
  235. for item in sess_time:
  236. if item["cat"] == "Node" and "dur" in item and "args" in item and "op_name" in item["args"]:
  237. node_name = (
  238. item["name"].replace("_kernel_time", "").replace("_fence_before", "").replace("_fence_after", "")
  239. )
  240. if "provider" in item["args"]:
  241. if item["args"]["provider"] == "CPUExecutionProvider":
  242. device = "CPU"
  243. elif item["args"]["provider"] == "CUDAExecutionProvider":
  244. device = "CUDA"
  245. elif item["args"]["provider"] == "DmlExecutionProvider":
  246. device = "DML"
  247. if node_name not in node_provider:
  248. node_provider[node_name] = device
  249. else:
  250. assert node_provider[node_name] == device
  251. elif kernel_time_only:
  252. continue
  253. op_name = item["args"]["op_name"]
  254. if op_name in NODES_TYPE_CONTAINING_SUBGRAPH:
  255. continue
  256. if node_name in node_time:
  257. node_time[node_name] += item["dur"]
  258. node_freq[node_name] += 1
  259. else:
  260. node_time[node_name] = item["dur"]
  261. node_freq[node_name] = 1
  262. node_name_list.append(node_name)
  263. total += item["dur"]
  264. # Output items in the original order.
  265. lines = [
  266. "\nNodes in the original order:",
  267. "-" * 64,
  268. "Total(μs)\tTime%\tAcc %\tAvg(μs)\tCalls\tProvider\tNode",
  269. ]
  270. before_percentage = 0.0
  271. for node_name in node_name_list:
  272. duration = node_time[node_name]
  273. calls = node_freq[node_name]
  274. avg_time = duration / float(calls)
  275. percentage = (duration / total) * 100.0
  276. provider = node_provider[node_name] if node_name in node_provider else ""
  277. before_percentage += percentage
  278. lines.append(
  279. f"{duration:10d}\t{percentage:5.2f}\t{before_percentage:5.2f}\t{avg_time:8.1f}\t{calls:5d}\t{provider:8s}\t{node_name}"
  280. )
  281. # Output items with run time ratio > thresholds, and sorted by duration in the descending order.
  282. lines.append(f"\nTop expensive nodes with Time% >= {threshold*100:.2f}:")
  283. lines.append("-" * 64)
  284. lines.append("Total(μs)\tTime%\tAvg(μs)\tCalls\tProvider\tNode")
  285. for node_name, duration in sorted(node_time.items(), key=lambda x: x[1], reverse=True):
  286. ratio = duration / total
  287. if ratio < threshold:
  288. continue
  289. calls = node_freq[node_name]
  290. avg_time = duration / float(calls)
  291. percentage = (duration / total) * 100.0
  292. provider = node_provider[node_name] if node_name in node_provider else ""
  293. lines.append(f"{duration:10d}\t{percentage:5.2f}\t{avg_time:8.1f}\t{calls:5d}\t{provider:8s}\t{node_name}")
  294. return lines
  295. def group_node_results(sess_time, kernel_time_only, use_gpu):
  296. """Group results by operator name.
  297. Args:
  298. sess_time (List[Dict]): profile data
  299. kernel_time_only (bool): Only include items for kernel time.
  300. use_gpu (bool): GPU is used in profiling or not.
  301. Returns:
  302. List[str]: lines of string for output.
  303. """
  304. op_kernel_time = {}
  305. op_kernel_records = {}
  306. total_kernel_time = 0
  307. provider_op_kernel_time = {}
  308. provider_op_kernel_records = {}
  309. provider_kernel_time = {}
  310. op_fence_time = {}
  311. total_fence_time = 0
  312. provider_counter = {}
  313. for item in sess_time:
  314. if item["cat"] == "Node" and "dur" in item and "args" in item and "op_name" in item["args"]:
  315. op_name = item["args"]["op_name"]
  316. # TODO: shall we have a separated group for nodes with subgraph?
  317. if op_name in NODES_TYPE_CONTAINING_SUBGRAPH:
  318. continue
  319. if "provider" not in item["args"]:
  320. if "fence" in item["name"]:
  321. if op_name in op_fence_time:
  322. op_fence_time[op_name] += item["dur"]
  323. else:
  324. op_fence_time[op_name] = item["dur"]
  325. total_fence_time += item["dur"]
  326. continue
  327. provider = item["args"]["provider"] if "provider" in item["args"] else ""
  328. if provider in provider_counter:
  329. provider_counter[provider] += 1
  330. else:
  331. provider_counter[provider] = 1
  332. key = f"{provider}:{op_name}"
  333. if key in provider_op_kernel_time:
  334. provider_op_kernel_time[key] += item["dur"]
  335. provider_op_kernel_records[key] += 1
  336. else:
  337. provider_op_kernel_time[key] = item["dur"]
  338. provider_op_kernel_records[key] = 1
  339. if provider in provider_kernel_time:
  340. provider_kernel_time[provider] += item["dur"]
  341. else:
  342. provider_kernel_time[provider] = item["dur"]
  343. if op_name in op_kernel_time:
  344. op_kernel_time[op_name] += item["dur"]
  345. op_kernel_records[op_name] += 1
  346. else:
  347. op_kernel_time[op_name] = item["dur"]
  348. op_kernel_records[op_name] = 1
  349. total_kernel_time += item["dur"]
  350. lines = ["", "Grouped by operator"]
  351. lines.append("-" * 64)
  352. lines.append("Total(μs)\tTime%\tKernel(μs)\tKernel%\tCalls\tAvgKernel(μs)\tFence(μs)\tOperator")
  353. for op_name, kernel_time in sorted(op_kernel_time.items(), key=lambda x: x[1], reverse=True):
  354. fence_time = op_fence_time[op_name] if op_name in op_fence_time else 0
  355. kernel_time_ratio = kernel_time / total_kernel_time
  356. total_time = kernel_time + fence_time
  357. time_ratio = total_time / (total_kernel_time + total_fence_time)
  358. kernel_calls = op_kernel_records[op_name]
  359. avg_kernel_time = kernel_time / kernel_calls
  360. lines.append(
  361. f"{total_time:10d}\t{time_ratio * 100.0:5.2f}\t{kernel_time:11d}\t{kernel_time_ratio * 100.0:5.2f}\t{kernel_calls:5d}\t{avg_kernel_time:14.1f}\t{fence_time:10d}\t{op_name}"
  362. )
  363. lines += ["", "Grouped by provider + operator"]
  364. lines.append("-" * 64)
  365. lines.append("Kernel(μs)\tProvider%\tCalls\tAvgKernel(μs)\tProvider\tOperator")
  366. for key, kernel_time in sorted(provider_op_kernel_time.items(), key=lambda x: x[1], reverse=True):
  367. parts = key.split(":")
  368. provider = parts[0]
  369. op_name = parts[1]
  370. short_ep = provider.replace("ExecutionProvider", "")
  371. calls = provider_op_kernel_records[key]
  372. avg_kernel_time = kernel_time / calls
  373. provider_time_ratio = kernel_time / provider_kernel_time[provider]
  374. lines.append(
  375. f"{kernel_time:10d}\t{provider_time_ratio * 100.0:9.2f}\t{calls:5d}\t{avg_kernel_time:14.1f}\t{short_ep:8s}\t{op_name}"
  376. )
  377. return lines
  378. def get_dim_from_type_proto(dim):
  379. return getattr(dim, dim.WhichOneof("value")) if type(dim.WhichOneof("value")) == str else None
  380. def get_shape_from_type_proto(type_proto):
  381. return [get_dim_from_type_proto(d) for d in type_proto.tensor_type.shape.dim]
  382. def create_dummy_inputs(onnx_model, batch_size, sequence_length, samples):
  383. """Create dummy inputs for ONNX model.
  384. Args:
  385. onnx_model (OnnxModel): ONNX model
  386. batch_size (int): batch size
  387. sequence_length (int): sequence length
  388. samples (int): number of samples
  389. Returns:
  390. List[Dict]: list of inputs
  391. """
  392. dummy_inputs = {}
  393. for graph_input in onnx_model.get_graph_inputs_excluding_initializers():
  394. shape = get_shape_from_type_proto(graph_input.type)
  395. symbol_dims = []
  396. for i, dim in enumerate(shape):
  397. if isinstance(dim, str):
  398. symbol_dims.append(i)
  399. # allowed symbolic dimensions: batch_size and sequence_length
  400. if len(symbol_dims) > 2:
  401. return None
  402. if len(symbol_dims) > 0:
  403. shape[symbol_dims[0]] = batch_size
  404. if len(symbol_dims) > 1:
  405. shape[symbol_dims[1]] = sequence_length
  406. elem_type = graph_input.type.tensor_type.elem_type
  407. assert elem_type in [TensorProto.FLOAT, TensorProto.INT32, TensorProto.INT64]
  408. data_type = (
  409. numpy.float32
  410. if elem_type == TensorProto.FLOAT
  411. else (numpy.int64 if elem_type == TensorProto.INT64 else numpy.int32)
  412. )
  413. data = numpy.ones(shape, dtype=data_type)
  414. dummy_inputs[graph_input.name] = data
  415. all_inputs = [dummy_inputs for _ in range(samples)]
  416. return all_inputs
  417. def create_bert_inputs(
  418. onnx_model,
  419. batch_size,
  420. sequence_length,
  421. samples,
  422. input_ids_name=None,
  423. segment_ids_name=None,
  424. input_mask_name=None,
  425. ):
  426. """Create dummy inputs for BERT model.
  427. Args:
  428. onnx_model (OnnxModel): ONNX model
  429. batch_size (int): batch size
  430. sequence_length (int): sequence length
  431. samples (int): number of samples
  432. input_ids_name (str, optional): Name of graph input for input IDs. Defaults to None.
  433. segment_ids_name (str, optional): Name of graph input for segment IDs. Defaults to None.
  434. input_mask_name (str, optional): Name of graph input for attention mask. Defaults to None.
  435. Returns:
  436. List[Dict]: list of inputs
  437. """
  438. from bert_test_data import find_bert_inputs, generate_test_data
  439. input_ids, segment_ids, input_mask = find_bert_inputs(onnx_model, input_ids_name, segment_ids_name, input_mask_name)
  440. all_inputs = generate_test_data(
  441. batch_size,
  442. sequence_length,
  443. test_cases=samples,
  444. seed=123,
  445. verbose=False,
  446. input_ids=input_ids,
  447. segment_ids=segment_ids,
  448. input_mask=input_mask,
  449. random_mask_length=False,
  450. )
  451. return all_inputs
  452. def create_gpt2_inputs(onnx_model, batch_size, sequence_length, past_sequence_length, samples):
  453. """Create dummy inputs for GPT-2 model.
  454. Args:
  455. onnx_model (OnnxModel): ONNX model
  456. batch_size (int): batch size
  457. sequence_length (int): sequence length
  458. past_sequence_length (int): past sequence length
  459. samples (int): number of samples
  460. Raises:
  461. RuntimeError: symbolic is not supported. Use the tool convert_to_onnx.py to export ONNX model instead.
  462. Returns:
  463. List[Dict]: list of inputs
  464. """
  465. # The symbolic names shall be same as those used in Gpt2Helper.export_onnx(...) function.
  466. symbols = {
  467. "batch_size": batch_size,
  468. "seq_len": sequence_length,
  469. "past_seq_len": past_sequence_length,
  470. "total_seq_len": sequence_length + past_sequence_length,
  471. }
  472. dummy_inputs = {}
  473. for graph_input in onnx_model.get_graph_inputs_excluding_initializers():
  474. shape = get_shape_from_type_proto(graph_input.type)
  475. for i, dim in enumerate(shape):
  476. if isinstance(dim, str):
  477. if dim not in symbols.keys():
  478. raise RuntimeError(f"symbol is not supported: {dim}")
  479. else:
  480. shape[i] = symbols[dim]
  481. elem_type = graph_input.type.tensor_type.elem_type
  482. assert elem_type in [TensorProto.FLOAT, TensorProto.INT32, TensorProto.INT64]
  483. data_type = (
  484. numpy.float32
  485. if elem_type == TensorProto.FLOAT
  486. else (numpy.int64 if elem_type == TensorProto.INT64 else numpy.int32)
  487. )
  488. data = numpy.ones(shape, dtype=data_type)
  489. dummy_inputs[graph_input.name] = data
  490. all_inputs = [dummy_inputs for _ in range(samples)]
  491. return all_inputs
  492. def create_longformer_inputs(onnx_model, batch_size, sequence_length, global_length, samples):
  493. """Create dummy inputs for Longformer model.
  494. Args:
  495. onnx_model (OnnxModel): ONNX model
  496. batch_size (int): batch size
  497. sequence_length (int): sequence length
  498. global_length (int): number of global tokens
  499. samples (int): number of samples
  500. Raises:
  501. RuntimeError: symbolic is not supported. Use the tool convert_longformer_to_onnx.py to export ONNX model instead.
  502. Returns:
  503. List[Dict]: list of inputs
  504. """
  505. symbols = {"batch_size": batch_size, "sequence_length": sequence_length}
  506. dummy_inputs = {}
  507. for graph_input in onnx_model.get_graph_inputs_excluding_initializers():
  508. shape = get_shape_from_type_proto(graph_input.type)
  509. for i, dim in enumerate(shape):
  510. if isinstance(dim, str):
  511. if dim not in symbols.keys():
  512. raise RuntimeError(f"symbol is not supported: {dim}")
  513. else:
  514. shape[i] = symbols[dim]
  515. elem_type = graph_input.type.tensor_type.elem_type
  516. assert elem_type in [TensorProto.FLOAT, TensorProto.INT32, TensorProto.INT64]
  517. data_type = (
  518. numpy.float32
  519. if elem_type == TensorProto.FLOAT
  520. else (numpy.int64 if elem_type == TensorProto.INT64 else numpy.int32)
  521. )
  522. if "global" in graph_input.name:
  523. data = numpy.zeros(shape, dtype=data_type)
  524. data[:, :global_length] = 1
  525. else:
  526. data = numpy.ones(shape, dtype=data_type)
  527. dummy_inputs[graph_input.name] = data
  528. all_inputs = [dummy_inputs for _ in range(samples)]
  529. return all_inputs
  530. def process_results(profile_file, args):
  531. profile_records = load_profile_json(profile_file)
  532. lines = parse_kernel_results(profile_records, args.threshold)
  533. lines += parse_node_results(profile_records, args.kernel_time_only, args.threshold)
  534. lines += group_node_results(profile_records, args.kernel_time_only, args.use_gpu)
  535. return lines
  536. def run(args):
  537. num_threads = args.thread_num if args.thread_num > 0 else psutil.cpu_count(logical=False)
  538. # Set OMP environment variable before importing onnxruntime. Needed for cpu only, and no impact for onnxruntime-gpu package.
  539. if "OMP_NUM_THREADS" not in os.environ:
  540. os.environ["OMP_NUM_THREADS"] = str(num_threads)
  541. from onnx import load
  542. from onnx_model import OnnxModel
  543. onnx_model = OnnxModel(load(args.model))
  544. all_inputs = None
  545. if args.dummy_inputs == "bert":
  546. all_inputs = create_bert_inputs(
  547. onnx_model,
  548. args.batch_size,
  549. args.sequence_length,
  550. args.samples,
  551. args.input_ids_name,
  552. args.segment_ids_name,
  553. args.input_mask_name,
  554. )
  555. elif args.dummy_inputs == "gpt2":
  556. all_inputs = create_gpt2_inputs(
  557. onnx_model,
  558. args.batch_size,
  559. args.sequence_length,
  560. args.past_sequence_length,
  561. args.samples,
  562. )
  563. elif args.dummy_inputs == "longformer":
  564. all_inputs = create_longformer_inputs(
  565. onnx_model,
  566. args.batch_size,
  567. args.sequence_length,
  568. args.global_length,
  569. args.samples,
  570. )
  571. else: # default
  572. all_inputs = create_dummy_inputs(onnx_model, args.batch_size, args.sequence_length, args.samples)
  573. profile_file = run_profile(
  574. args.model,
  575. args.use_gpu,
  576. args.provider,
  577. args.basic_optimization,
  578. args.thread_num,
  579. all_inputs,
  580. )
  581. return profile_file
  582. if __name__ == "__main__":
  583. arguments = parse_arguments()
  584. print("Arguments", arguments)
  585. from benchmark_helper import setup_logger
  586. setup_logger(arguments.verbose)
  587. if not arguments.input:
  588. assert arguments.model, "requires either --model to run profiling or --input to read profiling results"
  589. profile_file = run(arguments)
  590. else:
  591. profile_file = arguments.input
  592. results = process_results(profile_file, arguments)
  593. for line in results:
  594. print(line)