图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

63 lines
2.1 KiB

  1. import onnx
  2. from ..quant_utils import QuantizedValue, QuantizedValueType, attribute_to_kwarg
  3. from .base_operator import QuantOperatorBase
  4. from .qdq_base_operator import QDQOperatorBase
  5. class QSplit(QuantOperatorBase):
  6. def __init__(self, onnx_quantizer, onnx_node):
  7. super().__init__(onnx_quantizer, onnx_node)
  8. def quantize(self):
  9. node = self.node
  10. (
  11. quantized_input_names,
  12. zero_point_names,
  13. scale_names,
  14. nodes,
  15. ) = self.quantizer.quantize_activation(node, [0])
  16. if quantized_input_names is None:
  17. return super().quantize()
  18. quantized_node_name = ""
  19. if node.name != "":
  20. quantized_node_name = node.name + "_quant"
  21. kwargs = {}
  22. for attribute in node.attribute:
  23. kwargs.update(attribute_to_kwarg(attribute))
  24. # Output just derive the scale/zero from input
  25. quantized_output_names = []
  26. for output_name in node.output:
  27. quantized_output_name = output_name + "quantized"
  28. quantized_output_names.append(quantized_output_name)
  29. q_output = QuantizedValue(
  30. output_name,
  31. quantized_output_name,
  32. scale_names[0],
  33. zero_point_names[0],
  34. QuantizedValueType.Input,
  35. )
  36. self.quantizer.quantized_value_map[output_name] = q_output
  37. if len(node.input) > 1:
  38. quantized_input_names.extend(node.input[1:])
  39. quantized_node = onnx.helper.make_node(
  40. node.op_type, quantized_input_names, quantized_output_names, quantized_node_name, **kwargs
  41. )
  42. nodes.append(quantized_node)
  43. self.quantizer.new_nodes += nodes
  44. class QDQSplit(QDQOperatorBase):
  45. def quantize(self):
  46. node = self.node
  47. assert node.op_type == "Split"
  48. if not self.quantizer.is_tensor_quantized(node.input[0]):
  49. self.quantizer.quantize_activation_tensor(node.input[0])
  50. if not self.disable_qdq_for_node_output:
  51. for output in node.output:
  52. self.quantizer.quantize_activation_tensor(output, node.input[0])