图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

118 lines
4.3 KiB

  1. import onnx
  2. from onnx import onnx_pb as onnx_proto
  3. from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain
  4. from .base_operator import QuantOperatorBase
  5. from .qdq_base_operator import QDQOperatorBase
  6. class QLinearActivation(QuantOperatorBase):
  7. def __init__(self, onnx_quantizer, onnx_node):
  8. super().__init__(onnx_quantizer, onnx_node)
  9. def QuantizeClipRelu(self):
  10. node = self.node
  11. assert node.op_type == "Relu" or node.op_type == "Clip"
  12. # When mode is QLinearOps, the output quantization params are calculated based on outputs from
  13. # activation nodes, therefore these nodes can be removed from the graph if they follow a quantized op.
  14. # If input to this node is not quantized then keep this node
  15. # If activation is symmetric, not quantize the op and simply return
  16. if node.input[0] not in self.quantizer.quantized_value_map or self.quantizer.is_activation_symmetric:
  17. return super().quantize()
  18. quantized_value = self.quantizer.quantized_value_map[node.input[0]]
  19. self.quantizer.quantized_value_map[node.output[0]] = quantized_value
  20. def quantize(self):
  21. node = self.node
  22. if node.op_type == "Relu" or node.op_type == "Clip":
  23. self.QuantizeClipRelu()
  24. return
  25. nnapi_sigmoid_option = "extra.Sigmoid.nnapi"
  26. sigmoid_nnapi_mode = (
  27. node.op_type == "Sigmoid"
  28. and nnapi_sigmoid_option in self.quantizer.extra_options
  29. and self.quantizer.extra_options[nnapi_sigmoid_option]
  30. )
  31. use_scale = 1 / 256.0 if sigmoid_nnapi_mode else None
  32. use_zeropoint = 0 if sigmoid_nnapi_mode else None
  33. # No assert on op_type as it is controlled by registry
  34. # only try to quantize when given quantization parameters for it
  35. (
  36. data_found,
  37. output_scale_name,
  38. output_zp_name,
  39. _,
  40. _,
  41. ) = self.quantizer._get_quantization_params(node.output[0], use_scale, use_zeropoint)
  42. (
  43. quantized_input_names,
  44. zero_point_names,
  45. scale_names,
  46. nodes,
  47. ) = self.quantizer.quantize_activation(node, [0])
  48. if not data_found or quantized_input_names is None:
  49. return super().quantize()
  50. qlinear_activation_output = node.output[0] + TENSOR_NAME_QUANT_SUFFIX
  51. qlinear_activation_name = ""
  52. if node.name != "":
  53. qlinear_activation_name = node.name + "_quant"
  54. kwargs = {}
  55. for attribute in node.attribute:
  56. kwargs.update(attribute_to_kwarg(attribute))
  57. kwargs["domain"] = ms_domain
  58. qlinear_activation_inputs = [
  59. quantized_input_names[0],
  60. scale_names[0],
  61. zero_point_names[0],
  62. output_scale_name,
  63. output_zp_name,
  64. ]
  65. qlinear_activation_node = onnx.helper.make_node(
  66. "QLinear" + node.op_type,
  67. qlinear_activation_inputs,
  68. [qlinear_activation_output],
  69. qlinear_activation_name,
  70. **kwargs,
  71. )
  72. # Create an entry for this quantized value
  73. q_output = QuantizedValue(
  74. node.output[0],
  75. qlinear_activation_output,
  76. output_scale_name,
  77. output_zp_name,
  78. QuantizedValueType.Input,
  79. )
  80. self.quantizer.quantized_value_map[node.output[0]] = q_output
  81. nodes.append(qlinear_activation_node)
  82. self.quantizer.new_nodes += nodes
  83. class QDQRemovableActivation(QDQOperatorBase):
  84. def __init__(self, onnx_quantizer, onnx_node):
  85. super().__init__(onnx_quantizer, onnx_node)
  86. def quantize(self):
  87. node = self.node
  88. # If input to this node is not quantized then keep this node
  89. if not self.quantizer.is_tensor_quantized(node.input[0]):
  90. return
  91. if not self.quantizer.is_activation_symmetric and self.quantizer.try_replacing_upstream_output(
  92. node.input[0], node.output[0]
  93. ):
  94. self.quantizer.remove_node(self.node)
  95. else:
  96. self.quantizer.quantize_activation_tensor(node.input[0])
  97. if not self.disable_qdq_for_node_output:
  98. self.quantizer.quantize_activation_tensor(node.output[0])