You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
441 lines
17 KiB
441 lines
17 KiB
# -------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License. See License.txt in the project root for
|
|
# license information.
|
|
# --------------------------------------------------------------------------
|
|
|
|
import logging
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
from pathlib import Path
|
|
from typing import List, Union
|
|
|
|
import numpy
|
|
import onnx
|
|
import torch
|
|
from past_helper import PastKeyValuesHelper
|
|
from t5_encoder import T5EncoderInputs
|
|
from transformers import MT5Config, T5Config
|
|
|
|
from onnxruntime import InferenceSession
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
from io_binding_helper import TypeHelper # noqa: E402
|
|
from onnx_model import OnnxModel # noqa: E402
|
|
from torch_onnx_export_helper import torch_onnx_export # noqa: E402
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class T5DecoderInit(torch.nn.Module):
|
|
"""A T5 decoder with LM head to create initial past key values.
|
|
This model is only called once during starting decoding.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
decoder: torch.nn.Module,
|
|
lm_head: torch.nn.Module,
|
|
config: Union[T5Config, MT5Config],
|
|
decoder_start_token_id: int = None,
|
|
):
|
|
super().__init__()
|
|
self.decoder = decoder
|
|
self.lm_head = lm_head
|
|
self.config = config
|
|
self.decoder_start_token_id = (
|
|
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
decoder_input_ids: torch.Tensor,
|
|
encoder_attention_mask: torch.Tensor,
|
|
encoder_hidden_states: torch.FloatTensor,
|
|
):
|
|
if decoder_input_ids is None:
|
|
batch_size = encoder_attention_mask.shape[0]
|
|
decoder_input_ids = (
|
|
torch.ones(
|
|
(batch_size, 1),
|
|
dtype=torch.long,
|
|
device=encoder_attention_mask.device,
|
|
)
|
|
* self.decoder_start_token_id
|
|
)
|
|
|
|
decoder_outputs = self.decoder(
|
|
input_ids=decoder_input_ids,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
use_cache=True,
|
|
return_dict=True,
|
|
)
|
|
|
|
sequence_output = decoder_outputs.last_hidden_state
|
|
present_key_values = decoder_outputs.past_key_values
|
|
|
|
sequence_output = sequence_output * (self.config.d_model**-0.5)
|
|
|
|
lm_logits = self.lm_head(sequence_output)
|
|
past_self, past_cross = PastKeyValuesHelper.group_by_self_or_cross(present_key_values)
|
|
return lm_logits, past_self, past_cross
|
|
|
|
|
|
class T5Decoder(torch.nn.Module):
|
|
"""A T5 decoder with LM head and past key values"""
|
|
|
|
def __init__(self, decoder, lm_head, config):
|
|
super().__init__()
|
|
self.decoder = decoder
|
|
self.lm_head = lm_head
|
|
self.config = config
|
|
|
|
def forward(self, decoder_input_ids, encoder_attention_mask, encoder_hidden_states, *past):
|
|
|
|
past_key_values = PastKeyValuesHelper.group_by_layer(past, self.config.num_layers)
|
|
|
|
decoder_outputs = self.decoder(
|
|
input_ids=decoder_input_ids,
|
|
past_key_values=past_key_values,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
use_cache=True,
|
|
return_dict=True,
|
|
)
|
|
|
|
sequence_output = decoder_outputs.last_hidden_state
|
|
present_key_values = decoder_outputs.past_key_values
|
|
|
|
sequence_output = sequence_output * (self.config.d_model**-0.5)
|
|
|
|
lm_logits = self.lm_head(sequence_output)
|
|
present_self, _ = PastKeyValuesHelper.group_by_self_or_cross(present_key_values)
|
|
|
|
# Do not return present_cross since they are identical to corresponding past_cross input
|
|
return lm_logits, present_self
|
|
|
|
|
|
class T5DecoderInputs:
|
|
def __init__(
|
|
self,
|
|
decoder_input_ids,
|
|
encoder_attention_mask,
|
|
encoder_hidden_states,
|
|
past_key_values=None,
|
|
):
|
|
self.decoder_input_ids: torch.LongTensor = decoder_input_ids
|
|
self.encoder_attention_mask: torch.LongTensor = encoder_attention_mask
|
|
self.encoder_hidden_states: Union[torch.FloatTensor, torch.HalfTensor] = encoder_hidden_states
|
|
self.past_key_values: Union[List[torch.FloatTensor], List[torch.HalfTensor], None] = past_key_values
|
|
|
|
@staticmethod
|
|
def create_dummy(
|
|
config: Union[T5Config, MT5Config],
|
|
batch_size: int,
|
|
encode_sequence_length: int,
|
|
past_decode_sequence_length: int,
|
|
device: torch.device,
|
|
float16: bool = False,
|
|
use_int32_inputs: bool = False,
|
|
): # -> T5DecoderInputs:
|
|
"""Create dummy inputs for T5Decoder.
|
|
|
|
Args:
|
|
decoder: decoder
|
|
batch_size (int): batch size
|
|
encode_sequence_length (int): sequence length of input_ids for encoder
|
|
past_decode_sequence_length (int): past sequence length of input_ids for decoder
|
|
device (torch.device): device of output tensors
|
|
float16 (bool): whether the model uses float32 or float16 in input
|
|
use_int32_inputs(bool): whether use int32 instead of int64 for some inputs
|
|
|
|
Returns:
|
|
T5DecoderInputs: dummy inputs for decoder
|
|
"""
|
|
hidden_size: int = config.d_model
|
|
num_attention_heads: int = config.num_heads
|
|
num_layers: int = config.num_layers
|
|
vocab_size: int = config.vocab_size
|
|
|
|
# Do not use head_size = hidden_size / num_attention_heads here.
|
|
# For example, mt5-small, d_model=512 and num_heads=6
|
|
head_size: int = config.d_kv
|
|
|
|
sequence_length: int = 1 # fixed for decoding
|
|
decoder_input_ids = torch.randint(
|
|
low=0,
|
|
high=vocab_size - 1,
|
|
size=(batch_size, sequence_length),
|
|
dtype=(torch.int32 if use_int32_inputs else torch.int64),
|
|
device=device,
|
|
)
|
|
|
|
encoder_inputs = T5EncoderInputs.create_dummy(
|
|
batch_size,
|
|
encode_sequence_length,
|
|
vocab_size,
|
|
device,
|
|
use_int32_inputs=use_int32_inputs,
|
|
)
|
|
|
|
float_type = torch.float16 if float16 else torch.float32
|
|
encoder_hidden_state = torch.rand(
|
|
batch_size,
|
|
encode_sequence_length,
|
|
hidden_size,
|
|
dtype=float_type,
|
|
device=device,
|
|
)
|
|
|
|
if past_decode_sequence_length > 0:
|
|
self_attention_past_shape = [
|
|
batch_size,
|
|
num_attention_heads,
|
|
past_decode_sequence_length,
|
|
head_size,
|
|
]
|
|
cross_attention_past_shape = [
|
|
batch_size,
|
|
num_attention_heads,
|
|
encode_sequence_length,
|
|
head_size,
|
|
]
|
|
|
|
past = []
|
|
for _ in range(2 * num_layers):
|
|
past.append(torch.rand(self_attention_past_shape, dtype=float_type, device=device))
|
|
|
|
for _ in range(2 * num_layers):
|
|
past.append(torch.rand(cross_attention_past_shape, dtype=float_type, device=device))
|
|
else:
|
|
past = None
|
|
|
|
return T5DecoderInputs(decoder_input_ids, encoder_inputs.attention_mask, encoder_hidden_state, past)
|
|
|
|
def to_list(self) -> List:
|
|
input_list = [
|
|
self.decoder_input_ids,
|
|
self.encoder_attention_mask,
|
|
self.encoder_hidden_states,
|
|
]
|
|
if self.past_key_values:
|
|
input_list.extend(self.past_key_values)
|
|
return input_list
|
|
|
|
def to_fp32(self):
|
|
encoder_hidden_state = self.encoder_hidden_states.to(dtype=torch.float32)
|
|
past = [p.to(dtype=torch.float32) for p in self.past_key_values] if self.past_key_values else None
|
|
return T5DecoderInputs(
|
|
self.decoder_input_ids.clone(),
|
|
self.encoder_attention_mask.clone(),
|
|
encoder_hidden_state,
|
|
past,
|
|
)
|
|
|
|
|
|
class T5DecoderHelper:
|
|
@staticmethod
|
|
def export_onnx(
|
|
decoder: Union[T5Decoder, T5DecoderInit],
|
|
device: torch.device,
|
|
onnx_model_path: str,
|
|
verbose: bool = True,
|
|
use_external_data_format: bool = False,
|
|
use_int32_inputs: bool = False,
|
|
):
|
|
"""Export decoder to ONNX
|
|
|
|
Args:
|
|
decoder (Union[T5Decoder, T5DecoderNoPastState]): decoder object
|
|
device (torch.device): device of decoder object
|
|
onnx_model_path (str): onnx path
|
|
verbose (bool, optional): print verbose information. Defaults to True.
|
|
use_external_data_format (bool, optional): use external data format or not. Defaults to False.
|
|
use_int32_inputs (bool, optional): use int32 inputs
|
|
"""
|
|
assert isinstance(decoder, (T5Decoder, T5DecoderInit))
|
|
|
|
inputs = T5DecoderInputs.create_dummy(
|
|
decoder.config,
|
|
batch_size=2,
|
|
encode_sequence_length=3,
|
|
past_decode_sequence_length=5 if isinstance(decoder, T5Decoder) else 0,
|
|
device=device,
|
|
use_int32_inputs=use_int32_inputs,
|
|
)
|
|
input_list = inputs.to_list()
|
|
|
|
past_names = PastKeyValuesHelper.get_past_names(decoder.config.num_layers, present=False)
|
|
present_names = PastKeyValuesHelper.get_past_names(decoder.config.num_layers, present=True)
|
|
present_self_names = present_names[: 2 * decoder.config.num_layers]
|
|
|
|
input_past_names = past_names if isinstance(decoder, T5Decoder) else []
|
|
output_present_names = present_self_names if isinstance(decoder, T5Decoder) else present_names
|
|
output_names = ["logits"] + output_present_names
|
|
|
|
# Shape of input tensors (sequence_length==1):
|
|
# input_ids: (batch_size, sequence_length)
|
|
# encoder_attention_mask: (batch_size, encode_sequence_length)
|
|
# encoder_hidden_states: (batch_size, encode_sequence_length, hidden_size)
|
|
# past_self_*: (batch_size, num_heads, past_decode_sequence_length, head_size)
|
|
# past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
|
|
|
|
# Shape of output tensors:
|
|
# logits: (batch_size, sequence_length, vocab_size)
|
|
# past_self_*: (batch_size, num_heads, past_decode_sequence_length + sequence_length, head_size)
|
|
# past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
|
|
|
|
input_names = ["input_ids"]
|
|
input_names.append("encoder_attention_mask")
|
|
input_names.append("encoder_hidden_states")
|
|
input_names.extend(input_past_names)
|
|
|
|
dynamic_axes = {
|
|
"input_ids": {
|
|
0: "batch_size",
|
|
# 1: 'sequence_length'
|
|
},
|
|
"encoder_attention_mask": {0: "batch_size", 1: "encode_sequence_length"},
|
|
"encoder_hidden_states": {0: "batch_size", 1: "encode_sequence_length"},
|
|
"logits": {
|
|
0: "batch_size",
|
|
# 1: 'sequence_length'
|
|
},
|
|
}
|
|
|
|
for name in input_past_names:
|
|
dynamic_axes[name] = {
|
|
0: "batch_size",
|
|
2: "past_decode_sequence_length" if "self" in name else "encode_sequence_length",
|
|
}
|
|
|
|
for name in output_present_names:
|
|
if "cross" in name:
|
|
dynamic_axes[name] = {0: "batch_size", 2: "encode_sequence_length"}
|
|
else: # self attention past state
|
|
if isinstance(decoder, T5Decoder):
|
|
dynamic_axes[name] = {
|
|
0: "batch_size",
|
|
2: "past_decode_sequence_length + 1",
|
|
}
|
|
else:
|
|
dynamic_axes[name] = {
|
|
0: "batch_size",
|
|
# 2: 'sequence_length'
|
|
}
|
|
|
|
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
|
temp_onnx_model_path = os.path.join(tmp_dir_name, "decoder.onnx")
|
|
Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
|
torch_onnx_export(
|
|
decoder,
|
|
args=tuple(input_list),
|
|
f=temp_onnx_model_path if use_external_data_format else onnx_model_path,
|
|
export_params=True,
|
|
input_names=input_names,
|
|
output_names=output_names,
|
|
dynamic_axes=dynamic_axes,
|
|
opset_version=12,
|
|
do_constant_folding=True,
|
|
use_external_data_format=use_external_data_format,
|
|
verbose=verbose,
|
|
)
|
|
|
|
if use_external_data_format:
|
|
model = onnx.load_model(temp_onnx_model_path, load_external_data=True)
|
|
OnnxModel.save(
|
|
model,
|
|
onnx_model_path,
|
|
save_as_external_data=True,
|
|
all_tensors_to_one_file=True,
|
|
)
|
|
|
|
@staticmethod
|
|
def onnxruntime_inference(ort_session, inputs: T5DecoderInputs):
|
|
"""Run inference of ONNX model."""
|
|
logger.debug("start onnxruntime_inference")
|
|
|
|
ort_inputs = {
|
|
"input_ids": numpy.ascontiguousarray(inputs.decoder_input_ids.cpu().numpy()),
|
|
"encoder_attention_mask": numpy.ascontiguousarray(inputs.encoder_attention_mask.cpu().numpy()),
|
|
"encoder_hidden_states": numpy.ascontiguousarray(inputs.encoder_hidden_states.cpu().numpy()),
|
|
}
|
|
|
|
if inputs.past_key_values:
|
|
assert len(inputs.past_key_values) % 4 == 0
|
|
num_layers = int(len(inputs.past_key_values) / 4)
|
|
past_names = PastKeyValuesHelper.get_past_names(num_layers)
|
|
for i, past_tensor in enumerate(inputs.past_key_values):
|
|
ort_inputs[past_names[i]] = numpy.ascontiguousarray(past_tensor.cpu().numpy())
|
|
|
|
ort_outputs = ort_session.run(None, ort_inputs)
|
|
return ort_outputs
|
|
|
|
@staticmethod
|
|
def verify_onnx(
|
|
model: Union[T5Decoder, T5DecoderInit],
|
|
ort_session: InferenceSession,
|
|
device: torch.device,
|
|
use_int32_inputs: bool,
|
|
max_cases: int = 4,
|
|
):
|
|
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
|
|
float16: bool = TypeHelper.get_input_type(ort_session, "encoder_hidden_states") == "tensor(float16)"
|
|
|
|
test_cases = [(4, 11, 3), (1, 2, 5), (3, 1, 1), (8, 5, 2)]
|
|
test_cases_max_diff = []
|
|
for (
|
|
batch_size,
|
|
encode_sequence_length,
|
|
past_decode_sequence_length,
|
|
) in test_cases[:max_cases]:
|
|
if isinstance(model, T5DecoderInit):
|
|
past_decode_sequence_length = 0
|
|
|
|
inputs = T5DecoderInputs.create_dummy(
|
|
model.config,
|
|
batch_size,
|
|
encode_sequence_length,
|
|
past_decode_sequence_length,
|
|
device=device,
|
|
float16=float16,
|
|
use_int32_inputs=use_int32_inputs,
|
|
)
|
|
|
|
# We use fp32 PyTroch model as baseline even when ONNX model is fp16
|
|
input_list = inputs.to_fp32().to_list()
|
|
|
|
# Run inference of PyTorch model
|
|
with torch.no_grad():
|
|
torch_outputs = model(*input_list)
|
|
|
|
ort_outputs = T5DecoderHelper.onnxruntime_inference(ort_session, inputs)
|
|
|
|
max_diff = numpy.amax(numpy.abs(torch_outputs[0].cpu().numpy() - ort_outputs[0]))
|
|
max_diff_all = max_diff
|
|
logger.debug(f"logits max_diff={max_diff}")
|
|
|
|
for i in range(2 * model.config.num_layers):
|
|
max_diff = numpy.amax(numpy.abs(torch_outputs[1][i].cpu().numpy() - ort_outputs[1 + i]))
|
|
logger.debug(f"self attention past state {i} max_diff={max_diff}")
|
|
max_diff_all = max(max_diff_all, max_diff)
|
|
|
|
if isinstance(model, T5DecoderInit):
|
|
for i in range(2 * model.config.num_layers):
|
|
max_diff = numpy.amax(
|
|
numpy.abs(torch_outputs[2][i].cpu().numpy() - ort_outputs[1 + 2 * model.config.num_layers + i])
|
|
)
|
|
logger.debug(f"cross attention past state {i} max_diff={max_diff}")
|
|
max_diff_all = max(max_diff_all, max_diff)
|
|
|
|
test_cases_max_diff.append(max_diff_all)
|
|
logger.info(
|
|
f"batch_size={batch_size}, encode_sequence_length={encode_sequence_length}, "
|
|
+ f"past_decode_sequence_length={past_decode_sequence_length}, max_diff={max_diff_all}"
|
|
)
|
|
|
|
return max_diff_all
|