m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

217 lines
8.2 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from typing import Dict
  7. from fusion_base import Fusion
  8. from fusion_utils import FusionUtils
  9. from onnx import helper
  10. from onnx_model import OnnxModel
  11. logger = getLogger(__name__)
  12. class FusionQOrderedMatMul(Fusion):
  13. def __init__(self, model: OnnxModel):
  14. super().__init__(model, "QOrderedMatMul", "MatMul")
  15. def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
  16. matmul_children = self.model.get_children(node, input_name_to_nodes)
  17. # Should only have 1 child - Bias Add
  18. if len(matmul_children) != 1 or matmul_children[0].op_type != "Add":
  19. return
  20. bias_add_node = matmul_children[0]
  21. # Atleast one of the inputs to Bias Add node must be a constant
  22. bias_add_node_index = 0
  23. if (
  24. self.model.get_constant_value(bias_add_node.input[0]) is None
  25. and self.model.get_constant_value(bias_add_node.input[1]) is None
  26. ):
  27. return
  28. if self.model.get_constant_value(bias_add_node.input[0]) is None:
  29. bias_add_node_index = 1
  30. bias_add_children = self.model.get_children(bias_add_node, input_name_to_nodes)
  31. if len(bias_add_children) != 1:
  32. return
  33. bias_add_child = bias_add_children[0]
  34. # Bias Add can have another Add downstream (Residual Add layer)
  35. residual_add_node = None
  36. downstream_quantize_node = None
  37. if bias_add_child.op_type == "Add":
  38. residual_add_node = bias_add_child
  39. residual_add_children = self.model.get_children(residual_add_node, input_name_to_nodes)
  40. if len(residual_add_children) != 1 or residual_add_children[0].op_type != "QuantizeLinear":
  41. return
  42. downstream_quantize_node = residual_add_children[0]
  43. elif bias_add_child.op_type == "QuantizeLinear":
  44. downstream_quantize_node = bias_add_child
  45. else:
  46. return
  47. # Make sure the downstream QuantizeLinear has the proper zero points and scales
  48. if not FusionUtils.check_qdq_node_for_fusion(downstream_quantize_node, self.model):
  49. return
  50. # The first input to MatMul should flow through a DequantizeLinear node
  51. first_path_id, first_input_parent_nodes, _ = self.model.match_parent_paths(
  52. node,
  53. [(["DequantizeLinear"], [0])],
  54. output_name_to_node,
  55. )
  56. # If Attention is not fused, this is the pattern to look for
  57. # leading upto the MatMul
  58. reshape_node_0 = None
  59. transpose_node_0 = None
  60. if first_path_id < 0:
  61. first_path_id, first_input_parent_nodes, _ = self.model.match_parent_paths(
  62. node,
  63. [(["Reshape", "Transpose", "DequantizeLinear", "QuantizeLinear"], [0, 0, 0, 0])],
  64. output_name_to_node,
  65. )
  66. if first_path_id < 0:
  67. return
  68. reshape_node_0 = first_input_parent_nodes[0]
  69. transpose_node_0 = first_input_parent_nodes[1]
  70. dequantize_node_0 = first_input_parent_nodes[2]
  71. else:
  72. dequantize_node_0 = first_input_parent_nodes[0]
  73. # Make sure the upstream DequantizeLinear-0 has the proper zero points and scales
  74. if not FusionUtils.check_qdq_node_for_fusion(dequantize_node_0, self.model):
  75. return
  76. # The second input to MatMul should flow through a DequantizeLinear node
  77. dequantize_node_1 = None
  78. is_weight_transpose_required = True
  79. weight_path_id, weight_nodes, _ = self.model.match_parent_paths(
  80. node,
  81. [(["DequantizeLinear", "QuantizeLinear", "Transpose", "DequantizeLinear"], [1, 0, 0, 0])],
  82. output_name_to_node,
  83. )
  84. if weight_path_id < 0:
  85. weight_path_id, weight_nodes, _ = self.model.match_parent_paths(
  86. node,
  87. [(["DequantizeLinear"], [1])],
  88. output_name_to_node,
  89. )
  90. if weight_path_id < 0:
  91. return
  92. dequantize_node_1 = weight_nodes[0]
  93. else:
  94. is_weight_transpose_required = False
  95. dequantize_node_1 = weight_nodes[3]
  96. # Check if weight 'B' is a constant
  97. if self.model.get_constant_value(dequantize_node_1.input[0]) is None:
  98. return
  99. # Make sure the upstream DequantizeLinear-1 has the proper zero points and scales
  100. # Per-channel scales are supported for weights alone
  101. if not FusionUtils.check_qdq_node_for_fusion(dequantize_node_1, self.model, False):
  102. return
  103. # Make sure the upstream flow into the Residual Add node flows through a DQ node
  104. residual_add_dequantize_node = None
  105. if residual_add_node is not None:
  106. residual_path_id, residual_input_parent_nodes, _ = self.model.match_parent_paths(
  107. residual_add_node,
  108. [
  109. (["DequantizeLinear"], [1]),
  110. ],
  111. output_name_to_node,
  112. )
  113. if residual_path_id < 0:
  114. return
  115. residual_add_dequantize_node = residual_input_parent_nodes[0]
  116. # Make sure the upstream DequantizeLinear to the Residual Add has the proper zero points and scales
  117. if residual_add_dequantize_node is not None and not FusionUtils.check_qdq_node_for_fusion(
  118. residual_add_dequantize_node, self.model
  119. ):
  120. return
  121. # Subgraph nodes to be fused
  122. subgraph_nodes = [node, bias_add_node] # MatMul + Bias Add
  123. if residual_add_node is not None:
  124. subgraph_nodes.extend([residual_add_node]) # Residual Add
  125. subgraph_nodes.extend(weight_nodes)
  126. subgraph_nodes.extend([downstream_quantize_node]) # Downstream Q node
  127. if not self.model.is_safe_to_fuse_nodes(
  128. subgraph_nodes, downstream_quantize_node.output, input_name_to_nodes, output_name_to_node
  129. ):
  130. logger.debug(f"It is not safe to fuse QOrderedMatMul node. Skip")
  131. return
  132. # Deal with the case where-in the Attention subgraph is not fused
  133. if transpose_node_0 is not None:
  134. self.model.replace_node_input(transpose_node_0, transpose_node_0.input[0], dequantize_node_0.input[0])
  135. # Make inputs
  136. fused_node_inputs = [
  137. reshape_node_0.output[0] if reshape_node_0 is not None else dequantize_node_0.input[0],
  138. dequantize_node_0.input[1],
  139. dequantize_node_1.input[0],
  140. dequantize_node_1.input[1],
  141. downstream_quantize_node.input[1],
  142. bias_add_node.input[bias_add_node_index],
  143. ]
  144. if residual_add_node is not None:
  145. fused_node_inputs.append(residual_add_dequantize_node.input[0])
  146. fused_node_inputs.append(residual_add_dequantize_node.input[1])
  147. # The MatMul weight 'B' and 'bias' need some post-processing
  148. # Transpose weight 'B' from order ROW to order COL
  149. # This offline transpose is needed only while using the CUDA EP
  150. # TODO: Make this fusion logic EP-agnostic ?
  151. if is_weight_transpose_required:
  152. weight_tensor = self.model.get_initializer(dequantize_node_1.input[0])
  153. FusionUtils.transpose_2d_int8_tensor(weight_tensor)
  154. fused_node = helper.make_node(
  155. "QOrderedMatMul",
  156. inputs=fused_node_inputs,
  157. outputs=[downstream_quantize_node.output[0]],
  158. name=self.model.create_node_name("QOrderedMatMul", name_prefix="QOrderedMatMul"),
  159. )
  160. fused_node.attribute.extend([helper.make_attribute("order_A", 1)])
  161. fused_node.attribute.extend([helper.make_attribute("order_B", 0)])
  162. fused_node.attribute.extend([helper.make_attribute("order_Y", 1)])
  163. fused_node.domain = "com.microsoft"
  164. self.nodes_to_remove.extend(subgraph_nodes)
  165. self.nodes_to_add.append(fused_node)
  166. self.node_name_to_graph_name[fused_node.name] = self.this_graph_name