图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

78 lines
3.2 KiB

  1. from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType
  2. from .base_operator import QuantOperatorBase
  3. from .qdq_base_operator import QDQOperatorBase
  4. # For operators that support 8bits operations directly, and output could
  5. # reuse input[0]'s type, zeropoint, scale; For example,Transpose, Reshape, etc.
  6. class Direct8BitOp(QuantOperatorBase):
  7. def __init__(self, onnx_quantizer, onnx_node):
  8. super().__init__(onnx_quantizer, onnx_node)
  9. def quantize(self):
  10. node = self.node
  11. if not self.quantizer.force_quantize_no_input_check:
  12. # Keep backward compatiblity
  13. # Quantize when input[0] is quantized already. Otherwise keep it.
  14. quantized_input_value = self.quantizer.find_quantized_value(node.input[0])
  15. if quantized_input_value is None:
  16. self.quantizer.new_nodes += [node]
  17. return
  18. quantized_output_value = QuantizedValue(
  19. node.output[0],
  20. node.output[0] + TENSOR_NAME_QUANT_SUFFIX,
  21. quantized_input_value.scale_name,
  22. quantized_input_value.zp_name,
  23. quantized_input_value.value_type,
  24. )
  25. self.quantizer.quantized_value_map[node.output[0]] = quantized_output_value
  26. node.input[0] = quantized_input_value.q_name
  27. node.output[0] = quantized_output_value.q_name
  28. self.quantizer.new_nodes += [node]
  29. else:
  30. # Force quantize those ops if possible, use exclude node list if this is not you want
  31. if not self.quantizer.is_valid_quantize_weight(node.input[0]):
  32. super().quantize()
  33. return
  34. (
  35. quantized_input_names,
  36. zero_point_names,
  37. scale_names,
  38. nodes,
  39. ) = self.quantizer.quantize_activation(node, [0])
  40. if quantized_input_names is None:
  41. return super().quantize()
  42. # Create an entry for output quantized value
  43. quantized_output_value = QuantizedValue(
  44. node.output[0],
  45. node.output[0] + TENSOR_NAME_QUANT_SUFFIX,
  46. scale_names[0],
  47. zero_point_names[0],
  48. QuantizedValueType.Input,
  49. )
  50. self.quantizer.quantized_value_map[node.output[0]] = quantized_output_value
  51. node.input[0] = quantized_input_names[0]
  52. node.output[0] = quantized_output_value.q_name
  53. nodes.append(node)
  54. self.quantizer.new_nodes += nodes
  55. class QDQDirect8BitOp(QDQOperatorBase):
  56. def __init__(self, onnx_quantizer, onnx_node):
  57. super().__init__(onnx_quantizer, onnx_node)
  58. def quantize(self):
  59. if self.quantizer.force_quantize_no_input_check:
  60. self.quantizer.quantize_activation_tensor(self.node.input[0])
  61. if not self.disable_qdq_for_node_output:
  62. self.quantizer.quantize_activation_tensor(self.node.output[0], self.node.input[0])
  63. elif self.quantizer.is_tensor_quantized(self.node.input[0]) and not self.disable_qdq_for_node_output:
  64. self.quantizer.quantize_activation_tensor(self.node.output[0], self.node.input[0])