m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

198 lines
8.0 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from typing import Dict
  7. import numpy as np
  8. from fusion_base import Fusion
  9. from fusion_utils import FusionUtils
  10. from onnx import TensorProto, helper
  11. from onnx_model import OnnxModel
  12. logger = getLogger(__name__)
  13. class FusionGroupNorm(Fusion):
  14. def __init__(self, model: OnnxModel):
  15. super().__init__(model, "GroupNorm", "Add")
  16. def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict):
  17. """
  18. Fuse Group Normalization subgraph into one node GroupNorm.
  19. The following is the pattern with swish activation:
  20. +----------------Shape-------------------------------+
  21. | |
  22. | (0, 32, -1) v (512x1x1) (512x1x1) (optional)
  23. [Root] --> Reshape -------> InstanceNormalization --> Reshape ---> Mul --> Add --> Mul--> [output]
  24. Bx512xHxW (scale=ones(32), B=zeros(32)) | ^ Bx512xHxW
  25. | |
  26. +--->Sigmoid (optional)
  27. The Mul and Sigmoid before output is for Swish activation. They are optional.
  28. """
  29. nodes = self.model.match_parent_path(
  30. add_node, ["Mul", "Reshape", "InstanceNormalization", "Reshape"], [0, 0, 0, 0], output_name_to_node
  31. )
  32. if nodes is None:
  33. return
  34. weight_mul, reshape_4d, instance_norm, reshape_3d = nodes
  35. root = reshape_3d.input[0]
  36. parents = self.model.match_parent_path(reshape_4d, ["Shape"], [1], output_name_to_node)
  37. if parents is None:
  38. return
  39. if parents[0].input[0] != root:
  40. return
  41. shape_node = parents[0]
  42. # Check whether it has swish activation.
  43. swish_mul = self.model.find_first_child_by_type(add_node, "Mul")
  44. swish_sigmoid = None
  45. if swish_mul is not None:
  46. sigmoid_path = self.model.match_parent_path(swish_mul, ["Sigmoid"], [None], output_name_to_node)
  47. if sigmoid_path is not None:
  48. swish_sigmoid = sigmoid_path[0]
  49. weight_input = weight_mul.input[1 - self.model.input_index(reshape_4d.output[0], weight_mul)]
  50. if not self.model.is_constant_with_specified_dimension(weight_input, 3, "group norm weight"):
  51. return
  52. bias_input = add_node.input[1 - self.model.input_index(weight_mul.output[0], add_node)]
  53. if not self.model.is_constant_with_specified_dimension(bias_input, 3, "layernorm bias"):
  54. return
  55. weight = self.model.get_constant_value(weight_input)
  56. if weight is None:
  57. return
  58. if not (len(weight.shape) == 3 and weight.shape[1] == 1 and weight.shape[2] == 1):
  59. return
  60. bias = self.model.get_constant_value(bias_input)
  61. if bias is None:
  62. return
  63. if not (len(bias.shape) == 3 and bias.shape[1] == 1 and bias.shape[2] == 1):
  64. return
  65. weight_elements = int(np.prod(weight.shape))
  66. bias_elements = int(np.prod(bias.shape))
  67. if weight_elements != bias_elements:
  68. return
  69. instance_norm_scale = self.model.get_constant_value(instance_norm.input[1])
  70. if instance_norm_scale is None:
  71. return
  72. instance_norm_bias = self.model.get_constant_value(instance_norm.input[2])
  73. if instance_norm_bias is None:
  74. return
  75. if not (
  76. len(instance_norm_scale.shape) == 1
  77. and len(instance_norm_bias.shape) == 1
  78. and instance_norm_scale.shape == instance_norm_bias.shape
  79. and instance_norm_scale.shape[0] == 32
  80. ):
  81. logger.info("InstanceNormalization groups=%d", instance_norm_scale.shape[0])
  82. return
  83. if not np.allclose(np.ones_like(instance_norm_scale), instance_norm_scale):
  84. return
  85. if not np.allclose(np.zeros_like(instance_norm_bias), instance_norm_bias):
  86. return
  87. group_norm_name = self.model.create_node_name("GroupNorm", name_prefix="GroupNorm")
  88. if weight_elements not in [320, 640, 960, 1280, 1920, 2560] + [128, 256, 512]:
  89. logger.info("GroupNorm channels=%d", weight_elements)
  90. gamma = helper.make_tensor(
  91. name=group_norm_name + "_gamma",
  92. data_type=TensorProto.FLOAT,
  93. dims=[weight_elements],
  94. vals=weight.flatten().tolist(),
  95. )
  96. self.model.add_initializer(gamma, self.this_graph_name)
  97. beta = helper.make_tensor(
  98. name=group_norm_name + "_beta",
  99. data_type=TensorProto.FLOAT,
  100. dims=[bias_elements],
  101. vals=bias.flatten().tolist(),
  102. )
  103. self.model.add_initializer(beta, self.this_graph_name)
  104. last_node = add_node
  105. subgraph_nodes = [add_node, weight_mul, reshape_4d, instance_norm, reshape_3d, shape_node]
  106. has_swish_activation = swish_mul and swish_sigmoid
  107. if swish_mul and swish_sigmoid:
  108. subgraph_nodes.extend([swish_mul, swish_sigmoid])
  109. last_node = swish_mul
  110. if not self.model.is_safe_to_fuse_nodes(
  111. subgraph_nodes,
  112. last_node.output,
  113. input_name_to_nodes,
  114. output_name_to_node,
  115. ):
  116. self.nodes_to_remove.extend([last_node])
  117. else:
  118. self.nodes_to_remove.extend(subgraph_nodes)
  119. # instance_norm_scale might from Constant node. Use prune graph to clear it.
  120. self.prune_graph = True
  121. # Right now GroupNorm only support float16 input. Need add a Cast in fp32 model.
  122. utils = FusionUtils(self.model)
  123. input = root
  124. output = last_node.output[0]
  125. if weight.dtype == np.float32:
  126. # Add a Cast node to get float16 input for GroupNorm
  127. cast_input, _cast_node = utils.cast_input(root, "float16")
  128. input = cast_input
  129. # Add a Cast node to convert back to float32 after GroupNorm
  130. output = group_norm_name + "_out"
  131. cast_node = helper.make_node("Cast", inputs=[group_norm_name + "_out"], outputs=[last_node.output[0]])
  132. cast_node.attribute.extend([helper.make_attribute("to", int(TensorProto.FLOAT))])
  133. self.model.add_node(cast_node)
  134. # NCHW to NHWC
  135. transpose_input = helper.make_node(
  136. "Transpose",
  137. [input],
  138. [input + "_NHWC"],
  139. name=self.model.create_node_name("Transpose", name_prefix="Transpose_NCHW_to_NHWC"),
  140. perm=[0, 2, 3, 1],
  141. )
  142. new_node = helper.make_node(
  143. "GroupNorm",
  144. inputs=[input + "_NHWC", group_norm_name + "_gamma", group_norm_name + "_beta"],
  145. outputs=[output + "_NHWC"],
  146. name=group_norm_name,
  147. )
  148. new_node.attribute.extend(instance_norm.attribute)
  149. new_node.attribute.extend([helper.make_attribute("groups", 32)])
  150. new_node.attribute.extend([helper.make_attribute("activation", 1 if has_swish_activation else 0)])
  151. new_node.domain = "com.microsoft"
  152. # NHWC to NCHW
  153. transpose_output = helper.make_node(
  154. "Transpose",
  155. [output + "_NHWC"],
  156. [output],
  157. name=self.model.create_node_name("Transpose", name_prefix="Transpose_NHWC_to_NCHW"),
  158. perm=[0, 3, 1, 2],
  159. )
  160. self.nodes_to_add.append(new_node)
  161. self.nodes_to_add.append(transpose_input)
  162. self.nodes_to_add.append(transpose_output)
  163. self.node_name_to_graph_name[new_node.name] = self.this_graph_name
  164. self.node_name_to_graph_name[transpose_input.name] = self.this_graph_name
  165. self.node_name_to_graph_name[transpose_output.name] = self.this_graph_name