m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

110 lines
3.6 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from typing import Dict, List, Union
  7. from fusion_base import Fusion
  8. from fusion_utils import FusionUtils
  9. from numpy import ndarray
  10. from onnx import NodeProto, TensorProto
  11. from onnx_model import OnnxModel
  12. logger = getLogger(__name__)
  13. class FusionShape(Fusion):
  14. def __init__(self, model: OnnxModel):
  15. super().__init__(model, "Shape", "Concat")
  16. self.utils = FusionUtils(model)
  17. self.shape_infer = None
  18. self.shape_infer_done = False
  19. def get_dimensions_from_tensor_proto(self, tensor_proto: TensorProto) -> Union[int, None]:
  20. if tensor_proto.type.tensor_type.HasField("shape"):
  21. return len(tensor_proto.type.tensor_type.shape.dim)
  22. else:
  23. return None
  24. def get_dimensions(self, input_name: str) -> Union[int, None]:
  25. graph_input = self.model.find_graph_input(input_name)
  26. if graph_input:
  27. return self.get_dimensions_from_tensor_proto(graph_input)
  28. if not self.shape_infer_done:
  29. self.shape_infer = self.model.infer_runtime_shape({}, update=True)
  30. self.shape_infer_done = True
  31. if self.shape_infer is not None:
  32. return self.get_dimensions_from_tensor_proto(self.shape_infer.known_vi_[input_name])
  33. return None
  34. def fuse(
  35. self,
  36. concat_node: NodeProto,
  37. input_name_to_nodes: Dict[str, List[NodeProto]],
  38. output_name_to_node: Dict[str, NodeProto],
  39. ):
  40. """
  41. Smplify subgraph like
  42. (2d_input)
  43. / \
  44. Shape shape
  45. / \
  46. Gather(indices=0) Gather(indices=1)
  47. | |
  48. Unsqueeze(axes=0) Unsqueeze(axes=0)
  49. \ /
  50. Concat
  51. |
  52. into (2d_input) --> Shape -->
  53. """
  54. opset_version = self.model.get_opset_version()
  55. inputs = len(concat_node.input)
  56. root = None
  57. shape_output = None
  58. for i in range(inputs):
  59. path = self.model.match_parent_path(
  60. concat_node,
  61. ["Unsqueeze", "Gather", "Shape"],
  62. [i, 0, 0],
  63. output_name_to_node,
  64. )
  65. if path is None:
  66. return
  67. unsqueeze, gather, shape = path
  68. if i == 0:
  69. shape_output = shape.output[0]
  70. if root is None:
  71. root = shape.input[0]
  72. if self.get_dimensions(root) != inputs:
  73. return
  74. elif shape.input[0] != root:
  75. return
  76. if not FusionUtils.check_node_attribute(unsqueeze, "axis", 0, default_value=0):
  77. return
  78. if opset_version < 13:
  79. if not FusionUtils.check_node_attribute(unsqueeze, "axes", [0]):
  80. return
  81. else:
  82. if not self.utils.check_node_input_value(unsqueeze, 1, [0]):
  83. return
  84. value = self.model.get_constant_value(gather.input[1])
  85. if not (isinstance(value, ndarray) and value.size == 1 and value.item() == i):
  86. return
  87. if self.model.find_graph_output(concat_node.output[0]) is None:
  88. self.model.replace_input_of_all_nodes(concat_node.output[0], shape_output)
  89. self.increase_counter("Reshape")
  90. self.prune_graph = True