m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

169 lines
6.1 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from typing import Optional
  7. from fusion_attention_unet import FusionAttentionUnet
  8. from fusion_biassplitgelu import FusionBiasSplitGelu
  9. from fusion_group_norm import FusionGroupNorm
  10. from fusion_nhwc_conv import FusionNhwcConv
  11. from fusion_options import FusionOptions
  12. from fusion_transpose import FusionTranspose
  13. from onnx import ModelProto
  14. from onnx_model import OnnxModel
  15. from onnx_model_bert import BertOnnxModel
  16. logger = getLogger(__name__)
  17. class UnetOnnxModel(BertOnnxModel):
  18. def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0):
  19. """Initialize UNet ONNX Model.
  20. Args:
  21. model (ModelProto): the ONNX model
  22. num_heads (int, optional): number of attention heads. Defaults to 0 (detect the parameter automatically).
  23. hidden_size (int, optional): hidden dimension. Defaults to 0 (detect the parameter automatically).
  24. """
  25. assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0)
  26. super().__init__(model, num_heads=num_heads, hidden_size=hidden_size)
  27. def preprocess(self):
  28. self.remove_useless_div()
  29. def postprocess(self):
  30. self.merge_sequential_transpose()
  31. self.prune_graph()
  32. self.remove_unused_constant()
  33. def remove_useless_div(self):
  34. """Remove Div by 1"""
  35. div_nodes = [node for node in self.nodes() if node.op_type == "Div"]
  36. nodes_to_remove = []
  37. for div in div_nodes:
  38. if self.find_constant_input(div, 1.0) == 1:
  39. nodes_to_remove.append(div)
  40. for node in nodes_to_remove:
  41. self.replace_input_of_all_nodes(node.output[0], node.input[0])
  42. if nodes_to_remove:
  43. self.remove_nodes(nodes_to_remove)
  44. logger.info("Removed %d useless Div (by 1) nodes", len(nodes_to_remove))
  45. def convert_conv_to_nhwc(self):
  46. # Do not update weight here since save external data has a bug
  47. conv_to_nhwc_conv = FusionNhwcConv(self, update_weight=False)
  48. conv_to_nhwc_conv.apply()
  49. def merge_sequential_transpose(self):
  50. fusion_transpose = FusionTranspose(self)
  51. fusion_transpose.apply()
  52. remove_count = 0
  53. nodes = self.get_nodes_by_op_type("Transpose")
  54. for node in nodes:
  55. permutation = OnnxModel.get_node_attribute(node, "perm")
  56. assert isinstance(permutation, list)
  57. if permutation != list(range(len(permutation))):
  58. continue
  59. assert not (
  60. self.find_graph_output(node.output[0])
  61. or self.find_graph_input(node.input[0])
  62. or self.find_graph_output(node.input[0])
  63. )
  64. # Let all children nodes skip current Transpose node and link to its parent
  65. # Note that we cannot update parent node output since parent node might have more than one children.
  66. self.replace_input_of_all_nodes(node.output[0], node.input[0])
  67. self.remove_node(node)
  68. remove_count += 1
  69. total = len(fusion_transpose.nodes_to_remove) + remove_count
  70. if total:
  71. logger.info("Removed %d Transpose nodes", total)
  72. def optimize(self, options: Optional[FusionOptions] = None):
  73. if (options is not None) and not options.enable_shape_inference:
  74. self.disable_shape_inference()
  75. self.utils.remove_identity_nodes()
  76. # Remove cast nodes that having same data type of input and output based on symbolic shape inference.
  77. self.utils.remove_useless_cast_nodes()
  78. if (options is None) or options.enable_layer_norm:
  79. self.fuse_layer_norm()
  80. if (options is None) or options.enable_gelu:
  81. self.fuse_gelu()
  82. self.preprocess()
  83. self.fuse_reshape()
  84. if (options is None) or options.enable_group_norm:
  85. group_norm_fusion = FusionGroupNorm(self)
  86. group_norm_fusion.apply()
  87. if (options is None) or options.enable_bias_splitgelu:
  88. bias_split_gelu_fusion = FusionBiasSplitGelu(self)
  89. bias_split_gelu_fusion.apply()
  90. if (options is None) or options.enable_attention:
  91. self_attention_fusion = FusionAttentionUnet(self, self.hidden_size, self.num_heads, False, False)
  92. self_attention_fusion.apply()
  93. enable_packed_kv = (options is None) or options.enable_packed_kv
  94. cross_attention_fusion = FusionAttentionUnet(self, self.hidden_size, self.num_heads, True, enable_packed_kv)
  95. cross_attention_fusion.apply()
  96. if (options is None) or options.enable_skip_layer_norm:
  97. self.fuse_skip_layer_norm()
  98. self.fuse_shape()
  99. # Remove reshape nodes that having same shape of input and output based on symbolic shape inference.
  100. self.utils.remove_useless_reshape_nodes()
  101. self.convert_conv_to_nhwc()
  102. if (options is None) or options.enable_bias_skip_layer_norm:
  103. # Fuse SkipLayerNormalization and Add Bias before it.
  104. self.fuse_add_bias_skip_layer_norm()
  105. if options is not None and options.enable_gelu_approximation:
  106. self.gelu_approximation()
  107. self.postprocess()
  108. logger.info(f"opset version: {self.get_opset_version()}")
  109. def get_fused_operator_statistics(self):
  110. """
  111. Returns node count of fused operators.
  112. """
  113. op_count = {}
  114. ops = [
  115. "Attention",
  116. "MultiHeadAttention",
  117. "Gelu",
  118. "FastGelu",
  119. "LayerNormalization",
  120. "SkipLayerNormalization",
  121. "BiasSplitGelu",
  122. "GroupNorm",
  123. "NhwcConv",
  124. ]
  125. for op in ops:
  126. nodes = self.get_nodes_by_op_type(op)
  127. op_count[op] = len(nodes)
  128. logger.info(f"Optimized operators:{op_count}")
  129. return op_count