m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

75 lines
2.7 KiB

# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import logging
import os
import onnx
import torch
from transformers.modeling_utils import Conv1D
logger = logging.getLogger(__name__)
def _conv1d_to_linear(module):
in_size, out_size = module.weight.shape
linear = torch.nn.Linear(in_size, out_size)
linear.weight.data = module.weight.data.T.contiguous()
linear.bias.data = module.bias.data
return linear
def conv1d_to_linear(model):
"""in-place
This is for Dynamic Quantization, as Conv1D is not recognized by PyTorch, convert it to nn.Linear
"""
logger.debug("replace Conv1D with Linear")
for name in list(model._modules):
module = model._modules[name]
if isinstance(module, Conv1D):
linear = _conv1d_to_linear(module)
model._modules[name] = linear
else:
conv1d_to_linear(module)
def _get_size_of_pytorch_model(model):
torch.save(model.state_dict(), "temp.p")
size = os.path.getsize("temp.p") / (1024 * 1024)
os.remove("temp.p")
return size
class QuantizeHelper:
@staticmethod
def quantize_torch_model(model, dtype=torch.qint8):
"""
Usage: model = quantize_model(model)
TODO: mix of in-place and return, but results are different
"""
conv1d_to_linear(model)
quantized_model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=dtype)
logger.info(f"Size of full precision Torch model(MB):{_get_size_of_pytorch_model(model)}")
logger.info(f"Size of quantized Torch model(MB):{_get_size_of_pytorch_model(quantized_model)}")
return quantized_model
@staticmethod
def quantize_onnx_model(onnx_model_path, quantized_model_path, use_external_data_format=False):
from pathlib import Path
from onnxruntime.quantization import quantize_dynamic
Path(quantized_model_path).parent.mkdir(parents=True, exist_ok=True)
logger.info(f"Size of full precision ONNX model(MB):{os.path.getsize(onnx_model_path)/(1024*1024)}")
quantize_dynamic(
onnx_model_path,
quantized_model_path,
use_external_data_format=use_external_data_format,
)
logger.info(f"quantized model saved to:{quantized_model_path}")
# TODO: inlcude external data in total model size.
logger.info(f"Size of quantized ONNX model(MB):{os.path.getsize(quantized_model_path)/(1024*1024)}")