You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
174 lines
6.1 KiB
174 lines
6.1 KiB
# -------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License. See License.txt in the project root for
|
|
# license information.
|
|
# --------------------------------------------------------------------------
|
|
|
|
import logging
|
|
import os
|
|
import random
|
|
import sys
|
|
import tempfile
|
|
from pathlib import Path
|
|
from typing import List, Union
|
|
|
|
import numpy
|
|
import onnx
|
|
import torch
|
|
from transformers import MT5Config, T5Config
|
|
|
|
from onnxruntime import InferenceSession
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
from onnx_model import OnnxModel # noqa: E402
|
|
from torch_onnx_export_helper import torch_onnx_export # noqa: E402
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class T5Encoder(torch.nn.Module):
|
|
"""T5 encoder outputs only the last hidden state"""
|
|
|
|
def __init__(self, encoder, config: Union[T5Config, MT5Config]):
|
|
super().__init__()
|
|
self.encoder = encoder
|
|
self.config = config
|
|
|
|
def forward(self, input_ids, attention_mask):
|
|
return self.encoder(input_ids, attention_mask)[0]
|
|
|
|
|
|
class T5EncoderInputs:
|
|
def __init__(self, input_ids, attention_mask):
|
|
self.input_ids: torch.LongTensor = input_ids
|
|
self.attention_mask: torch.LongTensor = attention_mask
|
|
|
|
@staticmethod
|
|
def create_dummy(
|
|
batch_size: int, sequence_length: int, vocab_size: int, device: torch.device, use_int32_inputs: bool = False
|
|
): # -> T5EncoderInputs
|
|
"""Create dummy inputs for T5 encoder.
|
|
|
|
Args:
|
|
batch_size (int): batch size
|
|
sequence_length (int): sequence length
|
|
vocab_size (int): vocabulary size
|
|
device (torch.device): device of output tensors
|
|
|
|
Returns:
|
|
T5EncoderInputs: dummy inputs for encoder
|
|
"""
|
|
dtype = torch.int32 if use_int32_inputs else torch.int64
|
|
|
|
input_ids = torch.randint(
|
|
low=0,
|
|
high=vocab_size - 1,
|
|
size=(batch_size, sequence_length),
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
|
|
attention_mask = torch.ones([batch_size, sequence_length], dtype=dtype, device=device)
|
|
if sequence_length >= 2:
|
|
for i in range(batch_size):
|
|
padding_position = random.randint(0, sequence_length - 1)
|
|
attention_mask[i, :padding_position] = 0
|
|
return T5EncoderInputs(input_ids, attention_mask)
|
|
|
|
def to_list(self) -> List:
|
|
input_list = [v for v in [self.input_ids, self.attention_mask] if v is not None]
|
|
return input_list
|
|
|
|
|
|
class T5EncoderHelper:
|
|
@staticmethod
|
|
def export_onnx(
|
|
encoder: T5Encoder,
|
|
device: torch.device,
|
|
onnx_model_path: str,
|
|
verbose: bool = True,
|
|
use_external_data_format: bool = False,
|
|
use_int32_inputs: bool = False,
|
|
):
|
|
"""Export encoder to ONNX
|
|
|
|
Args:
|
|
encoder (T5Encoder): encoder object
|
|
device (torch.device): device of encoder object
|
|
onnx_model_path (str): onnx path
|
|
verbose (bool, optional): print verbose information. Defaults to True.
|
|
use_external_data_format (bool, optional): use external data format or not. Defaults to False.
|
|
"""
|
|
config = encoder.config
|
|
encoder_inputs = T5EncoderInputs.create_dummy(
|
|
batch_size=2,
|
|
sequence_length=4,
|
|
vocab_size=config.vocab_size,
|
|
device=device,
|
|
use_int32_inputs=use_int32_inputs,
|
|
)
|
|
|
|
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
|
temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder.onnx")
|
|
Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
|
torch_onnx_export(
|
|
encoder,
|
|
args=tuple(encoder_inputs.to_list()),
|
|
f=temp_onnx_model_path if use_external_data_format else onnx_model_path,
|
|
export_params=True,
|
|
input_names=["input_ids", "attention_mask"],
|
|
output_names=["hidden_states"],
|
|
dynamic_axes={
|
|
"input_ids": {0: "batch_size", 1: "sequence_length"},
|
|
"attention_mask": {0: "batch_size", 1: "sequence_length"},
|
|
"hidden_states": {0: "batch_size", 1: "sequence_length"},
|
|
},
|
|
opset_version=12,
|
|
do_constant_folding=True,
|
|
use_external_data_format=use_external_data_format,
|
|
verbose=verbose,
|
|
)
|
|
|
|
if use_external_data_format:
|
|
model = onnx.load_model(temp_onnx_model_path, load_external_data=True)
|
|
OnnxModel.save(
|
|
model,
|
|
onnx_model_path,
|
|
save_as_external_data=True,
|
|
all_tensors_to_one_file=True,
|
|
)
|
|
|
|
@staticmethod
|
|
def onnxruntime_inference(ort_session, inputs: T5EncoderInputs):
|
|
"""Run inference of ONNX model."""
|
|
ort_inputs = {
|
|
"input_ids": numpy.ascontiguousarray(inputs.input_ids.cpu().numpy()),
|
|
"attention_mask": numpy.ascontiguousarray(inputs.attention_mask.cpu().numpy()),
|
|
}
|
|
|
|
return ort_session.run(None, ort_inputs)
|
|
|
|
@staticmethod
|
|
def verify_onnx(
|
|
model: T5Encoder, ort_session: InferenceSession, device: torch.device, use_int32_inputs: bool = False
|
|
):
|
|
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
|
|
inputs = T5EncoderInputs.create_dummy(
|
|
batch_size=4,
|
|
sequence_length=11,
|
|
vocab_size=model.config.vocab_size,
|
|
device=device,
|
|
use_int32_inputs=use_int32_inputs,
|
|
)
|
|
input_list = inputs.to_list()
|
|
torch_outputs = model(*input_list)
|
|
|
|
ort_outputs = T5EncoderHelper.onnxruntime_inference(ort_session, inputs)
|
|
|
|
max_diff = numpy.amax(numpy.abs(torch_outputs.cpu().numpy() - ort_outputs[0]))
|
|
|
|
logger.info(f"max_diff={max_diff}")
|
|
|
|
return max_diff
|