m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

101 lines
3.3 KiB

from .operators.activation import QDQRemovableActivation, QLinearActivation
from .operators.argmax import QArgMax
from .operators.attention import AttentionQuant
from .operators.base_operator import QuantOperatorBase
from .operators.binary_op import QLinearBinaryOp
from .operators.concat import QLinearConcat
from .operators.conv import ConvInteger, QDQConv, QLinearConv
from .operators.direct_q8 import Direct8BitOp, QDQDirect8BitOp
from .operators.embed_layernorm import EmbedLayerNormalizationQuant
from .operators.gather import GatherQuant, QDQGather
from .operators.gavgpool import QGlobalAveragePool
from .operators.gemm import QDQGemm, QLinearGemm
from .operators.lstm import LSTMQuant
from .operators.matmul import MatMulInteger, QDQMatMul, QLinearMatMul
from .operators.maxpool import QDQMaxPool, QMaxPool
from .operators.pad import QPad
from .operators.pooling import QLinearPool
from .operators.qdq_base_operator import QDQOperatorBase
from .operators.resize import QDQResize, QResize
from .operators.softmax import QDQSoftmax, QLinearSoftmax
from .operators.split import QDQSplit, QSplit
from .operators.where import QDQWhere, QLinearWhere
from .quant_utils import QuantizationMode
CommonOpsRegistry = {
"Gather": GatherQuant,
"Transpose": Direct8BitOp,
"EmbedLayerNormalization": EmbedLayerNormalizationQuant,
}
IntegerOpsRegistry = {
"Conv": ConvInteger,
"MatMul": MatMulInteger,
"Attention": AttentionQuant,
"LSTM": LSTMQuant,
}
IntegerOpsRegistry.update(CommonOpsRegistry)
QLinearOpsRegistry = {
"ArgMax": QArgMax,
"Conv": QLinearConv,
"Gemm": QLinearGemm,
"MatMul": QLinearMatMul,
"Add": QLinearBinaryOp,
"Mul": QLinearBinaryOp,
"Relu": QLinearActivation,
"Clip": QLinearActivation,
"LeakyRelu": QLinearActivation,
"Sigmoid": QLinearActivation,
"MaxPool": QMaxPool,
"GlobalAveragePool": QGlobalAveragePool,
"Split": QSplit,
"Pad": QPad,
"Reshape": Direct8BitOp,
"Squeeze": Direct8BitOp,
"Unsqueeze": Direct8BitOp,
"Resize": QResize,
"AveragePool": QLinearPool,
"Concat": QLinearConcat,
"Softmax": QLinearSoftmax,
"Where": QLinearWhere,
}
QLinearOpsRegistry.update(CommonOpsRegistry)
QDQRegistry = {
"Conv": QDQConv,
"Gemm": QDQGemm,
"Clip": QDQRemovableActivation,
"Relu": QDQRemovableActivation,
"Reshape": QDQDirect8BitOp,
"Transpose": QDQDirect8BitOp,
"Squeeze": QDQDirect8BitOp,
"Unsqueeze": QDQDirect8BitOp,
"Resize": QDQResize,
"MaxPool": QDQMaxPool,
"AveragePool": QDQDirect8BitOp,
"MatMul": QDQMatMul,
"Split": QDQSplit,
"Gather": QDQGather,
"Softmax": QDQSoftmax,
"Where": QDQWhere,
}
def CreateDefaultOpQuantizer(onnx_quantizer, node):
return QuantOperatorBase(onnx_quantizer, node)
def CreateOpQuantizer(onnx_quantizer, node):
registry = IntegerOpsRegistry if onnx_quantizer.mode == QuantizationMode.IntegerOps else QLinearOpsRegistry
if node.op_type in registry.keys():
op_quantizer = registry[node.op_type](onnx_quantizer, node)
if op_quantizer.should_quantize():
return op_quantizer
return QuantOperatorBase(onnx_quantizer, node)
def CreateQDQQuantizer(onnx_quantizer, node):
if node.op_type in QDQRegistry.keys():
return QDQRegistry[node.op_type](onnx_quantizer, node)
return QDQOperatorBase(onnx_quantizer, node)