m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1105 lines
44 KiB

7 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import logging
  6. import os
  7. import sys
  8. from collections import deque
  9. from pathlib import Path
  10. from typing import Dict, List, Optional, Tuple
  11. from float16 import convert_float_to_float16
  12. from onnx import AttributeProto, GraphProto, ModelProto, NodeProto, TensorProto, helper, numpy_helper, save_model
  13. from shape_infer_helper import SymbolicShapeInferenceHelper
  14. logger = logging.getLogger(__name__)
  15. class OnnxModel:
  16. def __init__(self, model):
  17. self.initialize(model)
  18. def initialize(self, model):
  19. self.model: ModelProto = model
  20. self._node_name_suffix: Dict[str, int] = {} # key is node name prefix, value is the last suffix generated
  21. self.shape_infer_helper: SymbolicShapeInferenceHelper = None
  22. self.enable_shape_infer: bool = True
  23. self.all_graphs: Optional[List[GraphProto]] = None
  24. def disable_shape_inference(self):
  25. self.enable_shape_infer = False
  26. def infer_runtime_shape(self, dynamic_axis_mapping={}, update=False):
  27. if self.enable_shape_infer:
  28. if self.shape_infer_helper is None or update:
  29. self.shape_infer_helper = SymbolicShapeInferenceHelper(self.model)
  30. try:
  31. if self.shape_infer_helper.infer(dynamic_axis_mapping):
  32. return self.shape_infer_helper
  33. except: # noqa
  34. self.enable_shape_infer = False # disable shape inference to suppress same error message.
  35. print("failed in shape inference", sys.exc_info()[0])
  36. return None
  37. def input_name_to_nodes(self):
  38. input_name_to_nodes = {}
  39. for node in self.nodes():
  40. for input_name in node.input:
  41. if input_name not in input_name_to_nodes:
  42. input_name_to_nodes[input_name] = [node]
  43. else:
  44. input_name_to_nodes[input_name].append(node)
  45. return input_name_to_nodes
  46. def output_name_to_node(self):
  47. output_name_to_node = {}
  48. for node in self.nodes():
  49. for output_name in node.output:
  50. output_name_to_node[output_name] = node
  51. return output_name_to_node
  52. def nodes(self):
  53. all_nodes = []
  54. for graph in self.graphs():
  55. for node in graph.node:
  56. all_nodes.append(node)
  57. return all_nodes
  58. def graph(self):
  59. return self.model.graph
  60. def graphs(self):
  61. if self.all_graphs is not None:
  62. return self.all_graphs
  63. self.all_graphs = []
  64. graph_queue = [self.model.graph]
  65. while graph_queue:
  66. graph = graph_queue.pop(0)
  67. self.all_graphs.append(graph)
  68. for node in graph.node:
  69. for attr in node.attribute:
  70. if attr.type == AttributeProto.AttributeType.GRAPH:
  71. assert isinstance(attr.g, GraphProto)
  72. graph_queue.append(attr.g)
  73. if attr.type == AttributeProto.AttributeType.GRAPHS:
  74. for g in attr.graphs:
  75. assert isinstance(g, GraphProto)
  76. graph_queue.append(g)
  77. return self.all_graphs
  78. def get_graphs_input_names(self):
  79. input_names = []
  80. for graph in self.graphs():
  81. for input in graph.input:
  82. input_names.append(input.name)
  83. return input_names
  84. def get_graphs_output_names(self):
  85. output_names = []
  86. for graph in self.graphs():
  87. for output in graph.output:
  88. output_names.append(output.name)
  89. return output_names
  90. def get_graph_by_node(self, node):
  91. for graph in self.graphs():
  92. if node in graph.node:
  93. return graph
  94. return None
  95. def get_graph_by_name(self, graph_name):
  96. for graph in self.graphs():
  97. if graph_name == graph.name:
  98. return graph
  99. return None
  100. def get_topological_insert_id(self, graph, outputs):
  101. for idx, node in enumerate(graph.node):
  102. for input in node.input:
  103. if input in outputs:
  104. return idx
  105. return len(graph.node)
  106. def remove_node(self, node):
  107. for graph in self.graphs():
  108. if node in graph.node:
  109. graph.node.remove(node)
  110. return
  111. logger.warning("Failed to remove node %s", node) # It might be a bug to hit this line.
  112. def remove_nodes(self, nodes_to_remove):
  113. for node in nodes_to_remove:
  114. self.remove_node(node)
  115. def add_node(self, node, graph_name=None):
  116. if graph_name is None or graph_name == self.model.graph.name:
  117. self.model.graph.node.extend([node])
  118. else:
  119. graph = self.get_graph_by_name(graph_name)
  120. insert_idx = self.get_topological_insert_id(graph, node.output)
  121. graph.node.insert(insert_idx, node)
  122. def add_nodes(self, nodes_to_add, node_name_to_graph_name=None):
  123. if node_name_to_graph_name is None:
  124. self.model.graph.node.extend(nodes_to_add)
  125. else:
  126. for node in nodes_to_add:
  127. graph_name = node_name_to_graph_name[node.name]
  128. self.add_node(node, graph_name)
  129. def add_initializer(self, tensor, graph_name=None):
  130. if graph_name is None or graph_name == self.model.graph.name:
  131. self.model.graph.initializer.extend([tensor])
  132. else:
  133. graph = self.get_graph_by_name(graph_name)
  134. graph.initializer.extend([tensor])
  135. def add_input(self, input, graph_name=None):
  136. if graph_name is None or graph_name == self.model.graph.name:
  137. self.model.graph.input.extend([input])
  138. else:
  139. graph = self.get_graph_by_name(graph_name)
  140. graph.input.extend([input])
  141. @staticmethod
  142. def replace_node_input(node, old_input_name, new_input_name):
  143. assert isinstance(old_input_name, str) and isinstance(new_input_name, str)
  144. for j in range(len(node.input)):
  145. if node.input[j] == old_input_name:
  146. node.input[j] = new_input_name
  147. def replace_input_of_all_nodes(self, old_input_name, new_input_name):
  148. for node in self.model.graph.node:
  149. OnnxModel.replace_node_input(node, old_input_name, new_input_name)
  150. @staticmethod
  151. def replace_node_output(node, old_output_name, new_output_name):
  152. assert isinstance(old_output_name, str) and isinstance(new_output_name, str)
  153. for j in range(len(node.output)):
  154. if node.output[j] == old_output_name:
  155. node.output[j] = new_output_name
  156. def replace_output_of_all_nodes(self, old_output_name, new_output_name):
  157. # This function shall be used carefully. For example:
  158. # Add --[old_name]--> Cast ---> [new_name]
  159. # |
  160. # +----[old_name]--> Transpose -->
  161. # If we want to remove the Cast node: replace output of Add to new_name is not enough;
  162. # The input of Transpose shall also be updated to new_name.
  163. for node in self.model.graph.node:
  164. OnnxModel.replace_node_output(node, old_output_name, new_output_name)
  165. def get_initializer(self, name):
  166. for graph in self.graphs():
  167. for tensor in graph.initializer:
  168. if tensor.name == name:
  169. return tensor
  170. return None
  171. def get_nodes_by_op_type(self, op_type):
  172. nodes = []
  173. for node in self.nodes():
  174. if node.op_type == op_type:
  175. nodes.append(node)
  176. return nodes
  177. def get_children(self, node, input_name_to_nodes=None):
  178. if input_name_to_nodes is None:
  179. input_name_to_nodes = self.input_name_to_nodes()
  180. children = []
  181. for output in node.output:
  182. if output in input_name_to_nodes:
  183. for node in input_name_to_nodes[output]:
  184. children.append(node)
  185. return children
  186. def get_parents(self, node, output_name_to_node=None):
  187. if output_name_to_node is None:
  188. output_name_to_node = self.output_name_to_node()
  189. parents = []
  190. for input in node.input:
  191. if input in output_name_to_node:
  192. parents.append(output_name_to_node[input])
  193. return parents
  194. def get_parent(self, node, i, output_name_to_node=None):
  195. if output_name_to_node is None:
  196. output_name_to_node = self.output_name_to_node()
  197. if len(node.input) <= i:
  198. return None
  199. input = node.input[i]
  200. if input not in output_name_to_node:
  201. return None
  202. return output_name_to_node[input]
  203. def match_first_parent(self, node, parent_op_type, output_name_to_node, exclude=[]):
  204. """
  205. Find parent node based on constraints on op_type.
  206. Args:
  207. node (str): current node name.
  208. parent_op_type (str): constraint of parent node op_type.
  209. output_name_to_node (dict): dictionary with output name as key, and node as value.
  210. exclude (list): list of nodes that are excluded (not allowed to match as parent).
  211. Returns:
  212. parent: The matched parent node. None if not found.
  213. index: The input index of matched parent node. None if not found.
  214. """
  215. for i, input in enumerate(node.input):
  216. if input in output_name_to_node:
  217. parent = output_name_to_node[input]
  218. if parent.op_type == parent_op_type and parent not in exclude:
  219. return parent, i
  220. else:
  221. logger.debug(f"To find first {parent_op_type}, current {parent.op_type}")
  222. return None, None
  223. def match_parent(
  224. self,
  225. node,
  226. parent_op_type,
  227. input_index=None,
  228. output_name_to_node=None,
  229. exclude=[],
  230. return_indice=None,
  231. ):
  232. """
  233. Find parent node based on constraints on op_type and index.
  234. When input_index is None, we will find the first parent node based on constraints,
  235. and return_indice will be appended the corresponding input index.
  236. Args:
  237. node (str): current node name.
  238. parent_op_type (str): constraint of parent node op_type.
  239. input_index (int or None): only check the parent given input index of current node.
  240. output_name_to_node (dict): dictionary with output name as key, and node as value.
  241. exclude (list): list of nodes that are excluded (not allowed to match as parent).
  242. return_indice (list): a list to append the input index when input_index is None.
  243. Returns:
  244. parent: The matched parent node.
  245. """
  246. assert node is not None
  247. assert input_index is None or input_index >= 0
  248. if output_name_to_node is None:
  249. output_name_to_node = self.output_name_to_node()
  250. if input_index is None:
  251. parent, index = self.match_first_parent(node, parent_op_type, output_name_to_node, exclude)
  252. if return_indice is not None:
  253. return_indice.append(index)
  254. return parent
  255. if input_index >= len(node.input):
  256. logger.debug(f"input_index {input_index} >= node inputs {len(node.input)}")
  257. return None
  258. parent = self.get_parent(node, input_index, output_name_to_node)
  259. if parent is not None and parent.op_type == parent_op_type and parent not in exclude:
  260. return parent
  261. if parent is not None:
  262. logger.debug(f"Expect {parent_op_type}, Got {parent.op_type}")
  263. return None
  264. def match_parent_paths(self, node, paths, output_name_to_node):
  265. for i, path in enumerate(paths):
  266. assert isinstance(path, List) or isinstance(path, Tuple)
  267. return_indice = []
  268. matched = self.match_parent_path(node, path[0], path[1], output_name_to_node, return_indice)
  269. if matched:
  270. return i, matched, return_indice
  271. return -1, None, None
  272. def match_parent_path(
  273. self,
  274. node,
  275. parent_op_types,
  276. parent_input_index=None,
  277. output_name_to_node=None,
  278. return_indice=None,
  279. ):
  280. """
  281. Find a sequence of input edges based on constraints on parent op_type and index.
  282. When input_index is None, we will find the first parent node based on constraints,
  283. and return_indice will be appended the corresponding input index.
  284. Args:
  285. node (str): current node name.
  286. parent_op_types (str): constraint of parent node op_type of each input edge.
  287. parent_input_index (list): constraint of input index of each input edge. None means no constraint.
  288. output_name_to_node (dict): dictionary with output name as key, and node as value.
  289. return_indice (list): a list to append the input index
  290. When there is no constraint on input index of an edge.
  291. Returns:
  292. parents: a list of matched parent node.
  293. """
  294. if parent_input_index is not None:
  295. assert len(parent_input_index) == len(parent_op_types)
  296. if output_name_to_node is None:
  297. output_name_to_node = self.output_name_to_node()
  298. current_node = node
  299. matched_parents = []
  300. for i, op_type in enumerate(parent_op_types):
  301. matched_parent = self.match_parent(
  302. current_node,
  303. op_type,
  304. parent_input_index[i] if parent_input_index is not None else None,
  305. output_name_to_node,
  306. exclude=[],
  307. return_indice=return_indice,
  308. )
  309. if matched_parent is None:
  310. if parent_input_index is not None:
  311. logger.debug(
  312. f"Failed to match index={i} parent_input_index={parent_input_index[i]} op_type={op_type}",
  313. stack_info=True,
  314. )
  315. else:
  316. logger.debug(f"Failed to match index={i} op_type={op_type}", stack_info=True)
  317. return None
  318. matched_parents.append(matched_parent)
  319. current_node = matched_parent
  320. return matched_parents
  321. def find_first_child_by_type(self, node, child_type, input_name_to_nodes=None, recursive=True):
  322. children = self.get_children(node, input_name_to_nodes)
  323. dq = deque(children)
  324. while len(dq) > 0:
  325. current_node = dq.pop()
  326. if current_node.op_type == child_type:
  327. return current_node
  328. if recursive:
  329. children = self.get_children(current_node, input_name_to_nodes)
  330. for child in children:
  331. dq.appendleft(child)
  332. return None
  333. def find_first_parent_by_type(self, node, parent_type, output_name_to_node=None, recursive=True):
  334. if output_name_to_node is None:
  335. output_name_to_node = self.output_name_to_node()
  336. parents = self.get_parents(node, output_name_to_node)
  337. dq = deque(parents)
  338. while len(dq) > 0:
  339. current_node = dq.pop()
  340. if current_node.op_type == parent_type:
  341. return current_node
  342. if recursive:
  343. parents = self.get_parents(current_node, output_name_to_node)
  344. for parent in parents:
  345. dq.appendleft(parent)
  346. return None
  347. def get_constant_value(self, output_name):
  348. for node in self.get_nodes_by_op_type("Constant"):
  349. if node.output[0] == output_name:
  350. for att in node.attribute:
  351. if att.name == "value":
  352. return numpy_helper.to_array(att.t)
  353. # Fall back to intializer since constant folding might have been applied.
  354. initializer = self.get_initializer(output_name)
  355. if initializer is not None:
  356. return numpy_helper.to_array(initializer)
  357. return None
  358. def get_constant_input(self, node):
  359. for i, input in enumerate(node.input):
  360. value = self.get_constant_value(input)
  361. if value is not None:
  362. return i, value
  363. return None, None
  364. def find_constant_input(self, node, expected_value, delta=0.000001):
  365. i, value = self.get_constant_input(node)
  366. if value is not None and value.size == 1 and abs(value - expected_value) < delta:
  367. return i
  368. return -1
  369. def is_constant_with_specified_dimension(self, output_name, dimensions, description):
  370. value = self.get_constant_value(output_name)
  371. if value is None:
  372. logger.debug(f"{description} {output_name} is not initializer.")
  373. return False
  374. if len(value.shape) != dimensions:
  375. logger.debug(f"{description} {output_name} shall have {dimensions} dimensions. Got shape {value.shape}")
  376. return False
  377. return True
  378. def has_constant_input(self, node, expected_value, delta=0.000001):
  379. return self.find_constant_input(node, expected_value, delta) >= 0
  380. def get_children_subgraph_nodes(self, root_node, stop_nodes, input_name_to_nodes=None):
  381. if input_name_to_nodes is None:
  382. input_name_to_nodes = self.input_name_to_nodes()
  383. children = input_name_to_nodes[root_node.output[0]]
  384. unique_nodes = []
  385. dq = deque(children)
  386. while len(dq) > 0:
  387. current_node = dq.pop()
  388. if current_node in stop_nodes:
  389. continue
  390. if current_node not in unique_nodes:
  391. unique_nodes.append(current_node)
  392. for output in current_node.output:
  393. if output in input_name_to_nodes:
  394. children = input_name_to_nodes[output]
  395. for child in children:
  396. dq.appendleft(child)
  397. return unique_nodes
  398. def tensor_shape_to_list(self, tensor_type):
  399. """Convert tensor shape to list"""
  400. shape_list = []
  401. for d in tensor_type.shape.dim:
  402. if d.HasField("dim_value"):
  403. shape_list.append(d.dim_value) # known dimension
  404. elif d.HasField("dim_param"):
  405. shape_list.append(d.dim_param) # unknown dimension with symbolic name
  406. else:
  407. shape_list.append("?") # shall not happen
  408. return shape_list
  409. def get_dtype(self, input_or_output: str):
  410. """Try get data type given a name (could be initializer, graph input or output)."""
  411. tensor_type_map = {obj.name: obj.type for obj in self.model.graph.value_info}
  412. if input_or_output in tensor_type_map:
  413. return tensor_type_map[input_or_output].tensor_type.elem_type
  414. graph_input = self.find_graph_input(input_or_output)
  415. if graph_input:
  416. return graph_input.type.tensor_type.elem_type
  417. graph_output = self.find_graph_output(input_or_output)
  418. if graph_output:
  419. return graph_output.type.tensor_type.elem_type
  420. return None
  421. @staticmethod
  422. def get_node_attribute(node: NodeProto, attribute_name: str):
  423. for attr in node.attribute:
  424. if attr.name == attribute_name:
  425. value = helper.get_attribute_value(attr)
  426. return value
  427. return None
  428. def remove_cascaded_cast_nodes(self):
  429. """Remove Cast node that are followed by another Cast node like --> Cast --> Cast -->
  430. Note that this shall be used carefully since it might introduce semantic change.
  431. For example, float -> int -> float could get different value than the original float value.
  432. So, it is recommended to used only in post-processing of mixed precision conversion.
  433. """
  434. output_name_to_node = self.output_name_to_node()
  435. removed_count = 0
  436. for node in self.nodes():
  437. if node.op_type == "Cast":
  438. parent = self.get_parent(node, 0, output_name_to_node=output_name_to_node)
  439. if parent and parent.op_type == "Cast":
  440. node.input[0] = parent.input[0]
  441. removed_count += 1
  442. if removed_count > 0:
  443. logger.info("Removed %d cascaded Cast nodes", removed_count)
  444. self.prune_graph()
  445. def remove_useless_cast_nodes(self):
  446. """Remove cast nodes that are not needed: input and output has same data type."""
  447. shape_infer = self.infer_runtime_shape(update=True)
  448. if shape_infer is None:
  449. logger.info("Skip removing useless cast nodes since shape inference failed.")
  450. return
  451. def get_data_type(input_or_output_name):
  452. dtype = self.get_dtype(input_or_output_name)
  453. if dtype:
  454. return dtype
  455. if shape_infer.known_vi_[input_or_output_name].type.tensor_type.HasField("elem_type"):
  456. return shape_infer.known_vi_[input_or_output_name].type.tensor_type.elem_type
  457. return None
  458. nodes_to_remove = []
  459. for node in self.nodes():
  460. if node.op_type == "Cast":
  461. input_dtype = get_data_type(node.input[0])
  462. output_dtype = get_data_type(node.output[0])
  463. if input_dtype and input_dtype == output_dtype:
  464. nodes_to_remove.append(node)
  465. if nodes_to_remove:
  466. graph_input_names = set(self.get_graphs_input_names())
  467. graph_output_names = set(self.get_graphs_output_names())
  468. for node in nodes_to_remove:
  469. if bool(set(node.output) & graph_output_names):
  470. if (not bool(set(node.input) & graph_input_names)) and len(
  471. self.input_name_to_nodes()[node.input[0]]
  472. ) == 1:
  473. self.replace_output_of_all_nodes(node.input[0], node.output[0])
  474. else:
  475. continue
  476. else:
  477. self.replace_input_of_all_nodes(node.output[0], node.input[0])
  478. self.remove_node(node)
  479. logger.info("Removed %d Cast nodes with output type same as input", len(nodes_to_remove))
  480. def convert_model_float32_to_float16(self, cast_input_output=True):
  481. logger.warning(
  482. "The function convert_model_float32_to_float16 is deprecated. Use convert_float_to_float16 instead!"
  483. )
  484. self.convert_float_to_float16(use_symbolic_shape_infer=True, keep_io_types=cast_input_output)
  485. def convert_float_to_float16(self, use_symbolic_shape_infer=True, **kwargs):
  486. """Convert a model to half (default) or mixed precision.
  487. To use mixed precision, user need specify which graph inputs, outputs, operator type
  488. or list of nodes shall keep in float32.
  489. By default, we use symbolic shape inference to get shape and type information.
  490. If not, ONNX shape inference will be used.
  491. Note that symbolic/ONNX shape inference might fail, and the conversion might not proceed
  492. without shape and type information.
  493. Args:
  494. use_symbolic_shape_infer (bool, optional): use symbolic shape inference instead of onnx shape inference.
  495. Defaults to True.
  496. keep_io_types (Union[bool, List[str]], optional): boolean or a list of float32 input/output names.
  497. If True, model inputs/outputs should be left as float32.
  498. Defaults to False.
  499. op_block_list (List[str], optional): List of operator types to leave as float32.
  500. Defaults to None, which will use `float16.DEFAULT_OP_BLOCK_LIST`.
  501. node_block_list (List[str], optional): List of node names to leave as float32. Defaults to None.
  502. force_fp16_initializers(bool): force converting all float initializers to float16.
  503. Default to false.
  504. min_positive_val (float, optional): minimal positive value. Defaults to 1e-7.
  505. max_finite_val (float, optional): maximal finite value. Defaults to 1e4.
  506. """
  507. if "keep_io_types" not in kwargs:
  508. kwargs["keep_io_types"] = True
  509. model = self.model
  510. if use_symbolic_shape_infer:
  511. # Use symbolic shape inference since custom operators (like Gelu, SkipLayerNormalization etc)
  512. # are not recognized by onnx shape inference.
  513. shape_infer_helper = SymbolicShapeInferenceHelper(model)
  514. model = shape_infer_helper.infer_shapes(model, auto_merge=True, guess_output_rank=False)
  515. parameters = {"disable_shape_infer": use_symbolic_shape_infer}
  516. parameters.update(
  517. {
  518. key: kwargs[key]
  519. for key in [
  520. "keep_io_types",
  521. "min_positive_val",
  522. "max_finite_val",
  523. "op_block_list",
  524. "node_block_list",
  525. "force_fp16_initializers",
  526. ]
  527. if key in kwargs
  528. }
  529. )
  530. fp16_model = convert_float_to_float16(model, **parameters)
  531. self.initialize(fp16_model)
  532. self.remove_cascaded_cast_nodes()
  533. self.remove_useless_cast_nodes()
  534. def create_node_name(self, op_type, name_prefix=None):
  535. """Create a unique node name that starts with a prefix (default is operator type).
  536. The name will not be duplicated with any name that generated or existed in current graphs.
  537. Args:
  538. op_type (str): operator type
  539. name_prefix (str, optional): prefix of node name. Defaults to None.
  540. Returns:
  541. str: node name
  542. """
  543. if name_prefix:
  544. prefix = name_prefix if name_prefix.endswith("_") else (name_prefix + "_")
  545. else:
  546. prefix = op_type + "_"
  547. suffix: int = 0
  548. if prefix in self._node_name_suffix:
  549. suffix = self._node_name_suffix[prefix] + 1
  550. else:
  551. # Check existed node name only once for a prefix
  552. # as we assume create_node_name is called for every new node in fusion.
  553. for node in self.nodes():
  554. if node.name and node.name.startswith(prefix):
  555. try:
  556. index = int(node.name[len(prefix) :])
  557. suffix = max(index + 1, suffix)
  558. except ValueError:
  559. continue
  560. # Record the generated suffix so that we can avoid generating duplicated name.
  561. self._node_name_suffix[prefix] = suffix
  562. return prefix + str(suffix)
  563. def find_graph_input(self, input_name):
  564. for input in self.model.graph.input:
  565. if input.name == input_name:
  566. return input
  567. return None
  568. def find_graph_output(self, output_name):
  569. for output in self.model.graph.output:
  570. if output.name == output_name:
  571. return output
  572. return None
  573. def get_parent_subgraph_nodes(self, node, stop_nodes, output_name_to_node=None):
  574. if output_name_to_node is None:
  575. output_name_to_node = self.output_name_to_node()
  576. unique_nodes = []
  577. parents = self.get_parents(node, output_name_to_node)
  578. dq = deque(parents)
  579. while len(dq) > 0:
  580. current_node = dq.pop()
  581. if current_node in stop_nodes:
  582. continue
  583. if current_node not in unique_nodes:
  584. unique_nodes.append(current_node)
  585. for input in current_node.input:
  586. if input in output_name_to_node:
  587. dq.appendleft(output_name_to_node[input])
  588. return unique_nodes
  589. def get_graph_inputs(self, current_node, recursive=False):
  590. """
  591. Find graph inputs that linked to current node.
  592. """
  593. graph_inputs = []
  594. for input in current_node.input:
  595. if self.find_graph_input(input) and input not in graph_inputs:
  596. graph_inputs.append(input)
  597. if recursive:
  598. parent_nodes = self.get_parent_subgraph_nodes(current_node, [])
  599. for node in parent_nodes:
  600. for input in node.input:
  601. if self.find_graph_input(input) and input not in graph_inputs:
  602. graph_inputs.append(input)
  603. return graph_inputs
  604. @staticmethod
  605. def input_index(node_output, child_node):
  606. index = 0
  607. for input in child_node.input:
  608. if input == node_output:
  609. return index
  610. index += 1
  611. return -1
  612. def remove_unused_constant(self):
  613. input_name_to_nodes = self.input_name_to_nodes()
  614. # remove unused constant
  615. unused_nodes = []
  616. nodes = self.nodes()
  617. for node in nodes:
  618. if node.op_type == "Constant" and node.output[0] not in input_name_to_nodes:
  619. unused_nodes.append(node)
  620. self.remove_nodes(unused_nodes)
  621. if len(unused_nodes) > 0:
  622. logger.debug(f"Removed unused constant nodes: {len(unused_nodes)}")
  623. def prune_graph(self, outputs=None):
  624. """
  625. Prune graph to keep only required outputs. It removes unnecessary inputs and nodes.
  626. Nodes are not linked (directly or indirectly) to any required output will be removed.
  627. Args:
  628. outputs (list): a list of graph outputs to retain. If it is None, all graph outputs will be kept.
  629. """
  630. if len(self.graphs()) > 1:
  631. logger.debug("Skip prune_graph since graph has subgraph")
  632. return
  633. if outputs is None:
  634. outputs = [output.name for output in self.model.graph.output]
  635. output_name_to_node = self.output_name_to_node()
  636. all_nodes = []
  637. for output in outputs:
  638. if output in output_name_to_node:
  639. last_node = output_name_to_node[output]
  640. if last_node in all_nodes:
  641. continue
  642. nodes = self.get_parent_subgraph_nodes(last_node, [])
  643. all_nodes.append(last_node)
  644. all_nodes.extend(nodes)
  645. nodes_to_remove = []
  646. for node in self.model.graph.node:
  647. if node not in all_nodes:
  648. nodes_to_remove.append(node)
  649. self.remove_nodes(nodes_to_remove)
  650. # remove outputs not in list
  651. output_to_remove = []
  652. for output in self.model.graph.output:
  653. if output.name not in outputs:
  654. output_to_remove.append(output)
  655. for output in output_to_remove:
  656. self.model.graph.output.remove(output)
  657. # remove inputs not used by any node.
  658. input_name_to_nodes = self.input_name_to_nodes()
  659. input_to_remove = []
  660. for input in self.model.graph.input:
  661. if input.name not in input_name_to_nodes:
  662. input_to_remove.append(input)
  663. for input in input_to_remove:
  664. self.model.graph.input.remove(input)
  665. if input_to_remove or output_to_remove or nodes_to_remove:
  666. logger.info(
  667. "Graph pruned: {} inputs, {} outputs and {} nodes are removed".format(
  668. len(input_to_remove), len(output_to_remove), len(nodes_to_remove)
  669. )
  670. )
  671. self.update_graph()
  672. def update_graph(self, verbose=False):
  673. graph = self.model.graph
  674. remaining_input_names = []
  675. for node in graph.node:
  676. if node.op_type in ["Loop", "Scan", "If"]:
  677. # TODO: handle inner graph
  678. logger.debug(f"Skip update_graph since graph has operator: {node.op_type}")
  679. return
  680. if node.op_type != "Constant":
  681. for input_name in node.input:
  682. if input_name not in remaining_input_names:
  683. remaining_input_names.append(input_name)
  684. if verbose:
  685. logger.debug(f"remaining input names: {remaining_input_names}")
  686. # remove graph input that is not used
  687. inputs_to_remove = []
  688. for input in graph.input:
  689. if input.name not in remaining_input_names:
  690. inputs_to_remove.append(input)
  691. for input in inputs_to_remove:
  692. graph.input.remove(input)
  693. names_to_remove = [input.name for input in inputs_to_remove]
  694. logger.debug(f"remove {len(inputs_to_remove)} unused inputs: {names_to_remove}")
  695. # remove weights that are not used
  696. weights_to_remove = []
  697. weights_to_keep = []
  698. for initializer in graph.initializer:
  699. if initializer.name not in remaining_input_names and not self.find_graph_output(initializer.name):
  700. weights_to_remove.append(initializer)
  701. else:
  702. weights_to_keep.append(initializer.name)
  703. for initializer in weights_to_remove:
  704. graph.initializer.remove(initializer)
  705. names_to_remove = [initializer.name for initializer in weights_to_remove]
  706. logger.debug(f"remove {len(weights_to_remove)} unused initializers: {names_to_remove}")
  707. if verbose:
  708. logger.debug(f"remaining initializers:{weights_to_keep}")
  709. self.remove_unused_constant()
  710. def is_safe_to_fuse_nodes(self, nodes_to_remove, keep_outputs, input_name_to_nodes, output_name_to_node):
  711. for node_to_remove in nodes_to_remove:
  712. for output_to_remove in node_to_remove.output:
  713. if output_to_remove in keep_outputs:
  714. continue
  715. if output_to_remove in input_name_to_nodes:
  716. for impacted_node in input_name_to_nodes[output_to_remove]:
  717. if impacted_node not in nodes_to_remove:
  718. logger.debug(
  719. "it is not safe to remove nodes since output %s is used by %s",
  720. output_to_remove,
  721. impacted_node,
  722. )
  723. return False
  724. return True
  725. @staticmethod
  726. def graph_topological_sort(graph):
  727. deps_count = [0] * len(graph.node) # dependency count of each node
  728. deps_to_nodes = {} # input to node indice
  729. sorted_nodes = [] # initialize sorted_nodes
  730. for node_idx, node in enumerate(graph.node):
  731. # CANNOT use len(node.input) directly because input can be optional
  732. deps_count[node_idx] = sum(1 for _ in node.input if _)
  733. if deps_count[node_idx] == 0: # Constant doesn't depend on any inputs
  734. sorted_nodes.append(graph.node[node_idx])
  735. continue
  736. for input_name in node.input:
  737. if input_name not in deps_to_nodes:
  738. deps_to_nodes[input_name] = [node_idx]
  739. else:
  740. deps_to_nodes[input_name].append(node_idx)
  741. # Note: this logic only applies to top level graph since a sub graph could use intializer from parent graph
  742. initializer_names = [init.name for init in graph.initializer]
  743. graph_input_names = [input.name for input in graph.input]
  744. input_names = initializer_names + graph_input_names
  745. input_names.sort()
  746. prev_input_name = None
  747. for input_name in input_names:
  748. if prev_input_name == input_name:
  749. continue
  750. prev_input_name = input_name
  751. if input_name in deps_to_nodes:
  752. for node_idx in deps_to_nodes[input_name]:
  753. deps_count[node_idx] = deps_count[node_idx] - 1
  754. if deps_count[node_idx] == 0:
  755. sorted_nodes.append(graph.node[node_idx])
  756. start = 0
  757. end = len(sorted_nodes)
  758. while start < end:
  759. for output in sorted_nodes[start].output:
  760. if output in deps_to_nodes:
  761. for node_idx in deps_to_nodes[output]:
  762. deps_count[node_idx] = deps_count[node_idx] - 1
  763. if deps_count[node_idx] == 0:
  764. sorted_nodes.append(graph.node[node_idx])
  765. end = end + 1
  766. start = start + 1
  767. if end != len(graph.node):
  768. raise RuntimeError(
  769. f"Graph is not a DAG: end={end}, len(graph.node)={len(graph.node)}, graph.node[end]={graph.node[end]}"
  770. )
  771. graph.ClearField("node")
  772. graph.node.extend(sorted_nodes)
  773. def topological_sort(self):
  774. # TODO: support graph_topological_sort() in subgraphs
  775. # for graph in self.graphs():
  776. # self.graph_topological_sort(graph)
  777. OnnxModel.graph_topological_sort(self.model.graph)
  778. @staticmethod
  779. def save(
  780. model,
  781. output_path,
  782. save_as_external_data=False,
  783. all_tensors_to_one_file=True,
  784. size_threshold=1024,
  785. convert_attribute=False,
  786. ):
  787. Path(output_path).parent.mkdir(parents=True, exist_ok=True)
  788. # Add ms domain if needed
  789. ms_opset = [opset for opset in model.opset_import if opset.domain == "com.microsoft"]
  790. # Check whether there is custom op in top level graph (our fusion is on top level right now).
  791. # May need to extend to subgraph if our fusion are extended to subgraphs.
  792. ms_node = [node for node in model.graph.node if node.domain == "com.microsoft"]
  793. if ms_node and not ms_opset:
  794. opset = model.opset_import.add()
  795. opset.version = 1
  796. opset.domain = "com.microsoft"
  797. if save_as_external_data:
  798. # Save model to external data, which is needed for model size > 2GB
  799. output_dir = Path(output_path).parent
  800. output_dir.mkdir(parents=True, exist_ok=True)
  801. external_data_path = output_path + ".data"
  802. location = Path(external_data_path).name if all_tensors_to_one_file else None
  803. if os.path.exists(output_path):
  804. logger.info(f"Delete the existed onnx file: {output_path}")
  805. os.remove(output_path)
  806. if all_tensors_to_one_file:
  807. if os.path.exists(external_data_path):
  808. # Delete the external data file. Otherwise, data will be appended to existing file.
  809. logger.info(f"Delete the existed external data file: {external_data_path}")
  810. os.remove(external_data_path)
  811. else:
  812. if os.listdir(output_dir):
  813. raise RuntimeError(f"Output directory ({output_dir}) for external data is not empty.")
  814. save_model(
  815. model,
  816. output_path,
  817. save_as_external_data=True,
  818. all_tensors_to_one_file=all_tensors_to_one_file,
  819. location=location,
  820. size_threshold=size_threshold,
  821. convert_attribute=convert_attribute,
  822. )
  823. else:
  824. save_model(model, output_path)
  825. def save_model_to_file(self, output_path, use_external_data_format=False, all_tensors_to_one_file=True):
  826. logger.info("Sort graphs in topological order")
  827. self.topological_sort()
  828. # Note: After the model is saved to another directory with external data,
  829. # You need reload the onnx model if you want to read tensor from self.model object.
  830. # It is because the base directory is not updated for self.model object so attempt to read tensor data
  831. # might encounter error since external data cannot be located.
  832. OnnxModel.save(self.model, output_path, use_external_data_format, all_tensors_to_one_file)
  833. logger.info(f"Model saved to {output_path}")
  834. def get_graph_inputs_excluding_initializers(self):
  835. """
  836. Returns real graph inputs (excluding initializers from older onnx model).
  837. """
  838. graph_inputs = []
  839. for input in self.model.graph.input:
  840. if self.get_initializer(input.name) is None:
  841. graph_inputs.append(input)
  842. return graph_inputs
  843. def get_opset_version(self):
  844. """Get opset version of onnx domain
  845. Raises:
  846. RuntimeError: ONNX model has no opset for default domain.
  847. Returns:
  848. int: opset version of onnx domain.
  849. """
  850. for opset in self.model.opset_import:
  851. if opset.domain in ["", "ai.onnx"]:
  852. return opset.version
  853. raise RuntimeError("ONNX model has no opset for default domain")
  854. @staticmethod
  855. def has_same_value(tensor1: TensorProto, tensor2: TensorProto) -> bool:
  856. """Returns True when two tensors have same value.
  857. Note that name can be different.
  858. Args:
  859. tensor1 (TensorProto): initializer 1
  860. tensor2 (TensorProto): initializer 2
  861. Returns:
  862. bool: True when two intializers has same value.
  863. """
  864. if tensor1.data_type != tensor2.data_type or tensor1.dims != tensor2.dims:
  865. return False
  866. if tensor1.HasField("raw_data") and tensor2.HasField("raw_data"):
  867. return tensor1.raw_data == tensor2.raw_data
  868. return numpy_helper.to_array(tensor1) == numpy_helper.to_array(tensor2)
  869. def remove_duplicated_initializer(self):
  870. """Remove initializers with duplicated values, and only keep the first one.
  871. It could help reduce size of models (like ALBert) with shared weights.
  872. Note: this function does not process subgraph.
  873. """
  874. if len(self.graphs()) > 1:
  875. logger.warning("remove_duplicated_initializer does not process subgraphs.")
  876. initializer_count = len(self.model.graph.initializer)
  877. same = [-1] * initializer_count
  878. for i in range(initializer_count - 1):
  879. if same[i] >= 0:
  880. continue
  881. for j in range(i + 1, initializer_count):
  882. if OnnxModel.has_same_value(self.model.graph.initializer[i], self.model.graph.initializer[j]):
  883. same[j] = i
  884. count = 0
  885. for i in range(initializer_count):
  886. if same[i] >= 0:
  887. count += 1
  888. self.replace_input_of_all_nodes(
  889. self.model.graph.initializer[i].name, self.model.graph.initializer[same[i]].name
  890. )
  891. if count > 0:
  892. self.update_graph()
  893. print(f"Removed {count} initializers with duplicated value")
  894. def add_prefix_to_names(self, prefix: str):
  895. """Add prefix to initializer or intermediate outputs in graph. Main graph inputs and outputs are excluded.
  896. It could help avoid conflicting in name of node_args when merging two graphs.
  897. Note: this function does not process subgraph.
  898. """
  899. if len(self.graphs()) > 1:
  900. logger.warning("add_prefix_to_names does not process subgraphs.")
  901. # Exclude the names of inputs and outputs of main graph (but not subgraphs)
  902. # and empty names ("") as they have special meaning to denote missing optional inputs
  903. excluded = [i.name for i in self.model.graph.input] + [o.name for o in self.model.graph.output] + [""]
  904. for initializer in self.model.graph.initializer:
  905. if initializer.name not in excluded:
  906. if prefix + initializer.name not in excluded:
  907. initializer.name = prefix + initializer.name
  908. for node in self.model.graph.node:
  909. # update name of node inputs
  910. for j in range(len(node.input)):
  911. if node.input[j] not in excluded:
  912. if prefix + node.input[j] not in excluded:
  913. node.input[j] = prefix + node.input[j]
  914. # update name of node outputs
  915. for j in range(len(node.output)):
  916. if node.output[j] not in excluded:
  917. if prefix + node.output[j] not in excluded:
  918. node.output[j] = prefix + node.output[j]
  919. for value_info in self.model.graph.value_info:
  920. if value_info.name not in excluded:
  921. value_info.name = prefix + value_info.name