m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

282 lines
11 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from typing import Tuple
  7. import numpy
  8. from numpy import array_equal, ndarray
  9. from onnx import NodeProto, TensorProto, helper, numpy_helper
  10. from onnx import onnx_pb as onnx_proto
  11. from onnx_model import OnnxModel
  12. logger = getLogger(__name__)
  13. class FusionUtils:
  14. def __init__(self, model: OnnxModel):
  15. self.model: OnnxModel = model
  16. def cast_graph_input_to_int32(self, input_name: str) -> Tuple[bool, str]:
  17. graph_input = self.model.find_graph_input(input_name)
  18. if graph_input is not None and graph_input.type.tensor_type.elem_type != TensorProto.INT32:
  19. cast_output, cast_node = self.cast_input_to_int32(input_name)
  20. logger.debug(f"Casted graph input {input_name} to int32")
  21. return True, cast_output
  22. logger.debug(f"Did not cast graph input {input_name} to int32: found {graph_input is not None}")
  23. return False, input_name
  24. def cast_input(self, input_name: str, target_type="int32"):
  25. cast_output = input_name + "_" + target_type
  26. # Avoid consequent Cast nodes.
  27. inputs = [input_name]
  28. output_name_to_node = self.model.output_name_to_node()
  29. if input_name in output_name_to_node:
  30. parent_node = output_name_to_node[input_name]
  31. if parent_node and parent_node.op_type == "Cast":
  32. inputs = [parent_node.input[0]]
  33. cast_node = helper.make_node("Cast", inputs=inputs, outputs=[cast_output])
  34. if target_type == "int32":
  35. to_type = int(TensorProto.INT32)
  36. elif target_type == "float32":
  37. to_type = int(TensorProto.FLOAT)
  38. elif target_type == "float16":
  39. to_type = int(TensorProto.FLOAT16)
  40. else:
  41. raise ValueError("Invalid target_type: {target_type}")
  42. cast_node.attribute.extend([helper.make_attribute("to", to_type)])
  43. self.model.add_node(cast_node)
  44. return cast_output, cast_node
  45. def cast_input_to_int32(self, input_name: str):
  46. return self.cast_input(input_name, "int32")
  47. def remove_cast_int32(self, input_name: str):
  48. input_name_to_nodes = self.model.input_name_to_nodes()
  49. nodes = input_name_to_nodes[input_name]
  50. for node in nodes:
  51. if node.op_type == "Cast":
  52. is_int32 = False
  53. for att in node.attribute:
  54. if att.name == "to" and att.i == int(TensorProto.INT32):
  55. is_int32 = True
  56. break
  57. if is_int32:
  58. output_name = node.output[0]
  59. self.model.remove_node(node)
  60. self.model.replace_input_of_all_nodes(output_name, input_name)
  61. @staticmethod
  62. def skip_parent(model: OnnxModel, node, parent_node, input_name_to_nodes):
  63. """
  64. Before:
  65. (input)-->parent-->node-->(output)
  66. After:
  67. (input)-->parent-->
  68. |
  69. +----->node-->(output)
  70. This function returns a flag about whether the parent node can be removed.
  71. Note that this function assumes the node has first input links from parent!
  72. """
  73. parent_can_be_removed = False
  74. input_name_to_nodes[node.input[0]].remove(node)
  75. # We can remove the first Transpose if its output is not used (linked to graph output or other nodes) anymore.
  76. if len(input_name_to_nodes[node.input[0]]) == 0 and not model.find_graph_output(
  77. node.input[0]
  78. ): # checks main graph output. TODO: deal with subgraph
  79. parent_can_be_removed = True
  80. # self.nodes_to_remove.append(transpose_a)
  81. input_name_to_nodes[parent_node.input[0]].append(node)
  82. node.input[0] = parent_node.input[0]
  83. return parent_can_be_removed
  84. @staticmethod
  85. def check_node_attribute(node, attribute_name: str, expected_value, default_value=None):
  86. """Verify that a node has expected value for an attribute.
  87. Args:
  88. node (NodeProto): a node to check
  89. attribute_name (str): name of attribute
  90. expected_value (Any): expected value of the attribute
  91. default_value (Any, optional): default value if the attribute does not exist. Defaults to None.
  92. Returns:
  93. bool: whether the check is passed or not
  94. """
  95. value = default_value
  96. for attr in node.attribute:
  97. if attr.name == attribute_name:
  98. value = helper.get_attribute_value(attr)
  99. if isinstance(expected_value, list):
  100. return (isinstance(value, ndarray) or isinstance(value, list)) and array_equal(
  101. expected_value, value, equal_nan=False
  102. )
  103. else:
  104. return value == expected_value
  105. @staticmethod
  106. def transpose_2d_int8_tensor(tensor: onnx_proto.TensorProto):
  107. """Transpose a 2-D INT8 TensorProto
  108. Args:
  109. tensor (TensorProto): tensor to be transposed
  110. Returns:
  111. tensor (TensorProto): transposed tensor
  112. """
  113. if not isinstance(tensor, onnx_proto.TensorProto):
  114. raise ValueError("Expected input type is an ONNX TensorProto but got %s" % type(tensor))
  115. if len(tensor.dims) != 2 or tensor.data_type != onnx_proto.TensorProto.INT8:
  116. raise ValueError("Only INT8 2-D tensors can be transposed")
  117. if tensor.raw_data:
  118. int32_data = numpy.reshape(numpy.frombuffer(tensor.raw_data, dtype="int8"), tensor.dims)
  119. int32_transposed_data = numpy.transpose(int32_data, [1, 0])
  120. tensor.raw_data = int32_transposed_data.tobytes()
  121. else:
  122. raise ValueError("only raw buffer supported")
  123. return tensor
  124. @staticmethod
  125. def check_qdq_node_for_fusion(node: NodeProto, model: OnnxModel, allow_per_tensor_quantization_only=True):
  126. """Verify if a provided QuantizeLinear (Q) / DequantizeLinear (DQ) node is a good candidate for fusion.
  127. It is a good candidate for fusion if:
  128. (1) The Q/DQ node is for per-tensor quantization if allow_per_tensor_quantization_only is `True`
  129. (2) The Q/DQ node should have constant scale
  130. (3) The Q/DQ node should have a zero point of 0
  131. Args:
  132. node (NodeProto): a Q/DQ node to check
  133. Returns:
  134. bool: whether the check is passed or not
  135. """
  136. if not node.op_type in {"QuantizeLinear", "DequantizeLinear"}:
  137. logger.debug(f"Provided node is not a Q/DQ node. Op Type: {node.op_type}")
  138. scale = model.get_constant_value(node.input[1])
  139. # Scale is not constant
  140. if scale is None:
  141. return False
  142. # Not per-tensor quantization
  143. scale_has_single_element = scale.ndim == 0 or (scale.ndim == 1 and scale.shape[0] == 1)
  144. if allow_per_tensor_quantization_only and not scale_has_single_element:
  145. return False
  146. # If the Q/DQ node has no zero point input, it is assumed to be 0 (per ONNX spec)
  147. if len(node.input) == 2:
  148. return True
  149. # Zero point should be constant and should have a value of 0
  150. zero_point = model.get_constant_value(node.input[2])
  151. # Zero point and scale should have same number of dims
  152. if scale.ndim != zero_point.ndim:
  153. return False
  154. # Zero point is not constant or zero point is not zero
  155. if zero_point is None:
  156. return False
  157. return numpy.all(zero_point == 0)
  158. def check_node_input_value(self, node, input_index: int, expected_value):
  159. """Verify that a node has expected input value
  160. Args:
  161. node (NodeProto): a node to check
  162. input_index (int): index of its input to be verified
  163. expected_value (Any): expected value of the input
  164. Returns:
  165. bool: whether the check is passed or not
  166. """
  167. assert len(node.input) > input_index
  168. value = self.model.get_constant_value(node.input[input_index])
  169. if isinstance(expected_value, list):
  170. return (isinstance(value, ndarray) or isinstance(value, list)) and array_equal(
  171. expected_value, value, equal_nan=False
  172. )
  173. else:
  174. return value == expected_value
  175. def remove_identity_nodes(self):
  176. """Remove Identity nodes, except those right before graph output."""
  177. nodes_to_remove = []
  178. for node in self.model.nodes():
  179. if node.op_type == "Identity":
  180. if node.output[0] not in self.model.get_graphs_output_names():
  181. self.model.replace_input_of_all_nodes(node.output[0], node.input[0])
  182. nodes_to_remove.append(node)
  183. if nodes_to_remove:
  184. self.model.remove_nodes(nodes_to_remove)
  185. logger.info(f"Removed {len(nodes_to_remove)} Identity nodes")
  186. def remove_cascaded_cast_nodes(self):
  187. self.model.remove_cascaded_cast_nodes()
  188. def remove_useless_cast_nodes(self):
  189. self.model.remove_useless_cast_nodes()
  190. def remove_useless_reshape_nodes(self):
  191. """Remove reshape node that is not needed based on symbolic shape inference: input and output has same shape"""
  192. shape_infer = self.model.infer_runtime_shape(update=True)
  193. if shape_infer is None:
  194. return
  195. nodes_to_remove = []
  196. for node in self.model.nodes():
  197. if node.op_type == "Reshape":
  198. input_shape = shape_infer.get_edge_shape(node.input[0])
  199. output_shape = shape_infer.get_edge_shape(node.output[0])
  200. if input_shape and output_shape and input_shape == output_shape:
  201. logger.info(
  202. f"Remove reshape node {node.name} since its input shape is same as output: {input_shape}"
  203. )
  204. nodes_to_remove.append(node)
  205. if nodes_to_remove:
  206. graph_input_names = set(self.model.get_graphs_input_names())
  207. graph_output_names = set(self.model.get_graphs_output_names())
  208. for node in nodes_to_remove:
  209. if bool(set(node.output) & graph_output_names):
  210. if (
  211. not bool(set(node.input) & graph_input_names)
  212. and len(self.model.input_name_to_nodes()[node.input[0]]) == 1 # parent has only one child
  213. ):
  214. self.model.replace_output_of_all_nodes(node.input[0], node.output[0])
  215. else:
  216. continue
  217. else:
  218. self.model.replace_input_of_all_nodes(node.output[0], node.input[0])
  219. self.model.remove_node(node)
  220. class NumpyHelper:
  221. @staticmethod
  222. def to_array(tensor: TensorProto, fill_zeros: bool = False) -> ndarray:
  223. # When weights are in external data format but not presented, we can still test the optimizer with two changes:
  224. # (1) set fill_zeros = True (2) change load_external_data=False in optimizer.py
  225. if fill_zeros:
  226. from onnx import mapping
  227. return ndarray(
  228. shape=tensor.dims,
  229. dtype=mapping.TENSOR_TYPE_TO_NP_TYPE[tensor.data_type],
  230. )
  231. return numpy_helper.to_array(tensor)