图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

64 lines
2.1 KiB

  1. from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType
  2. from .base_operator import QuantOperatorBase
  3. from .qdq_base_operator import QDQOperatorBase
  4. """
  5. Quantize Gather
  6. """
  7. class GatherQuant(QuantOperatorBase):
  8. def __init__(self, onnx_quantizer, onnx_node):
  9. super().__init__(onnx_quantizer, onnx_node)
  10. def should_quantize(self):
  11. if not self.quantizer.should_quantize_node(self.node):
  12. return False
  13. return self.quantizer.is_valid_quantize_weight(self.node.input[0])
  14. def quantize(self):
  15. node = self.node
  16. assert node.op_type == "Gather"
  17. (
  18. quantized_input_names,
  19. zero_point_names,
  20. scale_names,
  21. nodes,
  22. ) = self.quantizer.quantize_activation(node, [0])
  23. if quantized_input_names is None:
  24. return super().quantize()
  25. gather_new_output = node.output[0] + TENSOR_NAME_QUANT_SUFFIX
  26. # Create an entry for this quantized value
  27. q_output = QuantizedValue(
  28. node.output[0],
  29. gather_new_output,
  30. scale_names[0],
  31. zero_point_names[0],
  32. QuantizedValueType.Input,
  33. )
  34. self.quantizer.quantized_value_map[node.output[0]] = q_output
  35. node.output[0] = gather_new_output
  36. node.input[0] = quantized_input_names[0]
  37. nodes.append(node)
  38. self.quantizer.new_nodes += nodes
  39. class QDQGather(QDQOperatorBase):
  40. def __init__(self, onnx_quantizer, onnx_node):
  41. super().__init__(onnx_quantizer, onnx_node)
  42. def quantize(self):
  43. node = self.node
  44. assert node.op_type == "Gather"
  45. if self.quantizer.is_valid_quantize_weight(node.input[0]) or self.quantizer.force_quantize_no_input_check:
  46. self.quantizer.quantize_activation_tensor(node.input[0])
  47. self.quantizer.quantize_activation_tensor(node.output[0], node.input[0])
  48. elif self.quantizer.is_tensor_quantized(node.input[0]):
  49. self.quantizer.quantize_activation_tensor(node.output[0], node.input[0])