图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

62 lines
2.3 KiB

  1. import onnx
  2. from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain
  3. from .base_operator import QuantOperatorBase
  4. class QGlobalAveragePool(QuantOperatorBase):
  5. def __init__(self, onnx_quantizer, onnx_node):
  6. super().__init__(onnx_quantizer, onnx_node)
  7. def quantize(self):
  8. node = self.node
  9. assert node.op_type == "GlobalAveragePool"
  10. # If input to this node is not quantized then keep this node.
  11. if node.input[0] not in self.quantizer.quantized_value_map:
  12. return super().quantize()
  13. quantized_input_value = self.quantizer.quantized_value_map[node.input[0]]
  14. # Create an entry for output quantized value.
  15. quantized_input_value = self.quantizer.quantized_value_map[node.input[0]]
  16. (
  17. data_found,
  18. output_scale_name_from_parameter,
  19. output_zp_name_from_parameter,
  20. _,
  21. _,
  22. ) = self.quantizer._get_quantization_params(node.output[0])
  23. # Just use input scale and zp if parameters for output is not specified.
  24. output_scale_name = output_scale_name_from_parameter if data_found else quantized_input_value.scale_name
  25. output_zp_name = output_zp_name_from_parameter if data_found else quantized_input_value.zp_name
  26. quantized_output_value = QuantizedValue(
  27. node.output[0],
  28. node.output[0] + TENSOR_NAME_QUANT_SUFFIX,
  29. output_scale_name,
  30. output_zp_name,
  31. QuantizedValueType.Input,
  32. )
  33. self.quantizer.quantized_value_map[node.output[0]] = quantized_output_value
  34. kwargs = {}
  35. for attribute in node.attribute:
  36. kwargs.update(attribute_to_kwarg(attribute))
  37. kwargs["domain"] = ms_domain
  38. kwargs["channels_last"] = 0
  39. qnode_name = node.name + "_quant" if node.name != "" else ""
  40. qnode = onnx.helper.make_node(
  41. "QLinear" + node.op_type,
  42. [
  43. quantized_input_value.q_name,
  44. quantized_input_value.scale_name,
  45. quantized_input_value.zp_name,
  46. output_scale_name,
  47. output_zp_name,
  48. ],
  49. [quantized_output_value.q_name],
  50. qnode_name,
  51. **kwargs,
  52. )
  53. self.quantizer.new_nodes += [qnode]