You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
169 lines
6.1 KiB
169 lines
6.1 KiB
# -------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
# --------------------------------------------------------------------------
|
|
|
|
from logging import getLogger
|
|
from typing import Optional
|
|
|
|
from fusion_attention_unet import FusionAttentionUnet
|
|
from fusion_biassplitgelu import FusionBiasSplitGelu
|
|
from fusion_group_norm import FusionGroupNorm
|
|
from fusion_nhwc_conv import FusionNhwcConv
|
|
from fusion_options import FusionOptions
|
|
from fusion_transpose import FusionTranspose
|
|
from onnx import ModelProto
|
|
from onnx_model import OnnxModel
|
|
from onnx_model_bert import BertOnnxModel
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
class UnetOnnxModel(BertOnnxModel):
|
|
def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0):
|
|
"""Initialize UNet ONNX Model.
|
|
|
|
Args:
|
|
model (ModelProto): the ONNX model
|
|
num_heads (int, optional): number of attention heads. Defaults to 0 (detect the parameter automatically).
|
|
hidden_size (int, optional): hidden dimension. Defaults to 0 (detect the parameter automatically).
|
|
"""
|
|
assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0)
|
|
|
|
super().__init__(model, num_heads=num_heads, hidden_size=hidden_size)
|
|
|
|
def preprocess(self):
|
|
self.remove_useless_div()
|
|
|
|
def postprocess(self):
|
|
self.merge_sequential_transpose()
|
|
self.prune_graph()
|
|
self.remove_unused_constant()
|
|
|
|
def remove_useless_div(self):
|
|
"""Remove Div by 1"""
|
|
div_nodes = [node for node in self.nodes() if node.op_type == "Div"]
|
|
|
|
nodes_to_remove = []
|
|
for div in div_nodes:
|
|
if self.find_constant_input(div, 1.0) == 1:
|
|
nodes_to_remove.append(div)
|
|
|
|
for node in nodes_to_remove:
|
|
self.replace_input_of_all_nodes(node.output[0], node.input[0])
|
|
|
|
if nodes_to_remove:
|
|
self.remove_nodes(nodes_to_remove)
|
|
logger.info("Removed %d useless Div (by 1) nodes", len(nodes_to_remove))
|
|
|
|
def convert_conv_to_nhwc(self):
|
|
# Do not update weight here since save external data has a bug
|
|
conv_to_nhwc_conv = FusionNhwcConv(self, update_weight=False)
|
|
conv_to_nhwc_conv.apply()
|
|
|
|
def merge_sequential_transpose(self):
|
|
fusion_transpose = FusionTranspose(self)
|
|
fusion_transpose.apply()
|
|
|
|
remove_count = 0
|
|
nodes = self.get_nodes_by_op_type("Transpose")
|
|
for node in nodes:
|
|
permutation = OnnxModel.get_node_attribute(node, "perm")
|
|
assert isinstance(permutation, list)
|
|
if permutation != list(range(len(permutation))):
|
|
continue
|
|
assert not (
|
|
self.find_graph_output(node.output[0])
|
|
or self.find_graph_input(node.input[0])
|
|
or self.find_graph_output(node.input[0])
|
|
)
|
|
|
|
# Let all children nodes skip current Transpose node and link to its parent
|
|
# Note that we cannot update parent node output since parent node might have more than one children.
|
|
self.replace_input_of_all_nodes(node.output[0], node.input[0])
|
|
|
|
self.remove_node(node)
|
|
remove_count += 1
|
|
|
|
total = len(fusion_transpose.nodes_to_remove) + remove_count
|
|
if total:
|
|
logger.info("Removed %d Transpose nodes", total)
|
|
|
|
def optimize(self, options: Optional[FusionOptions] = None):
|
|
if (options is not None) and not options.enable_shape_inference:
|
|
self.disable_shape_inference()
|
|
|
|
self.utils.remove_identity_nodes()
|
|
|
|
# Remove cast nodes that having same data type of input and output based on symbolic shape inference.
|
|
self.utils.remove_useless_cast_nodes()
|
|
|
|
if (options is None) or options.enable_layer_norm:
|
|
self.fuse_layer_norm()
|
|
|
|
if (options is None) or options.enable_gelu:
|
|
self.fuse_gelu()
|
|
|
|
self.preprocess()
|
|
|
|
self.fuse_reshape()
|
|
|
|
if (options is None) or options.enable_group_norm:
|
|
group_norm_fusion = FusionGroupNorm(self)
|
|
group_norm_fusion.apply()
|
|
|
|
if (options is None) or options.enable_bias_splitgelu:
|
|
bias_split_gelu_fusion = FusionBiasSplitGelu(self)
|
|
bias_split_gelu_fusion.apply()
|
|
|
|
if (options is None) or options.enable_attention:
|
|
self_attention_fusion = FusionAttentionUnet(self, self.hidden_size, self.num_heads, False, False)
|
|
self_attention_fusion.apply()
|
|
|
|
enable_packed_kv = (options is None) or options.enable_packed_kv
|
|
cross_attention_fusion = FusionAttentionUnet(self, self.hidden_size, self.num_heads, True, enable_packed_kv)
|
|
cross_attention_fusion.apply()
|
|
|
|
if (options is None) or options.enable_skip_layer_norm:
|
|
self.fuse_skip_layer_norm()
|
|
|
|
self.fuse_shape()
|
|
|
|
# Remove reshape nodes that having same shape of input and output based on symbolic shape inference.
|
|
self.utils.remove_useless_reshape_nodes()
|
|
|
|
self.convert_conv_to_nhwc()
|
|
|
|
if (options is None) or options.enable_bias_skip_layer_norm:
|
|
# Fuse SkipLayerNormalization and Add Bias before it.
|
|
self.fuse_add_bias_skip_layer_norm()
|
|
|
|
if options is not None and options.enable_gelu_approximation:
|
|
self.gelu_approximation()
|
|
|
|
self.postprocess()
|
|
|
|
logger.info(f"opset version: {self.get_opset_version()}")
|
|
|
|
def get_fused_operator_statistics(self):
|
|
"""
|
|
Returns node count of fused operators.
|
|
"""
|
|
op_count = {}
|
|
ops = [
|
|
"Attention",
|
|
"MultiHeadAttention",
|
|
"Gelu",
|
|
"FastGelu",
|
|
"LayerNormalization",
|
|
"SkipLayerNormalization",
|
|
"BiasSplitGelu",
|
|
"GroupNorm",
|
|
"NhwcConv",
|
|
]
|
|
for op in ops:
|
|
nodes = self.get_nodes_by_op_type(op)
|
|
op_count[op] = len(nodes)
|
|
|
|
logger.info(f"Optimized operators:{op_count}")
|
|
return op_count
|