图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

166 lines
5.5 KiB

  1. import logging
  2. import numpy as np
  3. import onnx
  4. from onnx import onnx_pb as onnx_proto
  5. from ..quant_utils import (
  6. TENSOR_NAME_QUANT_SUFFIX,
  7. QuantizedValue,
  8. QuantizedValueType,
  9. attribute_to_kwarg,
  10. find_by_name,
  11. get_mul_node,
  12. ms_domain,
  13. )
  14. from .base_operator import QuantOperatorBase
  15. from .matmul import QOpMatMul
  16. from .qdq_base_operator import QDQOperatorBase
  17. def is_B_transposed(gemm_node):
  18. transB_attribute = [attr for attr in gemm_node.attribute if attr.name == "transB"]
  19. if len(transB_attribute):
  20. return 0 < onnx.helper.get_attribute_value(transB_attribute[0])
  21. return False
  22. def get_beta(gemm_node):
  23. beta_attribute = [attr for attr in gemm_node.attribute if attr.name == "beta"]
  24. if len(beta_attribute):
  25. return onnx.helper.get_attribute_value(beta_attribute[0])
  26. return 1.0
  27. def set_default_beta(gemm_node):
  28. beta_attribute = [attr for attr in gemm_node.attribute if attr.name == "beta"]
  29. if len(beta_attribute):
  30. beta_attribute[0].f = 1.0
  31. return 1.0
  32. class QLinearGemm(QOpMatMul):
  33. def __init__(self, onnx_quantizer, onnx_node):
  34. super().__init__(onnx_quantizer, onnx_node)
  35. def quantize(self):
  36. node = self.node
  37. assert node.op_type == "Gemm"
  38. (
  39. data_found,
  40. output_scale_name,
  41. output_zp_name,
  42. _,
  43. _,
  44. ) = self.quantizer._get_quantization_params(node.output[0])
  45. if self.quantizer.is_input_a_initializer(node.input[1]) and self.quantizer.is_per_channel():
  46. (
  47. quantized_input_names,
  48. zero_point_names,
  49. scale_names,
  50. nodes,
  51. ) = self.quantizer.quantize_activation(node, [0])
  52. quant_weight_tuple = self.quantizer.quantize_weight_per_channel(
  53. node.input[1],
  54. onnx_proto.TensorProto.INT8,
  55. 0 if is_B_transposed(node) else 1,
  56. )
  57. quantized_input_names.append(quant_weight_tuple[0])
  58. zero_point_names.append(quant_weight_tuple[1])
  59. scale_names.append(quant_weight_tuple[2])
  60. else:
  61. # Get Quantized from both activation(input[0]) and weight(input[1])
  62. (
  63. quantized_input_names,
  64. zero_point_names,
  65. scale_names,
  66. nodes,
  67. ) = self.quantizer.quantize_activation(node, [0])
  68. (
  69. quantized_input_names_weight,
  70. zero_point_names_weight,
  71. scale_names_weight,
  72. nodes_weight,
  73. ) = self.quantizer.quantize_weight(node, [1], reduce_range=self.quantizer.reduce_range)
  74. quantized_input_names.extend(quantized_input_names_weight)
  75. zero_point_names.extend(zero_point_names_weight)
  76. scale_names.extend(scale_names_weight)
  77. nodes.extend(nodes_weight)
  78. if not data_found or quantized_input_names is None:
  79. return super().quantize()
  80. quantized_bias_name = ""
  81. if len(node.input) == 3:
  82. if not self.quantizer.is_input_a_initializer(node.input[2]):
  83. return super().quantize()
  84. quantized_bias_name = self.quantizer.quantize_bias_static(
  85. node.input[2], node.input[0], node.input[1], get_beta(self.node)
  86. )
  87. qgemm_output = node.output[0] + TENSOR_NAME_QUANT_SUFFIX
  88. qgemm_name = qgemm_name = node.name + "_quant" if node.name != "" else ""
  89. kwargs = {}
  90. for attribute in node.attribute:
  91. if attribute.name != "beta":
  92. kwargs.update(attribute_to_kwarg(attribute))
  93. kwargs["domain"] = ms_domain
  94. # generate input
  95. qgemm_inputs = []
  96. for i in range(2):
  97. qgemm_inputs.extend([quantized_input_names[i], scale_names[i], zero_point_names[i]])
  98. qgemm_inputs.extend([quantized_bias_name, output_scale_name, output_zp_name])
  99. qgemm_node = onnx.helper.make_node("QGemm", qgemm_inputs, [qgemm_output], qgemm_name, **kwargs)
  100. nodes.append(qgemm_node)
  101. # Create an entry for this quantized value
  102. q_output = QuantizedValue(
  103. node.output[0],
  104. qgemm_output,
  105. output_scale_name,
  106. output_zp_name,
  107. QuantizedValueType.Input,
  108. )
  109. self.quantizer.quantized_value_map[node.output[0]] = q_output
  110. self.quantizer.new_nodes += nodes
  111. class QDQGemm(QDQOperatorBase):
  112. def __init__(self, onnx_quantizer, onnx_node):
  113. super().__init__(onnx_quantizer, onnx_node)
  114. def quantize(self):
  115. node = self.node
  116. assert node.op_type == "Gemm"
  117. self.quantizer.quantize_activation_tensor(node.input[0])
  118. if not self.disable_qdq_for_node_output:
  119. self.quantizer.quantize_activation_tensor(node.output[0])
  120. if self.quantizer.is_per_channel():
  121. self.quantizer.quantize_weight_tensor_per_channel(node.input[1], 0 if is_B_transposed(node) else 1)
  122. else:
  123. self.quantizer.quantize_weight_tensor(node.input[1])
  124. if len(node.input) == 3:
  125. if self.quantizer.is_input_a_initializer(node.input[2]):
  126. self.quantizer.quantize_bias_tensor(node.input[2], node.input[0], node.input[1], get_beta(self.node))
  127. set_default_beta(self.node)
  128. else:
  129. logging.warning(
  130. "Bias of Gemm node '{}' is not constant. Please exclude this node for better performance.".format(
  131. self.node.name
  132. )
  133. )