m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

90 lines
3.4 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from typing import List
  7. from fusion_base import Fusion
  8. from onnx import TensorProto, helper, numpy_helper
  9. from onnx_model import OnnxModel
  10. logger = getLogger(__name__)
  11. class FusionNhwcConv(Fusion):
  12. """Convert Conv to NhwcConv"""
  13. def __init__(self, model: OnnxModel, update_weight=False):
  14. super().__init__(model, "NhwcConv", ["Conv"], "NhwcConv")
  15. self.update_weight = update_weight
  16. def create_transpose_node(self, input_name: str, perm: List[int], output_name=None):
  17. """Append a Transpose node after an input"""
  18. node_name = self.model.create_node_name("Transpose")
  19. if output_name is None:
  20. output_name = node_name + "_out" + "-" + input_name
  21. transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name)
  22. transpose_node.attribute.extend([helper.make_attribute("perm", perm)])
  23. return transpose_node
  24. def fuse(self, conv, input_name_to_nodes, output_name_to_node):
  25. # Add Transpose node to convert input from NCHW to NHWC
  26. input_transpose_node = self.create_transpose_node(conv.input[0], [0, 2, 3, 1])
  27. nhwc_conv_input = input_transpose_node.output[0]
  28. # Create a tensor for transposed weights (already in NHWC format).
  29. node_name = self.model.create_node_name("NhwcConv")
  30. # Make sure the weights is 4D
  31. weight_tensor = self.model.get_initializer(conv.input[1])
  32. if weight_tensor is None:
  33. return
  34. weight = numpy_helper.to_array(weight_tensor)
  35. if len(weight.shape) != 4:
  36. return
  37. if self.update_weight:
  38. # Transpose weights from NCHW to NHWC
  39. weight = weight.transpose(0, 2, 3, 1)
  40. weight_name = node_name + "_weight_NHWC"
  41. nhwc_weight = helper.make_tensor(
  42. name=weight_name,
  43. data_type=TensorProto.FLOAT,
  44. dims=list(weight.shape),
  45. vals=weight.flatten().tolist(),
  46. )
  47. self.model.add_initializer(nhwc_weight, self.this_graph_name)
  48. weight_transpose_node = None
  49. else:
  50. weight_transpose_node = self.create_transpose_node(conv.input[1], [0, 2, 3, 1])
  51. weight_name = weight_transpose_node.output[0]
  52. nhwc_output_name = node_name + "_out" + "-" + conv.output[0]
  53. nhwc_conv = helper.make_node(
  54. "NhwcConv",
  55. inputs=[nhwc_conv_input, weight_name] + conv.input[2:],
  56. outputs=[nhwc_output_name],
  57. name=node_name + "-" + conv.name,
  58. )
  59. nhwc_conv.attribute.extend(conv.attribute)
  60. nhwc_conv.domain = "com.microsoft"
  61. output_transpose_node = self.create_transpose_node(nhwc_conv.output[0], [0, 3, 1, 2], conv.output[0])
  62. self.nodes_to_remove.append(conv)
  63. nodes_to_add = [input_transpose_node, nhwc_conv, output_transpose_node]
  64. if weight_transpose_node:
  65. nodes_to_add.append(weight_transpose_node)
  66. for node in nodes_to_add:
  67. self.node_name_to_graph_name[node.name] = self.this_graph_name
  68. self.nodes_to_add.extend(nodes_to_add)
  69. self.increase_counter("NhwcConv")