图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

78 lines
3.2 KiB

from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType
from .base_operator import QuantOperatorBase
from .qdq_base_operator import QDQOperatorBase
# For operators that support 8bits operations directly, and output could
# reuse input[0]'s type, zeropoint, scale; For example,Transpose, Reshape, etc.
class Direct8BitOp(QuantOperatorBase):
def __init__(self, onnx_quantizer, onnx_node):
super().__init__(onnx_quantizer, onnx_node)
def quantize(self):
node = self.node
if not self.quantizer.force_quantize_no_input_check:
# Keep backward compatiblity
# Quantize when input[0] is quantized already. Otherwise keep it.
quantized_input_value = self.quantizer.find_quantized_value(node.input[0])
if quantized_input_value is None:
self.quantizer.new_nodes += [node]
return
quantized_output_value = QuantizedValue(
node.output[0],
node.output[0] + TENSOR_NAME_QUANT_SUFFIX,
quantized_input_value.scale_name,
quantized_input_value.zp_name,
quantized_input_value.value_type,
)
self.quantizer.quantized_value_map[node.output[0]] = quantized_output_value
node.input[0] = quantized_input_value.q_name
node.output[0] = quantized_output_value.q_name
self.quantizer.new_nodes += [node]
else:
# Force quantize those ops if possible, use exclude node list if this is not you want
if not self.quantizer.is_valid_quantize_weight(node.input[0]):
super().quantize()
return
(
quantized_input_names,
zero_point_names,
scale_names,
nodes,
) = self.quantizer.quantize_activation(node, [0])
if quantized_input_names is None:
return super().quantize()
# Create an entry for output quantized value
quantized_output_value = QuantizedValue(
node.output[0],
node.output[0] + TENSOR_NAME_QUANT_SUFFIX,
scale_names[0],
zero_point_names[0],
QuantizedValueType.Input,
)
self.quantizer.quantized_value_map[node.output[0]] = quantized_output_value
node.input[0] = quantized_input_names[0]
node.output[0] = quantized_output_value.q_name
nodes.append(node)
self.quantizer.new_nodes += nodes
class QDQDirect8BitOp(QDQOperatorBase):
def __init__(self, onnx_quantizer, onnx_node):
super().__init__(onnx_quantizer, onnx_node)
def quantize(self):
if self.quantizer.force_quantize_no_input_check:
self.quantizer.quantize_activation_tensor(self.node.input[0])
if not self.disable_qdq_for_node_output:
self.quantizer.quantize_activation_tensor(self.node.output[0], self.node.input[0])
elif self.quantizer.is_tensor_quantized(self.node.input[0]) and not self.disable_qdq_for_node_output:
self.quantizer.quantize_activation_tensor(self.node.output[0], self.node.input[0])