图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

87 lines
3.0 KiB

  1. import onnx
  2. from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain
  3. from .base_operator import QuantOperatorBase
  4. from .qdq_base_operator import QDQOperatorBase
  5. class QLinearWhere(QuantOperatorBase):
  6. def should_quantize(self):
  7. return True
  8. def quantize(self):
  9. node = self.node
  10. assert node.op_type == "Where"
  11. if not self.quantizer.force_quantize_no_input_check:
  12. self.quantizer.new_nodes += [node]
  13. return
  14. (
  15. data_found,
  16. output_scale_name,
  17. output_zp_name,
  18. _,
  19. _,
  20. ) = self.quantizer._get_quantization_params(node.output[0])
  21. (
  22. q_input_names,
  23. zero_point_names,
  24. scale_names,
  25. nodes,
  26. ) = self.quantizer.quantize_activation(node, [1, 2])
  27. if not data_found or q_input_names is None:
  28. return super().quantize()
  29. qlinear_output = node.output[0] + TENSOR_NAME_QUANT_SUFFIX
  30. qlinear_output_name = node.name + "_quant" if node.name != "" else ""
  31. q_output = QuantizedValue(
  32. node.output[0],
  33. qlinear_output,
  34. output_scale_name,
  35. output_zp_name,
  36. QuantizedValueType.Input,
  37. )
  38. self.quantizer.quantized_value_map[node.output[0]] = q_output
  39. kwargs = {}
  40. for attribute in node.attribute:
  41. kwargs.update(attribute_to_kwarg(attribute))
  42. kwargs["domain"] = ms_domain
  43. qlwhere_inputs = [
  44. node.input[0],
  45. q_input_names[0],
  46. scale_names[0],
  47. zero_point_names[0],
  48. q_input_names[1],
  49. scale_names[1],
  50. zero_point_names[1],
  51. output_scale_name,
  52. output_zp_name,
  53. ]
  54. qlwhere_node = onnx.helper.make_node(
  55. "QLinearWhere", qlwhere_inputs, [qlinear_output], qlinear_output_name, **kwargs
  56. )
  57. self.quantizer.new_nodes += nodes
  58. self.quantizer.new_nodes += [qlwhere_node]
  59. class QDQWhere(QDQOperatorBase):
  60. def quantize(self):
  61. node = self.node
  62. assert node.op_type == "Where"
  63. if self.quantizer.force_quantize_no_input_check:
  64. if not self.quantizer.is_tensor_quantized(node.input[1]):
  65. self.quantizer.quantize_activation_tensor(node.input[1])
  66. if not self.quantizer.is_tensor_quantized(node.input[2]):
  67. self.quantizer.quantize_activation_tensor(node.input[2])
  68. if not self.disable_qdq_for_node_output:
  69. for output in node.output:
  70. self.quantizer.quantize_activation_tensor(output)
  71. elif (
  72. self.quantizer.is_tensor_quantized(node.input[1])
  73. and self.quantizer.is_tensor_quantized(node.input[2])
  74. and not self.disable_qdq_for_node_output
  75. ):
  76. for output in node.output:
  77. self.quantizer.quantize_activation_tensor(output)