m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

81 lines
2.9 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from typing import Dict, List
  7. from fusion_base import Fusion
  8. from fusion_utils import FusionUtils
  9. from onnx import NodeProto, helper
  10. from onnx_model import OnnxModel
  11. logger = getLogger(__name__)
  12. class FusionTranspose(Fusion):
  13. def __init__(self, model: OnnxModel):
  14. super().__init__(model, "Transpose", "Transpose")
  15. def fuse(
  16. self,
  17. transpose_node: NodeProto,
  18. input_name_to_nodes: Dict[str, List[NodeProto]],
  19. output_name_to_node: Dict[str, NodeProto],
  20. ):
  21. """
  22. Case 1:
  23. (input)-->Transpose(perm=a)-->Transpose(perm=b)-->
  24. After:
  25. (input)-->Transpose(perm=a)--> (this path can be removed if the output is not used anymore)
  26. |
  27. +----->Transpose(perm=a*b)-->
  28. Case 2 (Cast has only one child):
  29. (input)-->Transpose(perm=a)--> Cast -->Transpose(perm=b)-->
  30. After:
  31. (input)-->Transpose(perm=a)--> (this path can be removed if the output is not used anymore)
  32. |
  33. +----->Cast --> Transpose(perm=a*b)-->
  34. """
  35. transpose_b = transpose_node
  36. if transpose_b.input[0] not in output_name_to_node:
  37. return
  38. transpose_a = output_name_to_node[transpose_b.input[0]]
  39. if transpose_a.op_type != "Cast":
  40. cast_node = None
  41. else:
  42. cast_node = transpose_a
  43. cast_children = self.model.get_children(cast_node, input_name_to_nodes)
  44. if cast_children and len(cast_children) > 1:
  45. return
  46. transpose_a = output_name_to_node[cast_node.input[0]]
  47. if transpose_a.op_type != "Transpose":
  48. return
  49. permutation = OnnxModel.get_node_attribute(transpose_b, "perm")
  50. assert isinstance(permutation, list)
  51. parent_permutation = OnnxModel.get_node_attribute(transpose_a, "perm")
  52. assert isinstance(parent_permutation, list)
  53. assert len(parent_permutation) == len(permutation)
  54. output_permutation = []
  55. for j, index in enumerate(permutation):
  56. output_permutation.append(parent_permutation[index])
  57. if cast_node is None:
  58. if FusionUtils.skip_parent(self.model, transpose_b, transpose_a, input_name_to_nodes):
  59. self.nodes_to_remove.append(transpose_a)
  60. else:
  61. if FusionUtils.skip_parent(self.model, cast_node, transpose_a, input_name_to_nodes):
  62. self.nodes_to_remove.append(transpose_a)
  63. transpose_b.ClearField("attribute")
  64. transpose_b.attribute.extend([helper.make_attribute("perm", output_permutation)])