You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
282 lines
11 KiB
282 lines
11 KiB
# -------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
# --------------------------------------------------------------------------
|
|
from logging import getLogger
|
|
from typing import Tuple
|
|
|
|
import numpy
|
|
from numpy import array_equal, ndarray
|
|
from onnx import NodeProto, TensorProto, helper, numpy_helper
|
|
from onnx import onnx_pb as onnx_proto
|
|
from onnx_model import OnnxModel
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
class FusionUtils:
|
|
def __init__(self, model: OnnxModel):
|
|
self.model: OnnxModel = model
|
|
|
|
def cast_graph_input_to_int32(self, input_name: str) -> Tuple[bool, str]:
|
|
graph_input = self.model.find_graph_input(input_name)
|
|
if graph_input is not None and graph_input.type.tensor_type.elem_type != TensorProto.INT32:
|
|
cast_output, cast_node = self.cast_input_to_int32(input_name)
|
|
logger.debug(f"Casted graph input {input_name} to int32")
|
|
return True, cast_output
|
|
|
|
logger.debug(f"Did not cast graph input {input_name} to int32: found {graph_input is not None}")
|
|
return False, input_name
|
|
|
|
def cast_input(self, input_name: str, target_type="int32"):
|
|
cast_output = input_name + "_" + target_type
|
|
|
|
# Avoid consequent Cast nodes.
|
|
inputs = [input_name]
|
|
output_name_to_node = self.model.output_name_to_node()
|
|
if input_name in output_name_to_node:
|
|
parent_node = output_name_to_node[input_name]
|
|
if parent_node and parent_node.op_type == "Cast":
|
|
inputs = [parent_node.input[0]]
|
|
|
|
cast_node = helper.make_node("Cast", inputs=inputs, outputs=[cast_output])
|
|
|
|
if target_type == "int32":
|
|
to_type = int(TensorProto.INT32)
|
|
elif target_type == "float32":
|
|
to_type = int(TensorProto.FLOAT)
|
|
elif target_type == "float16":
|
|
to_type = int(TensorProto.FLOAT16)
|
|
else:
|
|
raise ValueError("Invalid target_type: {target_type}")
|
|
|
|
cast_node.attribute.extend([helper.make_attribute("to", to_type)])
|
|
self.model.add_node(cast_node)
|
|
|
|
return cast_output, cast_node
|
|
|
|
def cast_input_to_int32(self, input_name: str):
|
|
return self.cast_input(input_name, "int32")
|
|
|
|
def remove_cast_int32(self, input_name: str):
|
|
input_name_to_nodes = self.model.input_name_to_nodes()
|
|
nodes = input_name_to_nodes[input_name]
|
|
for node in nodes:
|
|
if node.op_type == "Cast":
|
|
is_int32 = False
|
|
for att in node.attribute:
|
|
if att.name == "to" and att.i == int(TensorProto.INT32):
|
|
is_int32 = True
|
|
break
|
|
if is_int32:
|
|
output_name = node.output[0]
|
|
self.model.remove_node(node)
|
|
self.model.replace_input_of_all_nodes(output_name, input_name)
|
|
|
|
@staticmethod
|
|
def skip_parent(model: OnnxModel, node, parent_node, input_name_to_nodes):
|
|
"""
|
|
Before:
|
|
(input)-->parent-->node-->(output)
|
|
After:
|
|
(input)-->parent-->
|
|
|
|
|
+----->node-->(output)
|
|
|
|
This function returns a flag about whether the parent node can be removed.
|
|
Note that this function assumes the node has first input links from parent!
|
|
"""
|
|
parent_can_be_removed = False
|
|
input_name_to_nodes[node.input[0]].remove(node)
|
|
# We can remove the first Transpose if its output is not used (linked to graph output or other nodes) anymore.
|
|
if len(input_name_to_nodes[node.input[0]]) == 0 and not model.find_graph_output(
|
|
node.input[0]
|
|
): # checks main graph output. TODO: deal with subgraph
|
|
parent_can_be_removed = True
|
|
# self.nodes_to_remove.append(transpose_a)
|
|
|
|
input_name_to_nodes[parent_node.input[0]].append(node)
|
|
node.input[0] = parent_node.input[0]
|
|
return parent_can_be_removed
|
|
|
|
@staticmethod
|
|
def check_node_attribute(node, attribute_name: str, expected_value, default_value=None):
|
|
"""Verify that a node has expected value for an attribute.
|
|
|
|
Args:
|
|
node (NodeProto): a node to check
|
|
attribute_name (str): name of attribute
|
|
expected_value (Any): expected value of the attribute
|
|
default_value (Any, optional): default value if the attribute does not exist. Defaults to None.
|
|
|
|
Returns:
|
|
bool: whether the check is passed or not
|
|
"""
|
|
value = default_value
|
|
for attr in node.attribute:
|
|
if attr.name == attribute_name:
|
|
value = helper.get_attribute_value(attr)
|
|
|
|
if isinstance(expected_value, list):
|
|
return (isinstance(value, ndarray) or isinstance(value, list)) and array_equal(
|
|
expected_value, value, equal_nan=False
|
|
)
|
|
else:
|
|
return value == expected_value
|
|
|
|
@staticmethod
|
|
def transpose_2d_int8_tensor(tensor: onnx_proto.TensorProto):
|
|
"""Transpose a 2-D INT8 TensorProto
|
|
Args:
|
|
tensor (TensorProto): tensor to be transposed
|
|
Returns:
|
|
tensor (TensorProto): transposed tensor
|
|
"""
|
|
if not isinstance(tensor, onnx_proto.TensorProto):
|
|
raise ValueError("Expected input type is an ONNX TensorProto but got %s" % type(tensor))
|
|
|
|
if len(tensor.dims) != 2 or tensor.data_type != onnx_proto.TensorProto.INT8:
|
|
raise ValueError("Only INT8 2-D tensors can be transposed")
|
|
|
|
if tensor.raw_data:
|
|
int32_data = numpy.reshape(numpy.frombuffer(tensor.raw_data, dtype="int8"), tensor.dims)
|
|
int32_transposed_data = numpy.transpose(int32_data, [1, 0])
|
|
tensor.raw_data = int32_transposed_data.tobytes()
|
|
|
|
else:
|
|
raise ValueError("only raw buffer supported")
|
|
|
|
return tensor
|
|
|
|
@staticmethod
|
|
def check_qdq_node_for_fusion(node: NodeProto, model: OnnxModel, allow_per_tensor_quantization_only=True):
|
|
"""Verify if a provided QuantizeLinear (Q) / DequantizeLinear (DQ) node is a good candidate for fusion.
|
|
It is a good candidate for fusion if:
|
|
(1) The Q/DQ node is for per-tensor quantization if allow_per_tensor_quantization_only is `True`
|
|
(2) The Q/DQ node should have constant scale
|
|
(3) The Q/DQ node should have a zero point of 0
|
|
Args:
|
|
node (NodeProto): a Q/DQ node to check
|
|
Returns:
|
|
bool: whether the check is passed or not
|
|
"""
|
|
if not node.op_type in {"QuantizeLinear", "DequantizeLinear"}:
|
|
logger.debug(f"Provided node is not a Q/DQ node. Op Type: {node.op_type}")
|
|
|
|
scale = model.get_constant_value(node.input[1])
|
|
|
|
# Scale is not constant
|
|
if scale is None:
|
|
return False
|
|
|
|
# Not per-tensor quantization
|
|
scale_has_single_element = scale.ndim == 0 or (scale.ndim == 1 and scale.shape[0] == 1)
|
|
if allow_per_tensor_quantization_only and not scale_has_single_element:
|
|
return False
|
|
|
|
# If the Q/DQ node has no zero point input, it is assumed to be 0 (per ONNX spec)
|
|
if len(node.input) == 2:
|
|
return True
|
|
|
|
# Zero point should be constant and should have a value of 0
|
|
zero_point = model.get_constant_value(node.input[2])
|
|
|
|
# Zero point and scale should have same number of dims
|
|
if scale.ndim != zero_point.ndim:
|
|
return False
|
|
|
|
# Zero point is not constant or zero point is not zero
|
|
if zero_point is None:
|
|
return False
|
|
|
|
return numpy.all(zero_point == 0)
|
|
|
|
def check_node_input_value(self, node, input_index: int, expected_value):
|
|
"""Verify that a node has expected input value
|
|
|
|
Args:
|
|
node (NodeProto): a node to check
|
|
input_index (int): index of its input to be verified
|
|
expected_value (Any): expected value of the input
|
|
|
|
Returns:
|
|
bool: whether the check is passed or not
|
|
"""
|
|
assert len(node.input) > input_index
|
|
|
|
value = self.model.get_constant_value(node.input[input_index])
|
|
|
|
if isinstance(expected_value, list):
|
|
return (isinstance(value, ndarray) or isinstance(value, list)) and array_equal(
|
|
expected_value, value, equal_nan=False
|
|
)
|
|
else:
|
|
return value == expected_value
|
|
|
|
def remove_identity_nodes(self):
|
|
"""Remove Identity nodes, except those right before graph output."""
|
|
nodes_to_remove = []
|
|
for node in self.model.nodes():
|
|
if node.op_type == "Identity":
|
|
if node.output[0] not in self.model.get_graphs_output_names():
|
|
self.model.replace_input_of_all_nodes(node.output[0], node.input[0])
|
|
nodes_to_remove.append(node)
|
|
|
|
if nodes_to_remove:
|
|
self.model.remove_nodes(nodes_to_remove)
|
|
logger.info(f"Removed {len(nodes_to_remove)} Identity nodes")
|
|
|
|
def remove_cascaded_cast_nodes(self):
|
|
self.model.remove_cascaded_cast_nodes()
|
|
|
|
def remove_useless_cast_nodes(self):
|
|
self.model.remove_useless_cast_nodes()
|
|
|
|
def remove_useless_reshape_nodes(self):
|
|
"""Remove reshape node that is not needed based on symbolic shape inference: input and output has same shape"""
|
|
shape_infer = self.model.infer_runtime_shape(update=True)
|
|
if shape_infer is None:
|
|
return
|
|
|
|
nodes_to_remove = []
|
|
for node in self.model.nodes():
|
|
if node.op_type == "Reshape":
|
|
input_shape = shape_infer.get_edge_shape(node.input[0])
|
|
output_shape = shape_infer.get_edge_shape(node.output[0])
|
|
if input_shape and output_shape and input_shape == output_shape:
|
|
logger.info(
|
|
f"Remove reshape node {node.name} since its input shape is same as output: {input_shape}"
|
|
)
|
|
nodes_to_remove.append(node)
|
|
|
|
if nodes_to_remove:
|
|
graph_input_names = set(self.model.get_graphs_input_names())
|
|
graph_output_names = set(self.model.get_graphs_output_names())
|
|
for node in nodes_to_remove:
|
|
if bool(set(node.output) & graph_output_names):
|
|
if (
|
|
not bool(set(node.input) & graph_input_names)
|
|
and len(self.model.input_name_to_nodes()[node.input[0]]) == 1 # parent has only one child
|
|
):
|
|
self.model.replace_output_of_all_nodes(node.input[0], node.output[0])
|
|
else:
|
|
continue
|
|
else:
|
|
self.model.replace_input_of_all_nodes(node.output[0], node.input[0])
|
|
self.model.remove_node(node)
|
|
|
|
|
|
class NumpyHelper:
|
|
@staticmethod
|
|
def to_array(tensor: TensorProto, fill_zeros: bool = False) -> ndarray:
|
|
# When weights are in external data format but not presented, we can still test the optimizer with two changes:
|
|
# (1) set fill_zeros = True (2) change load_external_data=False in optimizer.py
|
|
if fill_zeros:
|
|
from onnx import mapping
|
|
|
|
return ndarray(
|
|
shape=tensor.dims,
|
|
dtype=mapping.TENSOR_TYPE_TO_NP_TYPE[tensor.data_type],
|
|
)
|
|
|
|
return numpy_helper.to_array(tensor)
|