图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

56 lines
2.0 KiB

  1. import onnx
  2. from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain
  3. from .base_operator import QuantOperatorBase
  4. from .qdq_base_operator import QDQOperatorBase
  5. class QLinearConcat(QuantOperatorBase):
  6. def __init__(self, onnx_quantizer, onnx_node):
  7. super().__init__(onnx_quantizer, onnx_node)
  8. def quantize(self):
  9. node = self.node
  10. (
  11. data_found,
  12. output_scale_name,
  13. output_zp_name,
  14. _,
  15. _,
  16. ) = self.quantizer._get_quantization_params(node.output[0])
  17. (
  18. q_input_names,
  19. zero_point_names,
  20. scale_names,
  21. nodes,
  22. ) = self.quantizer.quantize_activation(node, [*range(0, len(node.input))])
  23. if not data_found or q_input_names is None:
  24. return super().quantize()
  25. # Create an entry for output quantized value
  26. quantized_input_value = self.quantizer.quantized_value_map[node.input[0]]
  27. quantized_output_value = QuantizedValue(
  28. node.output[0],
  29. node.output[0] + TENSOR_NAME_QUANT_SUFFIX,
  30. output_scale_name,
  31. output_zp_name,
  32. quantized_input_value.value_type,
  33. )
  34. self.quantizer.quantized_value_map[node.output[0]] = quantized_output_value
  35. kwargs = {}
  36. for attribute in node.attribute:
  37. kwargs.update(attribute_to_kwarg(attribute))
  38. kwargs["domain"] = ms_domain
  39. qnode_name = node.name + "_quant" if node.name != "" else ""
  40. qlconcat_inputs = [output_scale_name, output_zp_name]
  41. for i in range(0, len(q_input_names)):
  42. qlconcat_inputs.extend([q_input_names[i], scale_names[i], zero_point_names[i]])
  43. qlconcat_node = onnx.helper.make_node(
  44. "QLinearConcat", qlconcat_inputs, [quantized_output_value.q_name], qnode_name, **kwargs
  45. )
  46. self.quantizer.new_nodes += nodes
  47. self.quantizer.new_nodes += [qlconcat_node]