图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

72 lines
2.4 KiB

  1. import onnx
  2. from onnx import onnx_pb as onnx_proto
  3. from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain
  4. from .base_operator import QuantOperatorBase
  5. class QLinearBinaryOp(QuantOperatorBase):
  6. def __init__(self, onnx_quantizer, onnx_node):
  7. super().__init__(onnx_quantizer, onnx_node)
  8. def quantize(self):
  9. node = self.node
  10. (
  11. data_found,
  12. output_scale_name,
  13. output_zp_name,
  14. _,
  15. _,
  16. ) = self.quantizer._get_quantization_params(node.output[0])
  17. (
  18. quantized_input_names,
  19. zero_point_names,
  20. scale_names,
  21. nodes,
  22. ) = self.quantizer.quantize_activation(node, [0, 1])
  23. if not data_found or quantized_input_names is None:
  24. return super().quantize()
  25. qlinear_binary_math_output = node.output[0] + TENSOR_NAME_QUANT_SUFFIX
  26. qlinear_binary_math_name = node.name + "_quant" if node.name != "" else ""
  27. kwargs = {}
  28. for attribute in node.attribute:
  29. kwargs.update(attribute_to_kwarg(attribute))
  30. kwargs["domain"] = ms_domain
  31. qlinear_binary_math_inputs = []
  32. # Input 0
  33. qlinear_binary_math_inputs.append(quantized_input_names[0])
  34. qlinear_binary_math_inputs.append(scale_names[0])
  35. qlinear_binary_math_inputs.append(zero_point_names[0])
  36. # Input 1
  37. qlinear_binary_math_inputs.append(quantized_input_names[1])
  38. qlinear_binary_math_inputs.append(scale_names[1])
  39. qlinear_binary_math_inputs.append(zero_point_names[1])
  40. # Output
  41. qlinear_binary_math_inputs.append(output_scale_name)
  42. qlinear_binary_math_inputs.append(output_zp_name)
  43. qlinear_binary_math_node = onnx.helper.make_node(
  44. "QLinear" + node.op_type,
  45. qlinear_binary_math_inputs,
  46. [qlinear_binary_math_output],
  47. qlinear_binary_math_name,
  48. **kwargs,
  49. )
  50. nodes.append(qlinear_binary_math_node)
  51. # Create an entry for this quantized value
  52. q_output = QuantizedValue(
  53. node.output[0],
  54. qlinear_binary_math_output,
  55. output_scale_name,
  56. output_zp_name,
  57. QuantizedValueType.Input,
  58. )
  59. self.quantizer.quantized_value_map[node.output[0]] = q_output
  60. self.quantizer.new_nodes += nodes