图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

115 lines
4.6 KiB

  1. import numpy
  2. import onnx
  3. from onnx import onnx_pb as onnx_proto
  4. from ..quant_utils import QuantType, attribute_to_kwarg, ms_domain
  5. from .base_operator import QuantOperatorBase
  6. """
  7. Quantize LSTM
  8. """
  9. class LSTMQuant(QuantOperatorBase):
  10. def __init__(self, onnx_quantizer, onnx_node):
  11. super().__init__(onnx_quantizer, onnx_node)
  12. def quantize(self):
  13. """
  14. parameter node: LSTM node.
  15. parameter new_nodes_list: List of new nodes created before processing this node.
  16. return: a list of nodes in topological order that represents quantized Attention node.
  17. """
  18. node = self.node
  19. assert node.op_type == "LSTM"
  20. if not self.quantizer.is_valid_quantize_weight(node.input[1]) or not self.quantizer.is_valid_quantize_weight(
  21. node.input[2]
  22. ):
  23. super().quantize()
  24. return
  25. model = self.quantizer.model
  26. W = model.get_initializer(node.input[1])
  27. R = model.get_initializer(node.input[2])
  28. if len(W.dims) != 3 or len(R.dims) != 3:
  29. super().quantize()
  30. return
  31. [W_num_dir, W_4_hidden_size, W_input_size] = W.dims
  32. [R_num_dir, R_4_hidden_size, R_hidden_size] = R.dims
  33. if self.quantizer.is_per_channel():
  34. del W.dims[0]
  35. del R.dims[0]
  36. W.dims[0] = W_num_dir * W_4_hidden_size
  37. R.dims[0] = R_num_dir * R_4_hidden_size
  38. quant_input_weight_tuple = self.quantizer.quantize_weight_per_channel(
  39. node.input[1], onnx_proto.TensorProto.INT8, 0
  40. )
  41. quant_recurrent_weight_tuple = self.quantizer.quantize_weight_per_channel(
  42. node.input[2], onnx_proto.TensorProto.INT8, 0
  43. )
  44. W_quant_weight = model.get_initializer(quant_input_weight_tuple[0])
  45. R_quant_weight = model.get_initializer(quant_recurrent_weight_tuple[0])
  46. W_quant_array = onnx.numpy_helper.to_array(W_quant_weight)
  47. R_quant_array = onnx.numpy_helper.to_array(R_quant_weight)
  48. W_quant_array = numpy.reshape(W_quant_array, (W_num_dir, W_4_hidden_size, W_input_size))
  49. R_quant_array = numpy.reshape(R_quant_array, (R_num_dir, R_4_hidden_size, R_hidden_size))
  50. W_quant_array = numpy.transpose(W_quant_array, (0, 2, 1))
  51. R_quant_array = numpy.transpose(R_quant_array, (0, 2, 1))
  52. W_quant_tranposed = onnx.numpy_helper.from_array(W_quant_array, quant_input_weight_tuple[0])
  53. R_quant_tranposed = onnx.numpy_helper.from_array(R_quant_array, quant_recurrent_weight_tuple[0])
  54. model.remove_initializers([W_quant_weight, R_quant_weight])
  55. model.add_initializer(W_quant_tranposed)
  56. model.add_initializer(R_quant_tranposed)
  57. W_quant_zp = model.get_initializer(quant_input_weight_tuple[1])
  58. R_quant_zp = model.get_initializer(quant_recurrent_weight_tuple[1])
  59. W_quant_scale = model.get_initializer(quant_input_weight_tuple[2])
  60. R_quant_scale = model.get_initializer(quant_recurrent_weight_tuple[2])
  61. if self.quantizer.is_per_channel():
  62. W_quant_zp.dims[:] = [W_num_dir, W_4_hidden_size]
  63. R_quant_zp.dims[:] = [R_num_dir, R_4_hidden_size]
  64. W_quant_scale.dims[:] = [W_num_dir, W_4_hidden_size]
  65. R_quant_scale.dims[:] = [R_num_dir, R_4_hidden_size]
  66. inputs = []
  67. input_len = len(node.input)
  68. inputs.extend([node.input[0]])
  69. inputs.extend([quant_input_weight_tuple[0], quant_recurrent_weight_tuple[0]])
  70. inputs.extend([node.input[3] if input_len > 3 else ""])
  71. inputs.extend([node.input[4] if input_len > 4 else ""])
  72. inputs.extend([node.input[5] if input_len > 5 else ""])
  73. inputs.extend([node.input[6] if input_len > 6 else ""])
  74. inputs.extend([node.input[7] if input_len > 7 else ""])
  75. inputs.extend(
  76. [
  77. quant_input_weight_tuple[2],
  78. quant_input_weight_tuple[1],
  79. quant_recurrent_weight_tuple[2],
  80. quant_recurrent_weight_tuple[1],
  81. ]
  82. )
  83. kwargs = {}
  84. for attribute in node.attribute:
  85. kwargs.update(attribute_to_kwarg(attribute))
  86. kwargs["domain"] = ms_domain
  87. quant_lstm_name = "" if node.name == "" else node.name + "_quant"
  88. quant_lstm_node = onnx.helper.make_node("DynamicQuantizeLSTM", inputs, node.output, quant_lstm_name, **kwargs)
  89. self.quantizer.new_nodes.append(quant_lstm_node)
  90. dequantize_node = self.quantizer._dequantize_value(node.input[0])
  91. if dequantize_node is not None:
  92. self.quantizer.new_nodes.append(dequantize_node)