图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

73 lines
2.5 KiB

  1. import onnx
  2. from onnx import onnx_pb as onnx_proto
  3. from ..quant_utils import attribute_to_kwarg, ms_domain
  4. from .base_operator import QuantOperatorBase
  5. """
  6. Quantize Attention
  7. """
  8. class AttentionQuant(QuantOperatorBase):
  9. def __init__(self, onnx_quantizer, onnx_node):
  10. super().__init__(onnx_quantizer, onnx_node)
  11. def should_quantize(self):
  12. return self.quantizer.should_quantize_node(self.node)
  13. def quantize(self):
  14. """
  15. parameter node: Attention node.
  16. parameter new_nodes_list: List of new nodes created before processing this node.
  17. return: a list of nodes in topological order that represents quantized Attention node.
  18. """
  19. node = self.node
  20. assert node.op_type == "Attention"
  21. # TODO This is a temporary fix to stop exporting QAttention with qkv_hidden_sizes
  22. # attribute. This needs to be removed once the QAttention for varied q,k,v sizes
  23. # is implemented
  24. for attr in node.attribute:
  25. if "qkv_hidden_sizes" == attr.name:
  26. return super().quantize()
  27. (
  28. quantized_input_names,
  29. zero_point_names,
  30. scale_names,
  31. nodes,
  32. ) = self.quantizer.quantize_activation(node, [0])
  33. (
  34. quantized_input_names_weight,
  35. zero_point_names_weight,
  36. scale_names_weight,
  37. nodes_weight,
  38. ) = self.quantizer.quantize_weight(node, [1], reduce_range=True, op_level_per_channel=True)
  39. quantized_input_names.extend(quantized_input_names_weight)
  40. zero_point_names.extend(zero_point_names_weight)
  41. scale_names.extend(scale_names_weight)
  42. nodes.extend(nodes_weight)
  43. if quantized_input_names is None:
  44. return super().quantize()
  45. qattention_name = "" if node.name == "" else node.name + "_quant"
  46. inputs = []
  47. inputs.extend(quantized_input_names)
  48. inputs.extend([node.input[2]])
  49. inputs.extend(scale_names)
  50. inputs.extend([node.input[3] if len(node.input) > 3 else ""])
  51. inputs.extend(zero_point_names)
  52. inputs.extend([node.input[4] if len(node.input) > 4 else ""])
  53. kwargs = {}
  54. for attribute in node.attribute:
  55. kwargs.update(attribute_to_kwarg(attribute))
  56. kwargs["domain"] = ms_domain
  57. qattention_node = onnx.helper.make_node("QAttention", inputs, node.output, qattention_name, **kwargs)
  58. nodes.append(qattention_node)
  59. self.quantizer.new_nodes += nodes