m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

545 lines
18 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. # It is a tool to generate test data for a bert model.
  6. # The test data can be used by onnxruntime_perf_test tool to evaluate the inference latency.
  7. import argparse
  8. import os
  9. import random
  10. from pathlib import Path
  11. from typing import Dict, Optional, Tuple
  12. import numpy as np
  13. from onnx import ModelProto, TensorProto, numpy_helper
  14. from onnx_model import OnnxModel
  15. def fake_input_ids_data(
  16. input_ids: TensorProto, batch_size: int, sequence_length: int, dictionary_size: int
  17. ) -> np.ndarray:
  18. """Create input tensor based on the graph input of input_ids
  19. Args:
  20. input_ids (TensorProto): graph input of the input_ids input tensor
  21. batch_size (int): batch size
  22. sequence_length (int): sequence length
  23. dictionary_size (int): vocabulary size of dictionary
  24. Returns:
  25. np.ndarray: the input tensor created
  26. """
  27. assert input_ids.type.tensor_type.elem_type in [
  28. TensorProto.FLOAT,
  29. TensorProto.INT32,
  30. TensorProto.INT64,
  31. ]
  32. data = np.random.randint(dictionary_size, size=(batch_size, sequence_length), dtype=np.int32)
  33. if input_ids.type.tensor_type.elem_type == TensorProto.FLOAT:
  34. data = np.float32(data)
  35. elif input_ids.type.tensor_type.elem_type == TensorProto.INT64:
  36. data = np.int64(data)
  37. return data
  38. def fake_segment_ids_data(segment_ids: TensorProto, batch_size: int, sequence_length: int) -> np.ndarray:
  39. """Create input tensor based on the graph input of segment_ids
  40. Args:
  41. segment_ids (TensorProto): graph input of the token_type_ids input tensor
  42. batch_size (int): batch size
  43. sequence_length (int): sequence length
  44. Returns:
  45. np.ndarray: the input tensor created
  46. """
  47. assert segment_ids.type.tensor_type.elem_type in [
  48. TensorProto.FLOAT,
  49. TensorProto.INT32,
  50. TensorProto.INT64,
  51. ]
  52. data = np.zeros((batch_size, sequence_length), dtype=np.int32)
  53. if segment_ids.type.tensor_type.elem_type == TensorProto.FLOAT:
  54. data = np.float32(data)
  55. elif segment_ids.type.tensor_type.elem_type == TensorProto.INT64:
  56. data = np.int64(data)
  57. return data
  58. def fake_input_mask_data(
  59. input_mask: TensorProto,
  60. batch_size: int,
  61. sequence_length: int,
  62. random_mask_length: bool,
  63. ) -> np.ndarray:
  64. """Create input tensor based on the graph input of segment_ids.
  65. Args:
  66. input_mask (TensorProto): graph input of the attention mask input tensor
  67. batch_size (int): batch size
  68. sequence_length (int): sequence length
  69. random_mask_length (bool): whether mask according to random padding length
  70. Returns:
  71. np.ndarray: the input tensor created
  72. """
  73. assert input_mask.type.tensor_type.elem_type in [
  74. TensorProto.FLOAT,
  75. TensorProto.INT32,
  76. TensorProto.INT64,
  77. ]
  78. if random_mask_length:
  79. actual_seq_len = random.randint(int(sequence_length * 2 / 3), sequence_length)
  80. data = np.zeros((batch_size, sequence_length), dtype=np.int32)
  81. temp = np.ones((batch_size, actual_seq_len), dtype=np.int32)
  82. data[: temp.shape[0], : temp.shape[1]] = temp
  83. else:
  84. data = np.ones((batch_size, sequence_length), dtype=np.int32)
  85. if input_mask.type.tensor_type.elem_type == TensorProto.FLOAT:
  86. data = np.float32(data)
  87. elif input_mask.type.tensor_type.elem_type == TensorProto.INT64:
  88. data = np.int64(data)
  89. return data
  90. def output_test_data(directory: str, inputs: Dict[str, np.ndarray]):
  91. """Output input tensors of test data to a directory
  92. Args:
  93. directory (str): path of a directory
  94. inputs (Dict[str, np.ndarray]): map from input name to value
  95. """
  96. if not os.path.exists(directory):
  97. try:
  98. os.mkdir(directory)
  99. except OSError:
  100. print("Creation of the directory %s failed" % directory)
  101. else:
  102. print("Successfully created the directory %s " % directory)
  103. else:
  104. print("Warning: directory %s existed. Files will be overwritten." % directory)
  105. index = 0
  106. for name, data in inputs.items():
  107. tensor = numpy_helper.from_array(data, name)
  108. with open(os.path.join(directory, "input_{}.pb".format(index)), "wb") as file:
  109. file.write(tensor.SerializeToString())
  110. index += 1
  111. def fake_test_data(
  112. batch_size: int,
  113. sequence_length: int,
  114. test_cases: int,
  115. dictionary_size: int,
  116. verbose: bool,
  117. random_seed: int,
  118. input_ids: TensorProto,
  119. segment_ids: TensorProto,
  120. input_mask: TensorProto,
  121. random_mask_length: bool,
  122. ):
  123. """Create given number of input data for testing
  124. Args:
  125. batch_size (int): batch size
  126. sequence_length (int): sequence length
  127. test_cases (int): number of test cases
  128. dictionary_size (int): vocabulary size of dictionary for input_ids
  129. verbose (bool): print more information or not
  130. random_seed (int): random seed
  131. input_ids (TensorProto): graph input of input IDs
  132. segment_ids (TensorProto): graph input of token type IDs
  133. input_mask (TensorProto): graph input of attention mask
  134. random_mask_length (bool): whether mask random number of words at the end
  135. Returns:
  136. List[Dict[str,numpy.ndarray]]: list of test cases, where each test case is a dictionary
  137. with input name as key and a tensor as value
  138. """
  139. assert input_ids is not None
  140. np.random.seed(random_seed)
  141. random.seed(random_seed)
  142. all_inputs = []
  143. for test_case in range(test_cases):
  144. input_1 = fake_input_ids_data(input_ids, batch_size, sequence_length, dictionary_size)
  145. inputs = {input_ids.name: input_1}
  146. if segment_ids:
  147. inputs[segment_ids.name] = fake_segment_ids_data(segment_ids, batch_size, sequence_length)
  148. if input_mask:
  149. inputs[input_mask.name] = fake_input_mask_data(input_mask, batch_size, sequence_length, random_mask_length)
  150. if verbose and len(all_inputs) == 0:
  151. print("Example inputs", inputs)
  152. all_inputs.append(inputs)
  153. return all_inputs
  154. def generate_test_data(
  155. batch_size: int,
  156. sequence_length: int,
  157. test_cases: int,
  158. seed: int,
  159. verbose: bool,
  160. input_ids: TensorProto,
  161. segment_ids: TensorProto,
  162. input_mask: TensorProto,
  163. random_mask_length: bool,
  164. ):
  165. """Create given number of input data for testing
  166. Args:
  167. batch_size (int): batch size
  168. sequence_length (int): sequence length
  169. test_cases (int): number of test cases
  170. seed (int): random seed
  171. verbose (bool): print more information or not
  172. input_ids (TensorProto): graph input of input IDs
  173. segment_ids (TensorProto): graph input of token type IDs
  174. input_mask (TensorProto): graph input of attention mask
  175. random_mask_length (bool): whether mask random number of words at the end
  176. Returns:
  177. List[Dict[str,numpy.ndarray]]: list of test cases, where each test case is a dictionary
  178. with input name as key and a tensor as value
  179. """
  180. dictionary_size = 10000
  181. all_inputs = fake_test_data(
  182. batch_size,
  183. sequence_length,
  184. test_cases,
  185. dictionary_size,
  186. verbose,
  187. seed,
  188. input_ids,
  189. segment_ids,
  190. input_mask,
  191. random_mask_length,
  192. )
  193. if len(all_inputs) != test_cases:
  194. print("Failed to create test data for test.")
  195. return all_inputs
  196. def get_graph_input_from_embed_node(onnx_model, embed_node, input_index):
  197. if input_index >= len(embed_node.input):
  198. return None
  199. input = embed_node.input[input_index]
  200. graph_input = onnx_model.find_graph_input(input)
  201. if graph_input is None:
  202. parent_node = onnx_model.get_parent(embed_node, input_index)
  203. if parent_node is not None and parent_node.op_type == "Cast":
  204. graph_input = onnx_model.find_graph_input(parent_node.input[0])
  205. return graph_input
  206. def find_bert_inputs(
  207. onnx_model: OnnxModel,
  208. input_ids_name: Optional[str] = None,
  209. segment_ids_name: Optional[str] = None,
  210. input_mask_name: Optional[str] = None,
  211. ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]:
  212. """Find graph inputs for BERT model.
  213. First, we will deduce inputs from EmbedLayerNormalization node.
  214. If not found, we will guess the meaning of graph inputs based on naming.
  215. Args:
  216. onnx_model (OnnxModel): onnx model object
  217. input_ids_name (str, optional): Name of graph input for input IDs. Defaults to None.
  218. segment_ids_name (str, optional): Name of graph input for segment IDs. Defaults to None.
  219. input_mask_name (str, optional): Name of graph input for attention mask. Defaults to None.
  220. Raises:
  221. ValueError: Graph does not have input named of input_ids_name or segment_ids_name or input_mask_name
  222. ValueError: Expected graph input number does not match with specified input_ids_name, segment_ids_name
  223. and input_mask_name
  224. Returns:
  225. Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: input tensors of input_ids,
  226. segment_ids and input_mask
  227. """
  228. graph_inputs = onnx_model.get_graph_inputs_excluding_initializers()
  229. if input_ids_name is not None:
  230. input_ids = onnx_model.find_graph_input(input_ids_name)
  231. if input_ids is None:
  232. raise ValueError(f"Graph does not have input named {input_ids_name}")
  233. segment_ids = None
  234. if segment_ids_name:
  235. segment_ids = onnx_model.find_graph_input(segment_ids_name)
  236. if segment_ids is None:
  237. raise ValueError(f"Graph does not have input named {segment_ids_name}")
  238. input_mask = None
  239. if input_mask_name:
  240. input_mask = onnx_model.find_graph_input(input_mask_name)
  241. if input_mask is None:
  242. raise ValueError(f"Graph does not have input named {input_mask_name}")
  243. expected_inputs = 1 + (1 if segment_ids else 0) + (1 if input_mask else 0)
  244. if len(graph_inputs) != expected_inputs:
  245. raise ValueError(f"Expect the graph to have {expected_inputs} inputs. Got {len(graph_inputs)}")
  246. return input_ids, segment_ids, input_mask
  247. if len(graph_inputs) != 3:
  248. raise ValueError("Expect the graph to have 3 inputs. Got {}".format(len(graph_inputs)))
  249. embed_nodes = onnx_model.get_nodes_by_op_type("EmbedLayerNormalization")
  250. if len(embed_nodes) == 1:
  251. embed_node = embed_nodes[0]
  252. input_ids = get_graph_input_from_embed_node(onnx_model, embed_node, 0)
  253. segment_ids = get_graph_input_from_embed_node(onnx_model, embed_node, 1)
  254. input_mask = get_graph_input_from_embed_node(onnx_model, embed_node, 7)
  255. if input_mask is None:
  256. for input in graph_inputs:
  257. input_name_lower = input.name.lower()
  258. if "mask" in input_name_lower:
  259. input_mask = input
  260. if input_mask is None:
  261. raise ValueError(f"Failed to find attention mask input")
  262. return input_ids, segment_ids, input_mask
  263. # Try guess the inputs based on naming.
  264. input_ids = None
  265. segment_ids = None
  266. input_mask = None
  267. for input in graph_inputs:
  268. input_name_lower = input.name.lower()
  269. if "mask" in input_name_lower: # matches input with name like "attention_mask" or "input_mask"
  270. input_mask = input
  271. elif (
  272. "token" in input_name_lower or "segment" in input_name_lower
  273. ): # matches input with name like "segment_ids" or "token_type_ids"
  274. segment_ids = input
  275. else:
  276. input_ids = input
  277. if input_ids and segment_ids and input_mask:
  278. return input_ids, segment_ids, input_mask
  279. raise ValueError("Fail to assign 3 inputs. You might try rename the graph inputs.")
  280. def get_bert_inputs(
  281. onnx_file: str,
  282. input_ids_name: Optional[str] = None,
  283. segment_ids_name: Optional[str] = None,
  284. input_mask_name: Optional[str] = None,
  285. ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]:
  286. """Find graph inputs for BERT model.
  287. First, we will deduce inputs from EmbedLayerNormalization node.
  288. If not found, we will guess the meaning of graph inputs based on naming.
  289. Args:
  290. onnx_file (str): onnx model path
  291. input_ids_name (str, optional): Name of graph input for input IDs. Defaults to None.
  292. segment_ids_name (str, optional): Name of graph input for segment IDs. Defaults to None.
  293. input_mask_name (str, optional): Name of graph input for attention mask. Defaults to None.
  294. Returns:
  295. Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: input tensors of input_ids,
  296. segment_ids and input_mask
  297. """
  298. model = ModelProto()
  299. with open(onnx_file, "rb") as file:
  300. model.ParseFromString(file.read())
  301. onnx_model = OnnxModel(model)
  302. return find_bert_inputs(onnx_model, input_ids_name, segment_ids_name, input_mask_name)
  303. def parse_arguments():
  304. parser = argparse.ArgumentParser()
  305. parser.add_argument("--model", required=True, type=str, help="bert onnx model path.")
  306. parser.add_argument(
  307. "--output_dir",
  308. required=False,
  309. type=str,
  310. default=None,
  311. help="output test data path. Default is current directory.",
  312. )
  313. parser.add_argument("--batch_size", required=False, type=int, default=1, help="batch size of input")
  314. parser.add_argument(
  315. "--sequence_length",
  316. required=False,
  317. type=int,
  318. default=128,
  319. help="maximum sequence length of input",
  320. )
  321. parser.add_argument(
  322. "--input_ids_name",
  323. required=False,
  324. type=str,
  325. default=None,
  326. help="input name for input ids",
  327. )
  328. parser.add_argument(
  329. "--segment_ids_name",
  330. required=False,
  331. type=str,
  332. default=None,
  333. help="input name for segment ids",
  334. )
  335. parser.add_argument(
  336. "--input_mask_name",
  337. required=False,
  338. type=str,
  339. default=None,
  340. help="input name for attention mask",
  341. )
  342. parser.add_argument(
  343. "--samples",
  344. required=False,
  345. type=int,
  346. default=1,
  347. help="number of test cases to be generated",
  348. )
  349. parser.add_argument("--seed", required=False, type=int, default=3, help="random seed")
  350. parser.add_argument(
  351. "--verbose",
  352. required=False,
  353. action="store_true",
  354. help="print verbose information",
  355. )
  356. parser.set_defaults(verbose=False)
  357. parser.add_argument(
  358. "--only_input_tensors",
  359. required=False,
  360. action="store_true",
  361. help="only save input tensors and no output tensors",
  362. )
  363. parser.set_defaults(only_input_tensors=False)
  364. args = parser.parse_args()
  365. return args
  366. def create_and_save_test_data(
  367. model: str,
  368. output_dir: str,
  369. batch_size: int,
  370. sequence_length: int,
  371. test_cases: int,
  372. seed: int,
  373. verbose: bool,
  374. input_ids_name: Optional[str],
  375. segment_ids_name: Optional[str],
  376. input_mask_name: Optional[str],
  377. only_input_tensors: bool,
  378. ):
  379. """Create test data for a model, and save test data to a directory.
  380. Args:
  381. model (str): path of ONNX bert model
  382. output_dir (str): output directory
  383. batch_size (int): batch size
  384. sequence_length (int): sequence length
  385. test_cases (int): number of test cases
  386. seed (int): random seed
  387. verbose (bool): whether print more information
  388. input_ids_name (str): graph input name of input_ids
  389. segment_ids_name (str): graph input name of segment_ids
  390. input_mask_name (str): graph input name of input_mask
  391. only_input_tensors (bool): only save input tensors
  392. """
  393. input_ids, segment_ids, input_mask = get_bert_inputs(model, input_ids_name, segment_ids_name, input_mask_name)
  394. all_inputs = generate_test_data(
  395. batch_size,
  396. sequence_length,
  397. test_cases,
  398. seed,
  399. verbose,
  400. input_ids,
  401. segment_ids,
  402. input_mask,
  403. random_mask_length=False,
  404. )
  405. for i, inputs in enumerate(all_inputs):
  406. directory = os.path.join(output_dir, "test_data_set_" + str(i))
  407. output_test_data(directory, inputs)
  408. if only_input_tensors:
  409. return
  410. import onnxruntime
  411. session = onnxruntime.InferenceSession(model)
  412. output_names = [output.name for output in session.get_outputs()]
  413. for i, inputs in enumerate(all_inputs):
  414. directory = os.path.join(output_dir, "test_data_set_" + str(i))
  415. result = session.run(output_names, inputs)
  416. for i, output_name in enumerate(output_names):
  417. tensor_result = numpy_helper.from_array(np.asarray(result[i]), output_name)
  418. with open(os.path.join(directory, "output_{}.pb".format(i)), "wb") as file:
  419. file.write(tensor_result.SerializeToString())
  420. def main():
  421. args = parse_arguments()
  422. output_dir = args.output_dir
  423. if output_dir is None:
  424. # Default output directory is a sub-directory under the directory of model.
  425. p = Path(args.model)
  426. output_dir = os.path.join(p.parent, "batch_{}_seq_{}".format(args.batch_size, args.sequence_length))
  427. if output_dir is not None:
  428. # create the output directory if not existed
  429. path = Path(output_dir)
  430. path.mkdir(parents=True, exist_ok=True)
  431. else:
  432. print("Directory existed. test data files will be overwritten.")
  433. create_and_save_test_data(
  434. args.model,
  435. output_dir,
  436. args.batch_size,
  437. args.sequence_length,
  438. args.samples,
  439. args.seed,
  440. args.verbose,
  441. args.input_ids_name,
  442. args.segment_ids_name,
  443. args.input_mask_name,
  444. args.only_input_tensors,
  445. )
  446. print("Test data is saved to directory:", output_dir)
  447. if __name__ == "__main__":
  448. main()