图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

67 lines
2.2 KiB

  1. import onnx
  2. from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain
  3. from .base_operator import QuantOperatorBase
  4. class QLinearPool(QuantOperatorBase):
  5. def __init__(self, onnx_quantizer, onnx_node):
  6. super().__init__(onnx_quantizer, onnx_node)
  7. def quantize(self):
  8. node = self.node
  9. # only try to quantize when given quantization parameters for it
  10. (
  11. data_found,
  12. output_scale_name,
  13. output_zp_name,
  14. _,
  15. _,
  16. ) = self.quantizer._get_quantization_params(node.output[0])
  17. # get quantized input tensor names, quantize input if needed
  18. (
  19. quantized_input_names,
  20. input_zero_point_names,
  21. input_scale_names,
  22. nodes,
  23. ) = self.quantizer.quantize_activation(node, [0])
  24. if not data_found or quantized_input_names is None:
  25. return super().quantize()
  26. # Create an entry for output quantized value.
  27. qlinear_output_name = node.output[0] + TENSOR_NAME_QUANT_SUFFIX
  28. quantized_output_value = QuantizedValue(
  29. node.output[0],
  30. qlinear_output_name,
  31. output_scale_name,
  32. output_zp_name,
  33. QuantizedValueType.Input,
  34. )
  35. self.quantizer.quantized_value_map[node.output[0]] = quantized_output_value
  36. # Create qlinear pool node for given type (AveragePool, etc)
  37. kwargs = {}
  38. for attribute in node.attribute:
  39. kwargs.update(attribute_to_kwarg(attribute))
  40. kwargs["domain"] = ms_domain
  41. qlinear_node_name = node.name + "_quant" if node.name != "" else ""
  42. qnode = onnx.helper.make_node(
  43. "QLinear" + node.op_type,
  44. [
  45. quantized_input_names[0],
  46. input_scale_names[0],
  47. input_zero_point_names[0],
  48. output_scale_name,
  49. output_zp_name,
  50. ],
  51. [qlinear_output_name],
  52. qlinear_node_name,
  53. **kwargs,
  54. )
  55. # add all newly created nodes
  56. nodes.append(qnode)
  57. self.quantizer.new_nodes += nodes