m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

75 lines
2.7 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. import logging
  7. import os
  8. import onnx
  9. import torch
  10. from transformers.modeling_utils import Conv1D
  11. logger = logging.getLogger(__name__)
  12. def _conv1d_to_linear(module):
  13. in_size, out_size = module.weight.shape
  14. linear = torch.nn.Linear(in_size, out_size)
  15. linear.weight.data = module.weight.data.T.contiguous()
  16. linear.bias.data = module.bias.data
  17. return linear
  18. def conv1d_to_linear(model):
  19. """in-place
  20. This is for Dynamic Quantization, as Conv1D is not recognized by PyTorch, convert it to nn.Linear
  21. """
  22. logger.debug("replace Conv1D with Linear")
  23. for name in list(model._modules):
  24. module = model._modules[name]
  25. if isinstance(module, Conv1D):
  26. linear = _conv1d_to_linear(module)
  27. model._modules[name] = linear
  28. else:
  29. conv1d_to_linear(module)
  30. def _get_size_of_pytorch_model(model):
  31. torch.save(model.state_dict(), "temp.p")
  32. size = os.path.getsize("temp.p") / (1024 * 1024)
  33. os.remove("temp.p")
  34. return size
  35. class QuantizeHelper:
  36. @staticmethod
  37. def quantize_torch_model(model, dtype=torch.qint8):
  38. """
  39. Usage: model = quantize_model(model)
  40. TODO: mix of in-place and return, but results are different
  41. """
  42. conv1d_to_linear(model)
  43. quantized_model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=dtype)
  44. logger.info(f"Size of full precision Torch model(MB):{_get_size_of_pytorch_model(model)}")
  45. logger.info(f"Size of quantized Torch model(MB):{_get_size_of_pytorch_model(quantized_model)}")
  46. return quantized_model
  47. @staticmethod
  48. def quantize_onnx_model(onnx_model_path, quantized_model_path, use_external_data_format=False):
  49. from pathlib import Path
  50. from onnxruntime.quantization import quantize_dynamic
  51. Path(quantized_model_path).parent.mkdir(parents=True, exist_ok=True)
  52. logger.info(f"Size of full precision ONNX model(MB):{os.path.getsize(onnx_model_path)/(1024*1024)}")
  53. quantize_dynamic(
  54. onnx_model_path,
  55. quantized_model_path,
  56. use_external_data_format=use_external_data_format,
  57. )
  58. logger.info(f"quantized model saved to:{quantized_model_path}")
  59. # TODO: inlcude external data in total model size.
  60. logger.info(f"Size of quantized ONNX model(MB):{os.path.getsize(quantized_model_path)/(1024*1024)}")