m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

174 lines
6.1 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. import numpy as np
  7. from fusion_base import Fusion
  8. from onnx import TensorProto, helper, numpy_helper
  9. from onnx_model import OnnxModel
  10. logger = getLogger(__name__)
  11. class FusionReshape(Fusion):
  12. def __init__(self, model: OnnxModel):
  13. super().__init__(model, "Reshape", "Reshape")
  14. self.prune_graph: bool = False
  15. def replace_reshape_node(self, shape, reshape_node, concat_node):
  16. shape_value = np.asarray(shape, dtype=np.int64)
  17. constant_shape_name = self.model.create_node_name("Constant", "constant_shape")
  18. new_node = helper.make_node(
  19. "Constant",
  20. inputs=[],
  21. outputs=[constant_shape_name],
  22. value=helper.make_tensor(
  23. name="const_tensor",
  24. data_type=TensorProto.INT64,
  25. dims=shape_value.shape,
  26. vals=bytes(shape_value),
  27. raw=True,
  28. ),
  29. )
  30. reshape_node.input[1] = constant_shape_name
  31. reshape_node.name = self.model.create_node_name("Reshape", "Reshape_Fuse")
  32. self.nodes_to_remove.extend([concat_node])
  33. self.nodes_to_add.append(new_node)
  34. self.node_name_to_graph_name[new_node.name] = self.this_graph_name
  35. def fuse(self, reshape_node, input_name_to_nodes, output_name_to_node):
  36. if reshape_node.input[1] not in output_name_to_node:
  37. return
  38. concat_node = output_name_to_node[reshape_node.input[1]]
  39. if concat_node.op_type != "Concat" or len(concat_node.input) < 3 or len(concat_node.input) > 4:
  40. return
  41. path0 = self.model.match_parent_path(
  42. concat_node,
  43. ["Unsqueeze", "Gather", "Shape"],
  44. [0, 0, 0],
  45. output_name_to_node,
  46. )
  47. if path0 is None:
  48. return
  49. (unsqueeze_0, gather_0, shape_0) = path0
  50. path1 = self.model.match_parent_path(
  51. concat_node,
  52. ["Unsqueeze", "Gather", "Shape"],
  53. [1, 0, 0],
  54. output_name_to_node,
  55. )
  56. if path1 is None:
  57. return
  58. (unsqueeze_1, gather_1, shape_1) = path1
  59. shape = []
  60. gather_value = self.model.get_constant_value(gather_0.input[1])
  61. if gather_value == 0:
  62. shape.append(0)
  63. gather_value = self.model.get_constant_value(gather_1.input[1])
  64. if gather_value == 1:
  65. shape.append(0)
  66. if len(shape) != 2:
  67. return
  68. path2 = []
  69. path3 = []
  70. shape_nodes = [shape_0, shape_1]
  71. if len(concat_node.input) == 3 and self.model.get_initializer(concat_node.input[2]) is None:
  72. path2 = self.model.match_parent_path(
  73. concat_node,
  74. ["Unsqueeze", "Mul", "Gather", "Shape"],
  75. [2, 0, 0, 0],
  76. output_name_to_node,
  77. )
  78. if path2 is None:
  79. path2 = self.model.match_parent_path(
  80. concat_node,
  81. ["Unsqueeze", "Mul", "Squeeze", "Slice", "Shape"],
  82. [2, 0, 0, 0, 0],
  83. output_name_to_node,
  84. ) # GPT2 exported by PyTorch 1.4 with opset_version=11
  85. if path2 is None:
  86. return
  87. path3 = self.model.match_parent_path(
  88. concat_node,
  89. ["Unsqueeze", "Mul", "Gather", "Shape"],
  90. [2, 0, 1, 0],
  91. output_name_to_node,
  92. )
  93. if path3 is None:
  94. path3 = self.model.match_parent_path(
  95. concat_node,
  96. ["Unsqueeze", "Mul", "Squeeze", "Slice", "Shape"],
  97. [2, 0, 1, 0, 0],
  98. output_name_to_node,
  99. ) # GPT2 exported by PyTorch 1.4 with opset_version=11
  100. if path3 is None:
  101. return
  102. shape_nodes.extend([path2[-1], path3[-1]])
  103. shape.append(-1)
  104. elif len(concat_node.input) > 2:
  105. concat_value = self.model.get_constant_value(concat_node.input[2])
  106. if concat_value is None:
  107. return
  108. if isinstance(concat_value, np.ndarray):
  109. shape.extend(concat_value.tolist())
  110. else:
  111. shape.append(concat_value)
  112. if len(concat_node.input) == 4 and self.model.get_constant_value(concat_node.input[3]) is None:
  113. if -1 in shape:
  114. return
  115. path2 = self.model.match_parent_path(
  116. concat_node,
  117. ["Unsqueeze", "Div", "Gather", "Shape"],
  118. [3, 0, 0, 0],
  119. output_name_to_node,
  120. )
  121. if path2 is None:
  122. path2 = self.model.match_parent_path(
  123. concat_node,
  124. ["Unsqueeze", "Div", "Squeeze", "Slice", "Shape"],
  125. [3, 0, 0, 0, 0],
  126. output_name_to_node,
  127. ) # GPT2 exported by PyTorch 1.4 with opset_version=11
  128. if path2 is None:
  129. return
  130. shape_nodes.extend([path2[-1]])
  131. shape.append(-1)
  132. elif len(concat_node.input) > 3:
  133. concat_3 = self.model.get_initializer(concat_node.input[3])
  134. if concat_3 is None:
  135. return
  136. concat_value = numpy_helper.to_array(concat_3)
  137. if isinstance(concat_value, np.ndarray):
  138. shape.extend(concat_value.tolist())
  139. else:
  140. shape.append(concat_value)
  141. root_input = reshape_node.input[0]
  142. same_shape_input = True
  143. for shape_node in shape_nodes:
  144. if shape_node.input[0] != root_input:
  145. same_shape_input = False
  146. if not same_shape_input:
  147. return
  148. self.replace_reshape_node(shape, reshape_node, concat_node)
  149. # TODO(tlwu): Subgraph blocks pruning un-used nodes. Add code to remove un-used nodes safely.
  150. self.prune_graph = True