You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
|
|
# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # --------------------------------------------------------------------------
from logging import getLogger from typing import Dict, List
from fusion_base import Fusion from fusion_utils import FusionUtils from onnx import NodeProto, helper from onnx_model import OnnxModel
logger = getLogger(__name__)
class FusionTranspose(Fusion): def __init__(self, model: OnnxModel): super().__init__(model, "Transpose", "Transpose")
def fuse( self, transpose_node: NodeProto, input_name_to_nodes: Dict[str, List[NodeProto]], output_name_to_node: Dict[str, NodeProto], ): """
Case 1: (input)-->Transpose(perm=a)-->Transpose(perm=b)--> After: (input)-->Transpose(perm=a)--> (this path can be removed if the output is not used anymore) | +----->Transpose(perm=a*b)-->
Case 2 (Cast has only one child): (input)-->Transpose(perm=a)--> Cast -->Transpose(perm=b)--> After: (input)-->Transpose(perm=a)--> (this path can be removed if the output is not used anymore) | +----->Cast --> Transpose(perm=a*b)-->
"""
transpose_b = transpose_node if transpose_b.input[0] not in output_name_to_node: return
transpose_a = output_name_to_node[transpose_b.input[0]] if transpose_a.op_type != "Cast": cast_node = None else: cast_node = transpose_a
cast_children = self.model.get_children(cast_node, input_name_to_nodes) if cast_children and len(cast_children) > 1: return transpose_a = output_name_to_node[cast_node.input[0]]
if transpose_a.op_type != "Transpose": return
permutation = OnnxModel.get_node_attribute(transpose_b, "perm") assert isinstance(permutation, list)
parent_permutation = OnnxModel.get_node_attribute(transpose_a, "perm") assert isinstance(parent_permutation, list)
assert len(parent_permutation) == len(permutation)
output_permutation = [] for j, index in enumerate(permutation): output_permutation.append(parent_permutation[index])
if cast_node is None: if FusionUtils.skip_parent(self.model, transpose_b, transpose_a, input_name_to_nodes): self.nodes_to_remove.append(transpose_a) else: if FusionUtils.skip_parent(self.model, cast_node, transpose_a, input_name_to_nodes): self.nodes_to_remove.append(transpose_a) transpose_b.ClearField("attribute") transpose_b.attribute.extend([helper.make_attribute("perm", output_permutation)])
|