m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

460 lines
18 KiB

6 months ago
  1. from pathlib import Path
  2. import onnx
  3. import onnx.helper as onnx_helper
  4. import onnx.numpy_helper as onnx_numpy_helper
  5. from .quant_utils import attribute_to_kwarg, find_by_name
  6. def _clean_initializers_helper(graph, model):
  7. """Clean unused initializers from graph.
  8. Returns:
  9. A cleaned graph without unused initializers
  10. A list of tensor names, which are not produced by this graph and its subgraphes
  11. """
  12. requesting_tensor_names = set()
  13. requesting_tensor_names.update(input_name for node in graph.node for input_name in node.input if input_name)
  14. requesting_tensor_names.update(g_out.name for g_out in graph.output if g_out.name)
  15. new_nodes = []
  16. for node in graph.node:
  17. new_node = node
  18. graph_attrs = [
  19. attr
  20. for attr in node.attribute
  21. if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS
  22. ]
  23. if graph_attrs:
  24. kwargs = {}
  25. for attr in node.attribute:
  26. new_attribute = {}
  27. if attr.type == onnx.AttributeProto.GRAPH:
  28. (
  29. cleaned_sub_graph,
  30. sub_requesting_tensor_names,
  31. ) = _clean_initializers_helper(attr.g, model)
  32. new_attribute = {attr.name: cleaned_sub_graph}
  33. requesting_tensor_names.update(sub_requesting_tensor_names)
  34. elif attr.type == onnx.AttributeProto.GRAPHS:
  35. cleaned_graphes = []
  36. for subgraph in attr.graphs:
  37. (
  38. cleaned_sub_graph,
  39. sub_requesting_tensor_names,
  40. ) = _clean_initializers_helper(subgraph, model)
  41. cleaned_graphes.append(cleaned_sub_graph)
  42. requesting_tensor_names.update(sub_requesting_tensor_names)
  43. new_attribute = {attr.name: cleaned_graphes}
  44. else:
  45. new_attribute = attribute_to_kwarg(attr)
  46. kwargs.update(new_attribute)
  47. new_node = onnx_helper.make_node(node.op_type, node.input, node.output, name=node.name, **kwargs)
  48. new_nodes.append(new_node)
  49. graph.ClearField("node")
  50. graph.node.extend(new_nodes)
  51. requesting_tensor_names.difference_update(output for node in graph.node for output in node.output)
  52. unused_initializer = []
  53. for initializer in graph.initializer:
  54. if initializer.name in requesting_tensor_names:
  55. requesting_tensor_names.remove(initializer.name)
  56. else:
  57. # mark it to remove, remove here directly will cause mis-behavier
  58. unused_initializer.append(initializer)
  59. name_to_input = {input.name: input for input in graph.input}
  60. for initializer in unused_initializer:
  61. graph.initializer.remove(initializer)
  62. if initializer.name in name_to_input:
  63. try:
  64. graph.input.remove(name_to_input[initializer.name])
  65. except StopIteration:
  66. if model.ir_version < 4:
  67. print(
  68. "Warning: invalid weight name {} found in the graph (not a graph input)".format(
  69. initializer.name
  70. )
  71. )
  72. requesting_tensor_names.difference_update(input.name for input in graph.input)
  73. return graph, requesting_tensor_names
  74. class ONNXModel:
  75. def __init__(self, model):
  76. self.model = model
  77. def nodes(self):
  78. return self.model.graph.node
  79. def initializer(self):
  80. return self.model.graph.initializer
  81. def graph(self):
  82. return self.model.graph
  83. def ir_version(self):
  84. return self.model.ir_version
  85. def opset_import(self):
  86. return self.model.opset_import
  87. def remove_node(self, node):
  88. if node in self.model.graph.node:
  89. self.model.graph.node.remove(node)
  90. def remove_nodes(self, nodes_to_remove):
  91. for node in nodes_to_remove:
  92. self.remove_node(node)
  93. def add_node(self, node):
  94. self.model.graph.node.extend([node])
  95. def add_nodes(self, nodes_to_add):
  96. self.model.graph.node.extend(nodes_to_add)
  97. def add_initializer(self, tensor):
  98. if find_by_name(tensor.name, self.model.graph.initializer) is None:
  99. self.model.graph.initializer.extend([tensor])
  100. def get_initializer(self, name):
  101. for tensor in self.model.graph.initializer:
  102. if tensor.name == name:
  103. return tensor
  104. return None
  105. def get_initializer_name_set(self):
  106. return set(initializer.name for initializer in self.model.graph.initializer)
  107. def remove_initializer(self, tensor):
  108. if tensor in self.model.graph.initializer:
  109. self.model.graph.initializer.remove(tensor)
  110. for input in self.model.graph.input:
  111. if input.name == tensor.name:
  112. self.model.graph.input.remove(input)
  113. break
  114. def remove_initializers(self, init_to_remove):
  115. for initializer in init_to_remove:
  116. self.remove_initializer(initializer)
  117. def get_non_initializer_inputs(self):
  118. initializer_names = self.get_initializer_name_set()
  119. non_initializer_inputs = set()
  120. for input in self.model.graph.input:
  121. if input.name not in initializer_names:
  122. non_initializer_inputs.add(input.name)
  123. return non_initializer_inputs
  124. def input_name_to_nodes(self):
  125. input_name_to_nodes = {}
  126. for node in self.model.graph.node:
  127. for input_name in node.input:
  128. if input_name not in input_name_to_nodes:
  129. input_name_to_nodes[input_name] = [node]
  130. else:
  131. input_name_to_nodes[input_name].append(node)
  132. return input_name_to_nodes
  133. def output_name_to_node(self):
  134. output_name_to_node = {}
  135. for node in self.model.graph.node:
  136. for output_name in node.output:
  137. output_name_to_node[output_name] = node
  138. return output_name_to_node
  139. def get_children(self, node, input_name_to_nodes=None):
  140. if input_name_to_nodes is None:
  141. input_name_to_nodes = self.input_name_to_nodes()
  142. children = []
  143. for output in node.output:
  144. if output in input_name_to_nodes:
  145. for node in input_name_to_nodes[output]:
  146. children.append(node)
  147. return children
  148. def get_parents(self, node, output_name_to_node=None):
  149. if output_name_to_node is None:
  150. output_name_to_node = self.output_name_to_node()
  151. parents = []
  152. for input in node.input:
  153. if input in output_name_to_node:
  154. parents.append(output_name_to_node[input])
  155. return parents
  156. def get_parent(self, node, idx, output_name_to_node=None):
  157. if output_name_to_node is None:
  158. output_name_to_node = self.output_name_to_node()
  159. if len(node.input) <= idx:
  160. return None
  161. input = node.input[idx]
  162. if input not in output_name_to_node:
  163. return None
  164. return output_name_to_node[input]
  165. def find_node_by_name(self, node_name, new_nodes_list, graph):
  166. """Find out if a node exists in a graph or a node is in the
  167. new set of nodes created during quantization.
  168. Returns:
  169. The node found or None.
  170. """
  171. graph_nodes_list = list(graph.node) # deep copy
  172. graph_nodes_list.extend(new_nodes_list)
  173. node = find_by_name(node_name, graph_nodes_list)
  174. return node
  175. def find_nodes_by_initializer(self, graph, initializer):
  176. """
  177. Find all nodes with given initializer as an input.
  178. """
  179. nodes = []
  180. for node in graph.node:
  181. for node_input in node.input:
  182. if node_input == initializer.name:
  183. nodes.append(node)
  184. return nodes
  185. @staticmethod
  186. def __get_initializer(name, graph_path):
  187. for gid in range(len(graph_path) - 1, -1, -1):
  188. graph = graph_path[gid]
  189. for tensor in graph.initializer:
  190. if tensor.name == name:
  191. return tensor, graph
  192. return None, None
  193. @staticmethod
  194. def __replace_gemm_with_matmul(graph_path):
  195. new_nodes = []
  196. graph = graph_path[-1]
  197. for node in graph.node:
  198. graph_attrs = [attr for attr in node.attribute if attr.type == 5 or attr.type == 10]
  199. if len(graph_attrs):
  200. kwargs = {}
  201. for attr in node.attribute:
  202. if attr.type == 5:
  203. graph_path.append(attr.g)
  204. kv = {attr.name: ONNXModel.__replace_gemm_with_matmul(graph_path)}
  205. elif attr.type == 10:
  206. value = []
  207. for subgraph in attr.graphs:
  208. graph_path.append(subgraph)
  209. value.extend([ONNXModel.__replace_gemm_with_matmul(graph_path)])
  210. kv = {attr.name: value}
  211. else:
  212. kv = attribute_to_kwarg(attr)
  213. kwargs.update(kv)
  214. node = onnx_helper.make_node(node.op_type, node.input, node.output, name=node.name, **kwargs)
  215. if node.op_type == "Gemm":
  216. alpha = 1.0
  217. beta = 1.0
  218. transA = 0
  219. transB = 0
  220. for attr in node.attribute:
  221. if attr.name == "alpha":
  222. alpha = onnx_helper.get_attribute_value(attr)
  223. elif attr.name == "beta":
  224. beta = onnx_helper.get_attribute_value(attr)
  225. elif attr.name == "transA":
  226. transA = onnx_helper.get_attribute_value(attr)
  227. elif attr.name == "transB":
  228. transB = onnx_helper.get_attribute_value(attr)
  229. if alpha == 1.0 and beta == 1.0 and transA == 0:
  230. inputB = node.input[1]
  231. if transB == 1:
  232. B, Bs_graph = ONNXModel.__get_initializer(node.input[1], graph_path)
  233. if B:
  234. # assume B is not used by any other node
  235. B_array = onnx_numpy_helper.to_array(B)
  236. B_trans = onnx_numpy_helper.from_array(B_array.T)
  237. B_trans.name = B.name
  238. Bs_graph.initializer.remove(B)
  239. for input in Bs_graph.input:
  240. if input.name == inputB:
  241. Bs_graph.input.remove(input)
  242. break
  243. Bs_graph.initializer.extend([B_trans])
  244. else:
  245. inputB += "_Transposed"
  246. transpose_node = onnx_helper.make_node(
  247. "Transpose",
  248. inputs=[node.input[1]],
  249. outputs=[inputB],
  250. name=node.name + "_Transpose" if node.name != "" else "",
  251. )
  252. new_nodes.append(transpose_node)
  253. matmul_node = onnx_helper.make_node(
  254. "MatMul",
  255. inputs=[node.input[0], inputB],
  256. outputs=[node.output[0] + ("_MatMul" if len(node.input) > 2 else "")],
  257. name=node.name + "_MatMul" if node.name != "" else "",
  258. )
  259. new_nodes.append(matmul_node)
  260. if len(node.input) > 2:
  261. add_node = onnx_helper.make_node(
  262. "Add",
  263. inputs=[node.output[0] + "_MatMul", node.input[2]],
  264. outputs=node.output,
  265. name=node.name + "_Add" if node.name != "" else "",
  266. )
  267. new_nodes.append(add_node)
  268. # unsupported
  269. else:
  270. new_nodes.append(node)
  271. # not GEMM
  272. else:
  273. new_nodes.append(node)
  274. graph.ClearField("node")
  275. graph.node.extend(new_nodes)
  276. graph_path.pop()
  277. return graph
  278. def replace_gemm_with_matmul(self):
  279. graph_path = [self.graph()]
  280. ONNXModel.__replace_gemm_with_matmul(graph_path)
  281. def save_model_to_file(self, output_path, use_external_data_format=False):
  282. """
  283. Save model to external data, which is needed for model size > 2GB
  284. """
  285. self.topological_sort()
  286. if use_external_data_format:
  287. onnx.external_data_helper.convert_model_to_external_data(
  288. self.model,
  289. all_tensors_to_one_file=True,
  290. location=Path(output_path).name + ".data",
  291. )
  292. onnx.save_model(self.model, output_path)
  293. @staticmethod
  294. def replace_node_input(node, old_input_name, new_input_name):
  295. assert isinstance(old_input_name, str) and isinstance(new_input_name, str)
  296. for j in range(len(node.input)):
  297. if node.input[j] == old_input_name:
  298. node.input[j] = new_input_name
  299. def replace_input_of_all_nodes(self, old_input_name, new_input_name):
  300. for node in self.model.graph.node:
  301. ONNXModel.replace_node_input(node, old_input_name, new_input_name)
  302. @staticmethod
  303. def replace_node_output(node, old_output_name, new_output_name):
  304. assert isinstance(old_output_name, str) and isinstance(new_output_name, str)
  305. for j in range(len(node.output)):
  306. if node.output[j] == old_output_name:
  307. node.output[j] = new_output_name
  308. def replace_output_of_all_nodes(self, old_output_name, new_output_name):
  309. for node in self.model.graph.node:
  310. ONNXModel.replace_node_output(node, old_output_name, new_output_name)
  311. def remove_unused_constant(self):
  312. input_name_to_nodes = self.input_name_to_nodes()
  313. # remove unused constant
  314. unused_nodes = []
  315. nodes = self.nodes()
  316. for node in nodes:
  317. if (
  318. node.op_type == "Constant"
  319. and not self.is_graph_output(node.output[0])
  320. and node.output[0] not in input_name_to_nodes
  321. ):
  322. unused_nodes.append(node)
  323. self.remove_nodes(unused_nodes)
  324. ununsed_weights = []
  325. for w in self.initializer():
  326. if w.name not in input_name_to_nodes and not self.is_graph_output(w.name):
  327. ununsed_weights.append(w)
  328. # Remove from graph.input
  329. for graph_input in self.graph().input:
  330. if graph_input.name == w.name:
  331. self.graph().input.remove(graph_input)
  332. self.remove_initializers(ununsed_weights)
  333. def is_graph_output(self, output_name):
  334. for output in self.model.graph.output:
  335. if output.name == output_name:
  336. return True
  337. return False
  338. def is_graph_input(self, tensor_name: str) -> bool:
  339. for input in self.model.graph.input:
  340. if input.name == tensor_name:
  341. return True
  342. return False
  343. # TODO:use OnnxModel.graph_topological_sort(self.model.graph) from transformers.onnx_model
  344. # Currently it breaks Openvino/Linux training gpu pipeline so hold off for 1.8 release
  345. def topological_sort(self):
  346. deps_count = [0] * len(self.nodes()) # dependency count of each node
  347. deps_to_nodes = {} # input to node indice
  348. sorted_nodes = [] # initialize sorted_nodes
  349. for node_idx, node in enumerate(self.nodes()):
  350. # CANNOT use len(node.input) directly because input can be optional
  351. deps_count[node_idx] = sum(1 for _ in node.input if _)
  352. if deps_count[node_idx] == 0: # Constant doesn't depend on any inputs
  353. sorted_nodes.append(self.nodes()[node_idx])
  354. continue
  355. for input_name in node.input:
  356. if input_name not in deps_to_nodes:
  357. deps_to_nodes[input_name] = [node_idx]
  358. else:
  359. deps_to_nodes[input_name].append(node_idx)
  360. initializer_names = [init.name for init in self.initializer()]
  361. graph_input_names = [input.name for input in self.model.graph.input]
  362. input_names = initializer_names + graph_input_names
  363. input_names.sort()
  364. prev_input_name = None
  365. for input_name in input_names:
  366. if prev_input_name == input_name:
  367. continue
  368. prev_input_name = input_name
  369. if input_name in deps_to_nodes:
  370. for node_idx in deps_to_nodes[input_name]:
  371. deps_count[node_idx] = deps_count[node_idx] - 1
  372. if deps_count[node_idx] == 0:
  373. sorted_nodes.append(self.nodes()[node_idx])
  374. start = 0
  375. end = len(sorted_nodes)
  376. while start < end:
  377. for output in sorted_nodes[start].output:
  378. if output in deps_to_nodes:
  379. for node_idx in deps_to_nodes[output]:
  380. deps_count[node_idx] = deps_count[node_idx] - 1
  381. if deps_count[node_idx] == 0:
  382. sorted_nodes.append(self.nodes()[node_idx])
  383. end = end + 1
  384. start = start + 1
  385. assert end == len(self.graph().node), "Graph is not a DAG"
  386. self.graph().ClearField("node")
  387. self.graph().node.extend(sorted_nodes)
  388. def clean_initializers(self):
  389. return _clean_initializers_helper(self.graph(), self.model)