m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

207 lines
7.3 KiB

import logging
from typing import Dict, List, Union
import numpy
import torch
from onnxruntime import InferenceSession
logger = logging.getLogger(__name__)
class TypeHelper:
@staticmethod
def get_input_type(ort_session: InferenceSession, name: str) -> str:
for i, input in enumerate(ort_session.get_inputs()):
if input.name == name:
return input.type
raise ValueError(f"input name {name} not found")
@staticmethod
def get_output_type(ort_session, name: str) -> str:
for i, output in enumerate(ort_session.get_outputs()):
if output.name == name:
return output.type
raise ValueError(f"output name {name} not found")
@staticmethod
def ort_type_to_numpy_type(ort_type: str):
ort_type_to_numpy_type_map = {
"tensor(int64)": numpy.longlong,
"tensor(int32)": numpy.intc,
"tensor(float)": numpy.float32,
"tensor(float16)": numpy.float16,
"tensor(bool)": bool,
}
if ort_type not in ort_type_to_numpy_type_map:
raise ValueError(f"{ort_type} not found in map")
return ort_type_to_numpy_type_map[ort_type]
@staticmethod
def ort_type_to_torch_type(ort_type: str):
ort_type_to_torch_type_map = {
"tensor(int64)": torch.int64,
"tensor(int32)": torch.int32,
"tensor(float)": torch.float32,
"tensor(float16)": torch.float16,
"tensor(bool)": torch.bool,
}
if ort_type not in ort_type_to_torch_type_map:
raise ValueError(f"{ort_type} not found in map")
return ort_type_to_torch_type_map[ort_type]
@staticmethod
def numpy_type_to_torch_type(numpy_type: numpy.dtype):
numpy_type_to_torch_type_map = {
numpy.longlong: torch.int64,
numpy.intc: torch.int32,
numpy.int32: torch.int32,
numpy.float32: torch.float32,
numpy.float16: torch.float16,
bool: torch.bool,
}
if numpy_type not in numpy_type_to_torch_type_map:
raise ValueError(f"{numpy_type} not found in map")
return numpy_type_to_torch_type_map[numpy_type]
@staticmethod
def torch_type_to_numpy_type(torch_type: torch.dtype):
torch_type_to_numpy_type_map = {
torch.int64: numpy.longlong,
torch.int32: numpy.intc,
torch.float32: numpy.float32,
torch.float16: numpy.float16,
torch.bool: bool,
}
if torch_type not in torch_type_to_numpy_type_map:
raise ValueError(f"{torch_type} not found in map")
return torch_type_to_numpy_type_map[torch_type]
@staticmethod
def get_io_numpy_type_map(ort_session: InferenceSession) -> Dict[str, numpy.dtype]:
"""Create a mapping from input/output name to numpy data type"""
name_to_numpy_type = {}
for input in ort_session.get_inputs():
name_to_numpy_type[input.name] = TypeHelper.ort_type_to_numpy_type(input.type)
for output in ort_session.get_outputs():
name_to_numpy_type[output.name] = TypeHelper.ort_type_to_numpy_type(output.type)
return name_to_numpy_type
class IOBindingHelper:
@staticmethod
def get_output_buffers(ort_session: InferenceSession, output_shapes, device):
"""Returns a dictionary of output name as key, and 1D tensor as value. The tensor has enough space for given shape."""
output_buffers = {}
for name, shape in output_shapes.items():
ort_type = TypeHelper.get_output_type(ort_session, name)
torch_type = TypeHelper.ort_type_to_torch_type(ort_type)
output_buffers[name] = torch.empty(numpy.prod(shape), dtype=torch_type, device=device)
return output_buffers
@staticmethod
def prepare_io_binding(
ort_session,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: torch.Tensor,
past: List[torch.Tensor],
output_buffers,
output_shapes,
name_to_np_type=None,
):
"""Returnas IO binding object for a session."""
if name_to_np_type is None:
name_to_np_type = TypeHelper.get_io_numpy_type_map(ort_session)
# Bind inputs and outputs to onnxruntime session
io_binding = ort_session.io_binding()
# Bind inputs
assert input_ids.is_contiguous()
io_binding.bind_input(
"input_ids",
input_ids.device.type,
0,
name_to_np_type["input_ids"],
list(input_ids.size()),
input_ids.data_ptr(),
)
if past is not None:
for i, past_i in enumerate(past):
assert past_i.is_contiguous()
data_ptr = past_i.data_ptr()
if data_ptr == 0:
# When past_sequence_length is 0, its data_ptr will be zero. IO Binding asserts that data_ptr shall not be zero.
# Here we workaround and pass data pointer of input_ids. Actual data is not used for past so it does not matter.
data_ptr = input_ids.data_ptr()
io_binding.bind_input(
f"past_{i}",
past_i.device.type,
0,
name_to_np_type[f"past_{i}"],
list(past_i.size()),
data_ptr,
)
if attention_mask is not None:
assert attention_mask.is_contiguous()
io_binding.bind_input(
"attention_mask",
attention_mask.device.type,
0,
name_to_np_type["attention_mask"],
list(attention_mask.size()),
attention_mask.data_ptr(),
)
if position_ids is not None:
assert position_ids.is_contiguous()
io_binding.bind_input(
"position_ids",
position_ids.device.type,
0,
name_to_np_type["position_ids"],
list(position_ids.size()),
position_ids.data_ptr(),
)
# Bind outputs
for output in ort_session.get_outputs():
output_name = output.name
output_buffer = output_buffers[output_name]
logger.debug(f"{output_name} device type={output_buffer.device.type} shape={list(output_buffer.size())}")
io_binding.bind_output(
output_name,
output_buffer.device.type,
0,
name_to_np_type[output_name],
output_shapes[output_name],
output_buffer.data_ptr(),
)
return io_binding
@staticmethod
def get_outputs_from_io_binding_buffer(ort_session, output_buffers, output_shapes, return_numpy=True):
"""Copy results to cpu. Returns a list of numpy array."""
ort_outputs = []
for output in ort_session.get_outputs():
output_name = output.name
buffer = output_buffers[output_name]
shape = output_shapes[output_name]
copy_tensor = buffer[0 : numpy.prod(shape)].reshape(shape).clone().detach()
if return_numpy:
ort_outputs.append(copy_tensor.cpu().numpy())
else:
ort_outputs.append(copy_tensor)
return ort_outputs