You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
90 lines
3.4 KiB
90 lines
3.4 KiB
# -------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
# --------------------------------------------------------------------------
|
|
|
|
from logging import getLogger
|
|
from typing import List
|
|
|
|
from fusion_base import Fusion
|
|
from onnx import TensorProto, helper, numpy_helper
|
|
from onnx_model import OnnxModel
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
class FusionNhwcConv(Fusion):
|
|
"""Convert Conv to NhwcConv"""
|
|
|
|
def __init__(self, model: OnnxModel, update_weight=False):
|
|
super().__init__(model, "NhwcConv", ["Conv"], "NhwcConv")
|
|
self.update_weight = update_weight
|
|
|
|
def create_transpose_node(self, input_name: str, perm: List[int], output_name=None):
|
|
"""Append a Transpose node after an input"""
|
|
node_name = self.model.create_node_name("Transpose")
|
|
|
|
if output_name is None:
|
|
output_name = node_name + "_out" + "-" + input_name
|
|
|
|
transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name)
|
|
transpose_node.attribute.extend([helper.make_attribute("perm", perm)])
|
|
|
|
return transpose_node
|
|
|
|
def fuse(self, conv, input_name_to_nodes, output_name_to_node):
|
|
# Add Transpose node to convert input from NCHW to NHWC
|
|
input_transpose_node = self.create_transpose_node(conv.input[0], [0, 2, 3, 1])
|
|
|
|
nhwc_conv_input = input_transpose_node.output[0]
|
|
|
|
# Create a tensor for transposed weights (already in NHWC format).
|
|
node_name = self.model.create_node_name("NhwcConv")
|
|
|
|
# Make sure the weights is 4D
|
|
weight_tensor = self.model.get_initializer(conv.input[1])
|
|
if weight_tensor is None:
|
|
return
|
|
weight = numpy_helper.to_array(weight_tensor)
|
|
if len(weight.shape) != 4:
|
|
return
|
|
|
|
if self.update_weight:
|
|
# Transpose weights from NCHW to NHWC
|
|
weight = weight.transpose(0, 2, 3, 1)
|
|
|
|
weight_name = node_name + "_weight_NHWC"
|
|
nhwc_weight = helper.make_tensor(
|
|
name=weight_name,
|
|
data_type=TensorProto.FLOAT,
|
|
dims=list(weight.shape),
|
|
vals=weight.flatten().tolist(),
|
|
)
|
|
self.model.add_initializer(nhwc_weight, self.this_graph_name)
|
|
weight_transpose_node = None
|
|
else:
|
|
weight_transpose_node = self.create_transpose_node(conv.input[1], [0, 2, 3, 1])
|
|
weight_name = weight_transpose_node.output[0]
|
|
|
|
nhwc_output_name = node_name + "_out" + "-" + conv.output[0]
|
|
nhwc_conv = helper.make_node(
|
|
"NhwcConv",
|
|
inputs=[nhwc_conv_input, weight_name] + conv.input[2:],
|
|
outputs=[nhwc_output_name],
|
|
name=node_name + "-" + conv.name,
|
|
)
|
|
nhwc_conv.attribute.extend(conv.attribute)
|
|
nhwc_conv.domain = "com.microsoft"
|
|
|
|
output_transpose_node = self.create_transpose_node(nhwc_conv.output[0], [0, 3, 1, 2], conv.output[0])
|
|
|
|
self.nodes_to_remove.append(conv)
|
|
|
|
nodes_to_add = [input_transpose_node, nhwc_conv, output_transpose_node]
|
|
if weight_transpose_node:
|
|
nodes_to_add.append(weight_transpose_node)
|
|
for node in nodes_to_add:
|
|
self.node_name_to_graph_name[node.name] = self.this_graph_name
|
|
self.nodes_to_add.extend(nodes_to_add)
|
|
|
|
self.increase_counter("NhwcConv")
|