m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

101 lines
3.3 KiB

6 months ago
  1. from .operators.activation import QDQRemovableActivation, QLinearActivation
  2. from .operators.argmax import QArgMax
  3. from .operators.attention import AttentionQuant
  4. from .operators.base_operator import QuantOperatorBase
  5. from .operators.binary_op import QLinearBinaryOp
  6. from .operators.concat import QLinearConcat
  7. from .operators.conv import ConvInteger, QDQConv, QLinearConv
  8. from .operators.direct_q8 import Direct8BitOp, QDQDirect8BitOp
  9. from .operators.embed_layernorm import EmbedLayerNormalizationQuant
  10. from .operators.gather import GatherQuant, QDQGather
  11. from .operators.gavgpool import QGlobalAveragePool
  12. from .operators.gemm import QDQGemm, QLinearGemm
  13. from .operators.lstm import LSTMQuant
  14. from .operators.matmul import MatMulInteger, QDQMatMul, QLinearMatMul
  15. from .operators.maxpool import QDQMaxPool, QMaxPool
  16. from .operators.pad import QPad
  17. from .operators.pooling import QLinearPool
  18. from .operators.qdq_base_operator import QDQOperatorBase
  19. from .operators.resize import QDQResize, QResize
  20. from .operators.softmax import QDQSoftmax, QLinearSoftmax
  21. from .operators.split import QDQSplit, QSplit
  22. from .operators.where import QDQWhere, QLinearWhere
  23. from .quant_utils import QuantizationMode
  24. CommonOpsRegistry = {
  25. "Gather": GatherQuant,
  26. "Transpose": Direct8BitOp,
  27. "EmbedLayerNormalization": EmbedLayerNormalizationQuant,
  28. }
  29. IntegerOpsRegistry = {
  30. "Conv": ConvInteger,
  31. "MatMul": MatMulInteger,
  32. "Attention": AttentionQuant,
  33. "LSTM": LSTMQuant,
  34. }
  35. IntegerOpsRegistry.update(CommonOpsRegistry)
  36. QLinearOpsRegistry = {
  37. "ArgMax": QArgMax,
  38. "Conv": QLinearConv,
  39. "Gemm": QLinearGemm,
  40. "MatMul": QLinearMatMul,
  41. "Add": QLinearBinaryOp,
  42. "Mul": QLinearBinaryOp,
  43. "Relu": QLinearActivation,
  44. "Clip": QLinearActivation,
  45. "LeakyRelu": QLinearActivation,
  46. "Sigmoid": QLinearActivation,
  47. "MaxPool": QMaxPool,
  48. "GlobalAveragePool": QGlobalAveragePool,
  49. "Split": QSplit,
  50. "Pad": QPad,
  51. "Reshape": Direct8BitOp,
  52. "Squeeze": Direct8BitOp,
  53. "Unsqueeze": Direct8BitOp,
  54. "Resize": QResize,
  55. "AveragePool": QLinearPool,
  56. "Concat": QLinearConcat,
  57. "Softmax": QLinearSoftmax,
  58. "Where": QLinearWhere,
  59. }
  60. QLinearOpsRegistry.update(CommonOpsRegistry)
  61. QDQRegistry = {
  62. "Conv": QDQConv,
  63. "Gemm": QDQGemm,
  64. "Clip": QDQRemovableActivation,
  65. "Relu": QDQRemovableActivation,
  66. "Reshape": QDQDirect8BitOp,
  67. "Transpose": QDQDirect8BitOp,
  68. "Squeeze": QDQDirect8BitOp,
  69. "Unsqueeze": QDQDirect8BitOp,
  70. "Resize": QDQResize,
  71. "MaxPool": QDQMaxPool,
  72. "AveragePool": QDQDirect8BitOp,
  73. "MatMul": QDQMatMul,
  74. "Split": QDQSplit,
  75. "Gather": QDQGather,
  76. "Softmax": QDQSoftmax,
  77. "Where": QDQWhere,
  78. }
  79. def CreateDefaultOpQuantizer(onnx_quantizer, node):
  80. return QuantOperatorBase(onnx_quantizer, node)
  81. def CreateOpQuantizer(onnx_quantizer, node):
  82. registry = IntegerOpsRegistry if onnx_quantizer.mode == QuantizationMode.IntegerOps else QLinearOpsRegistry
  83. if node.op_type in registry.keys():
  84. op_quantizer = registry[node.op_type](onnx_quantizer, node)
  85. if op_quantizer.should_quantize():
  86. return op_quantizer
  87. return QuantOperatorBase(onnx_quantizer, node)
  88. def CreateQDQQuantizer(onnx_quantizer, node):
  89. if node.op_type in QDQRegistry.keys():
  90. return QDQRegistry[node.op_type](onnx_quantizer, node)
  91. return QDQOperatorBase(onnx_quantizer, node)