m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

81 lines
2.9 KiB

# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger
from typing import Dict, List
from fusion_base import Fusion
from fusion_utils import FusionUtils
from onnx import NodeProto, helper
from onnx_model import OnnxModel
logger = getLogger(__name__)
class FusionTranspose(Fusion):
def __init__(self, model: OnnxModel):
super().__init__(model, "Transpose", "Transpose")
def fuse(
self,
transpose_node: NodeProto,
input_name_to_nodes: Dict[str, List[NodeProto]],
output_name_to_node: Dict[str, NodeProto],
):
"""
Case 1:
(input)-->Transpose(perm=a)-->Transpose(perm=b)-->
After:
(input)-->Transpose(perm=a)--> (this path can be removed if the output is not used anymore)
|
+----->Transpose(perm=a*b)-->
Case 2 (Cast has only one child):
(input)-->Transpose(perm=a)--> Cast -->Transpose(perm=b)-->
After:
(input)-->Transpose(perm=a)--> (this path can be removed if the output is not used anymore)
|
+----->Cast --> Transpose(perm=a*b)-->
"""
transpose_b = transpose_node
if transpose_b.input[0] not in output_name_to_node:
return
transpose_a = output_name_to_node[transpose_b.input[0]]
if transpose_a.op_type != "Cast":
cast_node = None
else:
cast_node = transpose_a
cast_children = self.model.get_children(cast_node, input_name_to_nodes)
if cast_children and len(cast_children) > 1:
return
transpose_a = output_name_to_node[cast_node.input[0]]
if transpose_a.op_type != "Transpose":
return
permutation = OnnxModel.get_node_attribute(transpose_b, "perm")
assert isinstance(permutation, list)
parent_permutation = OnnxModel.get_node_attribute(transpose_a, "perm")
assert isinstance(parent_permutation, list)
assert len(parent_permutation) == len(permutation)
output_permutation = []
for j, index in enumerate(permutation):
output_permutation.append(parent_permutation[index])
if cast_node is None:
if FusionUtils.skip_parent(self.model, transpose_b, transpose_a, input_name_to_nodes):
self.nodes_to_remove.append(transpose_a)
else:
if FusionUtils.skip_parent(self.model, cast_node, transpose_a, input_name_to_nodes):
self.nodes_to_remove.append(transpose_a)
transpose_b.ClearField("attribute")
transpose_b.attribute.extend([helper.make_attribute("perm", output_permutation)])