m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

614 lines
19 KiB

6 months ago
  1. import logging
  2. import tempfile
  3. from enum import Enum
  4. from pathlib import Path
  5. import numpy
  6. import onnx
  7. from onnx import external_data_helper
  8. from onnx import onnx_pb as onnx_proto
  9. from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions
  10. __producer__ = "onnx.quantize"
  11. __version__ = "0.1.0"
  12. onnx_domain = "ai.onnx"
  13. ms_domain = "com.microsoft"
  14. QUANT_OP_NAME = "QuantizeLinear"
  15. QUANT_INPUT_SUFFIX = "_QuantizeLinear_Input"
  16. DEQUANT_OP_NAME = "DequantizeLinear"
  17. DEQUANT_OUTPUT_SUFFIX = "_DequantizeLinear_Output"
  18. TENSOR_NAME_QUANT_SUFFIX = "_quantized"
  19. type_to_name = {
  20. 1: "FLOAT",
  21. 2: "UINT8",
  22. 3: "INT8",
  23. 4: "UINT16",
  24. 5: "INT16",
  25. 6: "INT32",
  26. 7: "INT64",
  27. 8: "STRING",
  28. 9: "BOOL",
  29. 10: "FLOAT16",
  30. 11: "DOUBLE",
  31. 12: "UINT32",
  32. 13: "UINT64",
  33. 14: "COMPLEX64",
  34. 15: "COMPLEX128",
  35. }
  36. # Quantization mode
  37. # IntegerOps: Use IntegerOps in quantized model. Only ConvInteger and MatMulInteger ops are supported now.
  38. # QLinearOps: Use QLinearOps in quantized model. Only QLinearConv and QLinearMatMul ops are supported now.
  39. class QuantizationMode(Enum):
  40. IntegerOps = 0
  41. QLinearOps = 1
  42. def __str__(self):
  43. return self.name
  44. @staticmethod
  45. def from_string(mode):
  46. try:
  47. return QuantizationMode[mode]
  48. except KeyError:
  49. raise ValueError()
  50. class QuantizedValueType(Enum):
  51. Input = 0
  52. Initializer = 1
  53. def __str__(self):
  54. return self.name
  55. @staticmethod
  56. def from_string(v):
  57. try:
  58. return QuantizedValueType[v]
  59. except KeyError:
  60. raise ValueError()
  61. class QuantType(Enum):
  62. QInt8 = 0
  63. QUInt8 = 1
  64. def __str__(self):
  65. return self.name
  66. @staticmethod
  67. def from_string(t):
  68. try:
  69. return QuantType[t]
  70. except KeyError:
  71. raise ValueError()
  72. class QuantFormat(Enum):
  73. QOperator = 0
  74. QDQ = 1
  75. def __str__(self):
  76. return self.name
  77. @staticmethod
  78. def from_string(format):
  79. try:
  80. return QuantFormat[format]
  81. except KeyError:
  82. raise ValueError()
  83. ONNX_TYPE_TO_NP_TYPE = {
  84. onnx_proto.TensorProto.INT8: numpy.dtype("int8"),
  85. onnx_proto.TensorProto.UINT8: numpy.dtype("uint8"),
  86. }
  87. def quantize_nparray(qType, arr, scale, zero_point, low=None, high=None):
  88. assert (
  89. qType in ONNX_TYPE_TO_NP_TYPE
  90. ), "Unexpected data type {} requested. Only INT8 and UINT8 are supported.".format(qType)
  91. dtype = ONNX_TYPE_TO_NP_TYPE[qType]
  92. cliplow = max(0 if dtype == numpy.uint8 else -127, -127 if low is None else low)
  93. cliphigh = min(255 if dtype == numpy.uint8 else 127, 255 if high is None else high)
  94. arr_fp32 = numpy.asarray((arr.astype(numpy.float32) / scale).round() + zero_point)
  95. numpy.clip(arr_fp32, cliplow, cliphigh, out=arr_fp32)
  96. return arr_fp32.astype(dtype)
  97. def compute_scale_zp(rmin, rmax, qmin, qmax, symmetric=False):
  98. """Calculate the scale s and zero point z for the quantization relation
  99. r = s(q-z), where r are the original values and q are the corresponding
  100. quantized values.
  101. r and z are calculated such that every value within [rmin,rmax] has an
  102. approximate representation within [qmin,qmax]. In addition, qmin <= z <=
  103. qmax is enforced. If the symmetric flag is set to True, the interval
  104. [rmin,rmax] is symmetrized to [-absmax, +absmax], where
  105. absmax = max(abs(rmin), abs(rmax)).
  106. :parameter rmin: minimum value of r
  107. :parameter rmax: maximum value of r
  108. :parameter qmin: minimum value representable by the target quantization data type
  109. :parameter qmax: maximum value representable by the target quantization data type
  110. :return: zero and scale [z, s]
  111. """
  112. if qmin > 0 or qmax < 0:
  113. raise ValueError(f"qmin and qmax must meet requirement: qmin <= 0 <= qmax while qmin:{qmin}, qmmax:{qmax}")
  114. # Adjust rmin and rmax such that 0 is included in the range. This is
  115. # required to make sure zero can be represented by the quantization data
  116. # type (i.e. to make sure qmin <= zero_point <= qmax)
  117. rmin = min(rmin, 0)
  118. rmax = max(rmax, 0)
  119. if symmetric:
  120. absmax = max(abs(rmin), abs(rmax))
  121. rmin = -absmax
  122. rmax = +absmax
  123. scale = (rmax - rmin) / float(qmax - qmin)
  124. if scale < numpy.finfo(numpy.float32).tiny:
  125. scale = 1.0
  126. zero_point = 0
  127. else:
  128. zero_point = round(qmin - rmin / scale)
  129. return [zero_point, scale]
  130. def quantize_data(data, qType, symmetric, reduce_range=False):
  131. """
  132. :param data: data to quantize
  133. :param qType: data type to quantize to. Supported types UINT8 and INT8
  134. :param symmetric: whether symmetric quantization is used or not. This is applied to INT8.
  135. :return: minimum, maximum, zero point, scale, and quantized weights
  136. To pack weights, we compute a linear transformation
  137. - when data `type == uint8` mode, from `[rmin, rmax]` -> :math:`[0, 2^{b-1}]` and
  138. - when data `type == int8`, from `[-m , m]` -> :math:`[-(2^{b-1}-1), 2^{b-1}-1]` where
  139. `m = max(abs(rmin), abs(rmax))`
  140. and add necessary intermediate nodes to trasnform quantized weight to full weight using the equation
  141. :math:`r = S(q-z)`, where
  142. - *r*: real original value
  143. - *q*: quantized value
  144. - *S*: scale
  145. - *z*: zero point
  146. """
  147. rmin = 0
  148. rmax = 0
  149. zero_point = 0
  150. scale = 1.0
  151. if len(data):
  152. rmin = min(data)
  153. rmax = max(data)
  154. qmin, qmax = get_qmin_qmax_for_qType(qType, reduce_range, symmetric=symmetric)
  155. zero_point, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric)
  156. quantized_data = quantize_nparray(qType, numpy.asarray(data), scale, zero_point)
  157. return rmin, rmax, zero_point, scale, quantized_data
  158. def get_qmin_qmax_for_qType(qType, reduce_range=False, symmetric=False):
  159. """
  160. Return qmin and qmax, the minimum and maximum value representable by the given qType
  161. :parameter qType: onnx.onnx_pb.TensorProto.UINT8 or onnx.onnx_pb.TensorProto.UINT8
  162. :return: qmin, qmax
  163. """
  164. if qType == onnx_proto.TensorProto.UINT8:
  165. (qmin, qmax) = (0, 127) if reduce_range else (0, 255)
  166. elif qType == onnx_proto.TensorProto.INT8:
  167. if symmetric:
  168. (qmin, qmax) = (-64, 64) if reduce_range else (-127, 127)
  169. else:
  170. (qmin, qmax) = (-64, 64) if reduce_range else (-128, 127)
  171. else:
  172. raise ValueError("Unexpected data type {} requested. Only INT8 and UINT8 are supported.".format(qType))
  173. return qmin, qmax
  174. def get_qrange_for_qType(qType, reduce_range=False, symmetric=False):
  175. """
  176. Helper function to get the quantization range for a type.
  177. parameter qType: quantization type.
  178. return: quantization range.
  179. """
  180. qmin, qmax = get_qmin_qmax_for_qType(qType, reduce_range, symmetric=symmetric)
  181. return qmax - qmin
  182. class QuantizedInitializer:
  183. """
  184. Represents a linearly quantized weight input from ONNX operators
  185. """
  186. def __init__(
  187. self,
  188. name,
  189. initializer,
  190. rmins,
  191. rmaxs,
  192. zero_points,
  193. scales,
  194. data=[],
  195. quantized_data=[],
  196. axis=None,
  197. ):
  198. self.name = name
  199. self.initializer = initializer # TensorProto initializer in ONNX graph
  200. self.rmins = rmins # List of minimum range for each axis
  201. self.rmaxs = rmaxs # List of maximum range for each axis
  202. # 1D tensor of zero points computed for each axis. scalar if axis is empty
  203. self.zero_points = zero_points
  204. self.scales = scales # 1D tensor of scales computed for each axis. scalar if axis is empty
  205. self.data = data # original data from initializer TensorProto
  206. self.quantized_data = quantized_data # weight-packed data from data
  207. # Scalar to specify which dimension in the initializer to weight pack.
  208. self.axis = axis
  209. # If empty, single zero point and scales computed from a single rmin and rmax
  210. class QuantizedValue:
  211. """
  212. Represents a linearly quantized value (input\output\intializer)
  213. """
  214. def __init__(
  215. self,
  216. name,
  217. new_quantized_name,
  218. scale_name,
  219. zero_point_name,
  220. quantized_value_type,
  221. axis=None,
  222. ):
  223. self.original_name = name
  224. self.q_name = new_quantized_name
  225. self.scale_name = scale_name
  226. self.zp_name = zero_point_name
  227. self.value_type = quantized_value_type
  228. self.axis = axis
  229. class BiasToQuantize:
  230. """
  231. Represents a bias to be quantized
  232. """
  233. def __init__(self, bias_name, input_name, weight_name):
  234. self.bias_name = bias_name
  235. self.input_name = input_name
  236. self.weight_name = weight_name
  237. def attribute_to_kwarg(attribute):
  238. """
  239. Convert attribute to kwarg format for use with onnx.helper.make_node.
  240. :parameter attribute: attribute in AttributeProto format.
  241. :return: attribute in {key: value} format.
  242. """
  243. if attribute.type == 0:
  244. raise ValueError("attribute {} does not have type specified.".format(attribute.name))
  245. # Based on attribute type definitions from AttributeProto
  246. # definition in https://github.com/onnx/onnx/blob/main/onnx/onnx.proto
  247. if attribute.type == 1:
  248. value = attribute.f
  249. elif attribute.type == 2:
  250. value = attribute.i
  251. elif attribute.type == 3:
  252. value = attribute.s
  253. elif attribute.type == 4:
  254. value = attribute.t
  255. elif attribute.type == 5:
  256. value = attribute.g
  257. elif attribute.type == 6:
  258. value = attribute.floats
  259. elif attribute.type == 7:
  260. value = attribute.ints
  261. elif attribute.type == 8:
  262. value = attribute.strings
  263. elif attribute.type == 9:
  264. value = attribute.tensors
  265. elif attribute.type == 10:
  266. value = attribute.graphs
  267. else:
  268. raise ValueError("attribute {} has unsupported type {}.".format(attribute.name, attribute.type))
  269. return {attribute.name: value}
  270. def find_by_name(item_name, item_list):
  271. """
  272. Helper function to find item by name in a list.
  273. parameter item_name: name of the item.
  274. parameter item_list: list of items.
  275. return: item if found. None otherwise.
  276. """
  277. items = [item for item in item_list if item.name == item_name]
  278. return items[0] if len(items) > 0 else None
  279. def get_elem_index(elem_name, elem_list):
  280. """
  281. Helper function to return index of an item in a node list
  282. """
  283. elem_idx = -1
  284. for i in range(0, len(elem_list)):
  285. if elem_list[i] == elem_name:
  286. elem_idx = i
  287. return elem_idx
  288. def get_mul_node(inputs, output, name):
  289. """
  290. Helper function to create a Mul node.
  291. parameter inputs: list of input names.
  292. parameter output: output name.
  293. parameter name: name of the node.
  294. return: Mul node in NodeProto format.
  295. """
  296. return onnx.helper.make_node("Mul", inputs, [output], name)
  297. def generate_identified_filename(filename: Path, identifier: str) -> Path:
  298. """
  299. Helper function to generate a identifiable filepath by concatenating the given identifier as a suffix.
  300. """
  301. return filename.parent.joinpath(filename.stem + identifier + filename.suffix)
  302. def apply_plot(hist, hist_edges):
  303. import sys
  304. import matplotlib.pyplot as plt
  305. import numpy
  306. numpy.set_printoptions(threshold=sys.maxsize)
  307. print("Histogram:")
  308. print(hist)
  309. print("Histogram Edges:")
  310. print(hist_edges)
  311. plt.stairs(hist, hist_edges, fill=True)
  312. plt.xlabel("Tensor value")
  313. plt.ylabel("Counts")
  314. plt.title("Tensor value V.S. Counts")
  315. plt.show()
  316. def write_calibration_table(calibration_cache):
  317. """
  318. Helper function to write calibration table to files.
  319. """
  320. import json
  321. import flatbuffers
  322. import onnxruntime.quantization.CalTableFlatBuffers.KeyValue as KeyValue
  323. import onnxruntime.quantization.CalTableFlatBuffers.TrtTable as TrtTable
  324. logging.info("calibration cache: {}".format(calibration_cache))
  325. with open("calibration.json", "w") as file:
  326. file.write(json.dumps(calibration_cache)) # use `json.loads` to do the reverse
  327. # Serialize data using FlatBuffers
  328. builder = flatbuffers.Builder(1024)
  329. key_value_list = []
  330. for key in sorted(calibration_cache.keys()):
  331. values = calibration_cache[key]
  332. value = str(max(abs(values[0]), abs(values[1])))
  333. flat_key = builder.CreateString(key)
  334. flat_value = builder.CreateString(value)
  335. KeyValue.KeyValueStart(builder)
  336. KeyValue.KeyValueAddKey(builder, flat_key)
  337. KeyValue.KeyValueAddValue(builder, flat_value)
  338. key_value = KeyValue.KeyValueEnd(builder)
  339. key_value_list.append(key_value)
  340. TrtTable.TrtTableStartDictVector(builder, len(key_value_list))
  341. for key_value in key_value_list:
  342. builder.PrependUOffsetTRelative(key_value)
  343. main_dict = builder.EndVector()
  344. TrtTable.TrtTableStart(builder)
  345. TrtTable.TrtTableAddDict(builder, main_dict)
  346. cal_table = TrtTable.TrtTableEnd(builder)
  347. builder.Finish(cal_table)
  348. buf = builder.Output()
  349. with open("calibration.flatbuffers", "wb") as file:
  350. file.write(buf)
  351. # Deserialize data (for validation)
  352. if False:
  353. cal_table = TrtTable.TrtTable.GetRootAsTrtTable(buf, 0)
  354. dict_len = cal_table.DictLength()
  355. for i in range(dict_len):
  356. key_value = cal_table.Dict(i)
  357. logging.info(key_value.Key())
  358. logging.info(key_value.Value())
  359. # write plain text
  360. with open("calibration.cache", "w") as file:
  361. for key in sorted(calibration_cache.keys()):
  362. value = calibration_cache[key]
  363. s = key + " " + str(max(abs(value[0]), abs(value[1])))
  364. file.write(s)
  365. file.write("\n")
  366. def smooth_distribution(p, eps=0.0001):
  367. """Given a discrete distribution (may have not been normalized to 1),
  368. smooth it by replacing zeros with eps multiplied by a scaling factor
  369. and taking the corresponding amount off the non-zero values.
  370. Ref: http://web.engr.illinois.edu/~hanj/cs412/bk3/KL-divergence.pdf
  371. https://github.com//apache/incubator-mxnet/blob/master/python/mxnet/contrib/quantization.py
  372. """
  373. import numpy as np
  374. is_zeros = (p == 0).astype(np.float32)
  375. is_nonzeros = (p != 0).astype(np.float32)
  376. n_zeros = is_zeros.sum()
  377. n_nonzeros = p.size - n_zeros
  378. if not n_nonzeros:
  379. # raise ValueError('The discrete probability distribution is malformed. All entries are 0.')
  380. return -1
  381. eps1 = eps * float(n_zeros) / float(n_nonzeros)
  382. assert eps1 < 1.0, "n_zeros=%d, n_nonzeros=%d, eps1=%f" % (
  383. n_zeros,
  384. n_nonzeros,
  385. eps1,
  386. )
  387. hist = p.astype(np.float32)
  388. hist += eps * is_zeros + (-eps1) * is_nonzeros
  389. assert (hist <= 0).sum() == 0
  390. return hist
  391. def model_has_external_data(model_path: Path):
  392. model = onnx.load(model_path.as_posix(), load_external_data=False)
  393. for intializer in model.graph.initializer:
  394. if external_data_helper.uses_external_data(intializer):
  395. return True
  396. return False
  397. def optimize_model(model_path: Path, opt_model_path: Path):
  398. """
  399. Generate model that applies graph optimization (constant folding, etc.)
  400. parameter model_path: path to the original onnx model
  401. parameter opt_model_path: path to the optimized onnx model
  402. :return: optimized onnx model
  403. """
  404. sess_option = SessionOptions()
  405. sess_option.optimized_model_filepath = opt_model_path.as_posix()
  406. sess_option.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
  407. _ = InferenceSession(model_path.as_posix(), sess_option, providers=["CPUExecutionProvider"])
  408. def add_pre_process_metadata(model):
  409. """Tag the model that it went through quantization pre-processing"""
  410. metadata_props = {"onnx.quant.pre_process": "onnxruntime.quant"}
  411. if model.metadata_props:
  412. for prop in model.metadata_props:
  413. metadata_props.update({prop.key: prop.value})
  414. onnx.helper.set_model_props(model, metadata_props)
  415. def model_has_pre_process_metadata(model):
  416. """Check the model whether it went through quantization pre-processing"""
  417. if model.metadata_props:
  418. for prop in model.metadata_props:
  419. if prop.key == "onnx.quant.pre_process" and prop.value == "onnxruntime.quant":
  420. return True
  421. return False
  422. def add_infer_metadata(model):
  423. metadata_props = {"onnx.infer": "onnxruntime.quant"}
  424. if model.metadata_props:
  425. for p in model.metadata_props:
  426. metadata_props.update({p.key: p.value})
  427. onnx.helper.set_model_props(model, metadata_props)
  428. def model_has_infer_metadata(model):
  429. if model.metadata_props:
  430. for p in model.metadata_props:
  431. if p.key == "onnx.infer" and p.value == "onnxruntime.quant":
  432. return True
  433. return False
  434. def load_model_with_shape_infer(model_path: Path):
  435. inferred_model_path = generate_identified_filename(model_path, "-inferred")
  436. onnx.shape_inference.infer_shapes_path(str(model_path), str(inferred_model_path))
  437. model = onnx.load(inferred_model_path.as_posix())
  438. inferred_model_path.unlink()
  439. return model
  440. def load_model(model_path: Path, need_optimize: bool):
  441. with tempfile.TemporaryDirectory(prefix="ort.quant.") as quant_tmp_dir:
  442. if need_optimize and not model_has_external_data(model_path):
  443. opt_model_path = Path(quant_tmp_dir).joinpath("model.onnx")
  444. optimize_model(model_path, opt_model_path)
  445. model_path = opt_model_path
  446. model = load_model_with_shape_infer(model_path)
  447. add_infer_metadata(model)
  448. return model
  449. def save_and_reload_model(model):
  450. with tempfile.TemporaryDirectory(prefix="ort.quant.") as quant_tmp_dir:
  451. model_path = Path(quant_tmp_dir).joinpath("model.onnx")
  452. onnx.external_data_helper.convert_model_to_external_data(model, all_tensors_to_one_file=True)
  453. onnx.save_model(model, model_path.as_posix())
  454. return load_model(model_path, False)
  455. def clone_model_with_shape_infer(model):
  456. if model_has_infer_metadata(model):
  457. cloned_model = onnx_proto.ModelProto()
  458. cloned_model.CopyFrom(model)
  459. else:
  460. cloned_model = save_and_reload_model(model)
  461. return cloned_model
  462. def tensor_proto_to_array(initializer):
  463. if initializer.data_type == onnx_proto.TensorProto.FLOAT:
  464. return onnx.numpy_helper.to_array(initializer)
  465. raise ValueError(
  466. f"Only float type is supported. Weights {initializer.name} is {type_to_name[initializer.data_type]}"
  467. )
  468. def add_quant_suffix(tensor_name):
  469. return tensor_name + "_QuantizeLinear"
  470. def add_quant_input_suffix(tensor_name):
  471. return tensor_name + QUANT_INPUT_SUFFIX
  472. def add_quant_output_suffix(tensor_name):
  473. return tensor_name + "_QuantizeLinear_Output"
  474. def add_dequant_suffix(tensor_name):
  475. return tensor_name + "_DequantizeLinear"
  476. def add_dequant_input_suffix(tensor_name):
  477. return tensor_name + "_DequantizeLinear_Input"
  478. def add_dequant_output_suffix(tensor_name):
  479. return tensor_name + DEQUANT_OUTPUT_SUFFIX