m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

174 lines
6.1 KiB

# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import logging
import os
import random
import sys
import tempfile
from pathlib import Path
from typing import List, Union
import numpy
import onnx
import torch
from transformers import MT5Config, T5Config
from onnxruntime import InferenceSession
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
from onnx_model import OnnxModel # noqa: E402
from torch_onnx_export_helper import torch_onnx_export # noqa: E402
logger = logging.getLogger(__name__)
class T5Encoder(torch.nn.Module):
"""T5 encoder outputs only the last hidden state"""
def __init__(self, encoder, config: Union[T5Config, MT5Config]):
super().__init__()
self.encoder = encoder
self.config = config
def forward(self, input_ids, attention_mask):
return self.encoder(input_ids, attention_mask)[0]
class T5EncoderInputs:
def __init__(self, input_ids, attention_mask):
self.input_ids: torch.LongTensor = input_ids
self.attention_mask: torch.LongTensor = attention_mask
@staticmethod
def create_dummy(
batch_size: int, sequence_length: int, vocab_size: int, device: torch.device, use_int32_inputs: bool = False
): # -> T5EncoderInputs
"""Create dummy inputs for T5 encoder.
Args:
batch_size (int): batch size
sequence_length (int): sequence length
vocab_size (int): vocabulary size
device (torch.device): device of output tensors
Returns:
T5EncoderInputs: dummy inputs for encoder
"""
dtype = torch.int32 if use_int32_inputs else torch.int64
input_ids = torch.randint(
low=0,
high=vocab_size - 1,
size=(batch_size, sequence_length),
dtype=dtype,
device=device,
)
attention_mask = torch.ones([batch_size, sequence_length], dtype=dtype, device=device)
if sequence_length >= 2:
for i in range(batch_size):
padding_position = random.randint(0, sequence_length - 1)
attention_mask[i, :padding_position] = 0
return T5EncoderInputs(input_ids, attention_mask)
def to_list(self) -> List:
input_list = [v for v in [self.input_ids, self.attention_mask] if v is not None]
return input_list
class T5EncoderHelper:
@staticmethod
def export_onnx(
encoder: T5Encoder,
device: torch.device,
onnx_model_path: str,
verbose: bool = True,
use_external_data_format: bool = False,
use_int32_inputs: bool = False,
):
"""Export encoder to ONNX
Args:
encoder (T5Encoder): encoder object
device (torch.device): device of encoder object
onnx_model_path (str): onnx path
verbose (bool, optional): print verbose information. Defaults to True.
use_external_data_format (bool, optional): use external data format or not. Defaults to False.
"""
config = encoder.config
encoder_inputs = T5EncoderInputs.create_dummy(
batch_size=2,
sequence_length=4,
vocab_size=config.vocab_size,
device=device,
use_int32_inputs=use_int32_inputs,
)
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
with tempfile.TemporaryDirectory() as tmp_dir_name:
temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder.onnx")
Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
torch_onnx_export(
encoder,
args=tuple(encoder_inputs.to_list()),
f=temp_onnx_model_path if use_external_data_format else onnx_model_path,
export_params=True,
input_names=["input_ids", "attention_mask"],
output_names=["hidden_states"],
dynamic_axes={
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
"hidden_states": {0: "batch_size", 1: "sequence_length"},
},
opset_version=12,
do_constant_folding=True,
use_external_data_format=use_external_data_format,
verbose=verbose,
)
if use_external_data_format:
model = onnx.load_model(temp_onnx_model_path, load_external_data=True)
OnnxModel.save(
model,
onnx_model_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
)
@staticmethod
def onnxruntime_inference(ort_session, inputs: T5EncoderInputs):
"""Run inference of ONNX model."""
ort_inputs = {
"input_ids": numpy.ascontiguousarray(inputs.input_ids.cpu().numpy()),
"attention_mask": numpy.ascontiguousarray(inputs.attention_mask.cpu().numpy()),
}
return ort_session.run(None, ort_inputs)
@staticmethod
def verify_onnx(
model: T5Encoder, ort_session: InferenceSession, device: torch.device, use_int32_inputs: bool = False
):
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
inputs = T5EncoderInputs.create_dummy(
batch_size=4,
sequence_length=11,
vocab_size=model.config.vocab_size,
device=device,
use_int32_inputs=use_int32_inputs,
)
input_list = inputs.to_list()
torch_outputs = model(*input_list)
ort_outputs = T5EncoderHelper.onnxruntime_inference(ort_session, inputs)
max_diff = numpy.amax(numpy.abs(torch_outputs.cpu().numpy() - ort_outputs[0]))
logger.info(f"max_diff={max_diff}")
return max_diff