图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

166 lines
5.5 KiB

import logging
import numpy as np
import onnx
from onnx import onnx_pb as onnx_proto
from ..quant_utils import (
TENSOR_NAME_QUANT_SUFFIX,
QuantizedValue,
QuantizedValueType,
attribute_to_kwarg,
find_by_name,
get_mul_node,
ms_domain,
)
from .base_operator import QuantOperatorBase
from .matmul import QOpMatMul
from .qdq_base_operator import QDQOperatorBase
def is_B_transposed(gemm_node):
transB_attribute = [attr for attr in gemm_node.attribute if attr.name == "transB"]
if len(transB_attribute):
return 0 < onnx.helper.get_attribute_value(transB_attribute[0])
return False
def get_beta(gemm_node):
beta_attribute = [attr for attr in gemm_node.attribute if attr.name == "beta"]
if len(beta_attribute):
return onnx.helper.get_attribute_value(beta_attribute[0])
return 1.0
def set_default_beta(gemm_node):
beta_attribute = [attr for attr in gemm_node.attribute if attr.name == "beta"]
if len(beta_attribute):
beta_attribute[0].f = 1.0
return 1.0
class QLinearGemm(QOpMatMul):
def __init__(self, onnx_quantizer, onnx_node):
super().__init__(onnx_quantizer, onnx_node)
def quantize(self):
node = self.node
assert node.op_type == "Gemm"
(
data_found,
output_scale_name,
output_zp_name,
_,
_,
) = self.quantizer._get_quantization_params(node.output[0])
if self.quantizer.is_input_a_initializer(node.input[1]) and self.quantizer.is_per_channel():
(
quantized_input_names,
zero_point_names,
scale_names,
nodes,
) = self.quantizer.quantize_activation(node, [0])
quant_weight_tuple = self.quantizer.quantize_weight_per_channel(
node.input[1],
onnx_proto.TensorProto.INT8,
0 if is_B_transposed(node) else 1,
)
quantized_input_names.append(quant_weight_tuple[0])
zero_point_names.append(quant_weight_tuple[1])
scale_names.append(quant_weight_tuple[2])
else:
# Get Quantized from both activation(input[0]) and weight(input[1])
(
quantized_input_names,
zero_point_names,
scale_names,
nodes,
) = self.quantizer.quantize_activation(node, [0])
(
quantized_input_names_weight,
zero_point_names_weight,
scale_names_weight,
nodes_weight,
) = self.quantizer.quantize_weight(node, [1], reduce_range=self.quantizer.reduce_range)
quantized_input_names.extend(quantized_input_names_weight)
zero_point_names.extend(zero_point_names_weight)
scale_names.extend(scale_names_weight)
nodes.extend(nodes_weight)
if not data_found or quantized_input_names is None:
return super().quantize()
quantized_bias_name = ""
if len(node.input) == 3:
if not self.quantizer.is_input_a_initializer(node.input[2]):
return super().quantize()
quantized_bias_name = self.quantizer.quantize_bias_static(
node.input[2], node.input[0], node.input[1], get_beta(self.node)
)
qgemm_output = node.output[0] + TENSOR_NAME_QUANT_SUFFIX
qgemm_name = qgemm_name = node.name + "_quant" if node.name != "" else ""
kwargs = {}
for attribute in node.attribute:
if attribute.name != "beta":
kwargs.update(attribute_to_kwarg(attribute))
kwargs["domain"] = ms_domain
# generate input
qgemm_inputs = []
for i in range(2):
qgemm_inputs.extend([quantized_input_names[i], scale_names[i], zero_point_names[i]])
qgemm_inputs.extend([quantized_bias_name, output_scale_name, output_zp_name])
qgemm_node = onnx.helper.make_node("QGemm", qgemm_inputs, [qgemm_output], qgemm_name, **kwargs)
nodes.append(qgemm_node)
# Create an entry for this quantized value
q_output = QuantizedValue(
node.output[0],
qgemm_output,
output_scale_name,
output_zp_name,
QuantizedValueType.Input,
)
self.quantizer.quantized_value_map[node.output[0]] = q_output
self.quantizer.new_nodes += nodes
class QDQGemm(QDQOperatorBase):
def __init__(self, onnx_quantizer, onnx_node):
super().__init__(onnx_quantizer, onnx_node)
def quantize(self):
node = self.node
assert node.op_type == "Gemm"
self.quantizer.quantize_activation_tensor(node.input[0])
if not self.disable_qdq_for_node_output:
self.quantizer.quantize_activation_tensor(node.output[0])
if self.quantizer.is_per_channel():
self.quantizer.quantize_weight_tensor_per_channel(node.input[1], 0 if is_B_transposed(node) else 1)
else:
self.quantizer.quantize_weight_tensor(node.input[1])
if len(node.input) == 3:
if self.quantizer.is_input_a_initializer(node.input[2]):
self.quantizer.quantize_bias_tensor(node.input[2], node.input[0], node.input[1], get_beta(self.node))
set_default_beta(self.node)
else:
logging.warning(
"Bias of Gemm node '{}' is not constant. Please exclude this node for better performance.".format(
self.node.name
)
)