You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
101 lines
3.3 KiB
101 lines
3.3 KiB
from .operators.activation import QDQRemovableActivation, QLinearActivation
|
|
from .operators.argmax import QArgMax
|
|
from .operators.attention import AttentionQuant
|
|
from .operators.base_operator import QuantOperatorBase
|
|
from .operators.binary_op import QLinearBinaryOp
|
|
from .operators.concat import QLinearConcat
|
|
from .operators.conv import ConvInteger, QDQConv, QLinearConv
|
|
from .operators.direct_q8 import Direct8BitOp, QDQDirect8BitOp
|
|
from .operators.embed_layernorm import EmbedLayerNormalizationQuant
|
|
from .operators.gather import GatherQuant, QDQGather
|
|
from .operators.gavgpool import QGlobalAveragePool
|
|
from .operators.gemm import QDQGemm, QLinearGemm
|
|
from .operators.lstm import LSTMQuant
|
|
from .operators.matmul import MatMulInteger, QDQMatMul, QLinearMatMul
|
|
from .operators.maxpool import QDQMaxPool, QMaxPool
|
|
from .operators.pad import QPad
|
|
from .operators.pooling import QLinearPool
|
|
from .operators.qdq_base_operator import QDQOperatorBase
|
|
from .operators.resize import QDQResize, QResize
|
|
from .operators.softmax import QDQSoftmax, QLinearSoftmax
|
|
from .operators.split import QDQSplit, QSplit
|
|
from .operators.where import QDQWhere, QLinearWhere
|
|
from .quant_utils import QuantizationMode
|
|
|
|
CommonOpsRegistry = {
|
|
"Gather": GatherQuant,
|
|
"Transpose": Direct8BitOp,
|
|
"EmbedLayerNormalization": EmbedLayerNormalizationQuant,
|
|
}
|
|
|
|
IntegerOpsRegistry = {
|
|
"Conv": ConvInteger,
|
|
"MatMul": MatMulInteger,
|
|
"Attention": AttentionQuant,
|
|
"LSTM": LSTMQuant,
|
|
}
|
|
IntegerOpsRegistry.update(CommonOpsRegistry)
|
|
|
|
QLinearOpsRegistry = {
|
|
"ArgMax": QArgMax,
|
|
"Conv": QLinearConv,
|
|
"Gemm": QLinearGemm,
|
|
"MatMul": QLinearMatMul,
|
|
"Add": QLinearBinaryOp,
|
|
"Mul": QLinearBinaryOp,
|
|
"Relu": QLinearActivation,
|
|
"Clip": QLinearActivation,
|
|
"LeakyRelu": QLinearActivation,
|
|
"Sigmoid": QLinearActivation,
|
|
"MaxPool": QMaxPool,
|
|
"GlobalAveragePool": QGlobalAveragePool,
|
|
"Split": QSplit,
|
|
"Pad": QPad,
|
|
"Reshape": Direct8BitOp,
|
|
"Squeeze": Direct8BitOp,
|
|
"Unsqueeze": Direct8BitOp,
|
|
"Resize": QResize,
|
|
"AveragePool": QLinearPool,
|
|
"Concat": QLinearConcat,
|
|
"Softmax": QLinearSoftmax,
|
|
"Where": QLinearWhere,
|
|
}
|
|
QLinearOpsRegistry.update(CommonOpsRegistry)
|
|
|
|
QDQRegistry = {
|
|
"Conv": QDQConv,
|
|
"Gemm": QDQGemm,
|
|
"Clip": QDQRemovableActivation,
|
|
"Relu": QDQRemovableActivation,
|
|
"Reshape": QDQDirect8BitOp,
|
|
"Transpose": QDQDirect8BitOp,
|
|
"Squeeze": QDQDirect8BitOp,
|
|
"Unsqueeze": QDQDirect8BitOp,
|
|
"Resize": QDQResize,
|
|
"MaxPool": QDQMaxPool,
|
|
"AveragePool": QDQDirect8BitOp,
|
|
"MatMul": QDQMatMul,
|
|
"Split": QDQSplit,
|
|
"Gather": QDQGather,
|
|
"Softmax": QDQSoftmax,
|
|
"Where": QDQWhere,
|
|
}
|
|
|
|
|
|
def CreateDefaultOpQuantizer(onnx_quantizer, node):
|
|
return QuantOperatorBase(onnx_quantizer, node)
|
|
|
|
|
|
def CreateOpQuantizer(onnx_quantizer, node):
|
|
registry = IntegerOpsRegistry if onnx_quantizer.mode == QuantizationMode.IntegerOps else QLinearOpsRegistry
|
|
if node.op_type in registry.keys():
|
|
op_quantizer = registry[node.op_type](onnx_quantizer, node)
|
|
if op_quantizer.should_quantize():
|
|
return op_quantizer
|
|
return QuantOperatorBase(onnx_quantizer, node)
|
|
|
|
|
|
def CreateQDQQuantizer(onnx_quantizer, node):
|
|
if node.op_type in QDQRegistry.keys():
|
|
return QDQRegistry[node.op_type](onnx_quantizer, node)
|
|
return QDQOperatorBase(onnx_quantizer, node)
|