You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
207 lines
7.3 KiB
207 lines
7.3 KiB
import logging
|
|
from typing import Dict, List, Union
|
|
|
|
import numpy
|
|
import torch
|
|
|
|
from onnxruntime import InferenceSession
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TypeHelper:
|
|
@staticmethod
|
|
def get_input_type(ort_session: InferenceSession, name: str) -> str:
|
|
for i, input in enumerate(ort_session.get_inputs()):
|
|
if input.name == name:
|
|
return input.type
|
|
raise ValueError(f"input name {name} not found")
|
|
|
|
@staticmethod
|
|
def get_output_type(ort_session, name: str) -> str:
|
|
for i, output in enumerate(ort_session.get_outputs()):
|
|
if output.name == name:
|
|
return output.type
|
|
|
|
raise ValueError(f"output name {name} not found")
|
|
|
|
@staticmethod
|
|
def ort_type_to_numpy_type(ort_type: str):
|
|
ort_type_to_numpy_type_map = {
|
|
"tensor(int64)": numpy.longlong,
|
|
"tensor(int32)": numpy.intc,
|
|
"tensor(float)": numpy.float32,
|
|
"tensor(float16)": numpy.float16,
|
|
"tensor(bool)": bool,
|
|
}
|
|
if ort_type not in ort_type_to_numpy_type_map:
|
|
raise ValueError(f"{ort_type} not found in map")
|
|
|
|
return ort_type_to_numpy_type_map[ort_type]
|
|
|
|
@staticmethod
|
|
def ort_type_to_torch_type(ort_type: str):
|
|
ort_type_to_torch_type_map = {
|
|
"tensor(int64)": torch.int64,
|
|
"tensor(int32)": torch.int32,
|
|
"tensor(float)": torch.float32,
|
|
"tensor(float16)": torch.float16,
|
|
"tensor(bool)": torch.bool,
|
|
}
|
|
if ort_type not in ort_type_to_torch_type_map:
|
|
raise ValueError(f"{ort_type} not found in map")
|
|
|
|
return ort_type_to_torch_type_map[ort_type]
|
|
|
|
@staticmethod
|
|
def numpy_type_to_torch_type(numpy_type: numpy.dtype):
|
|
numpy_type_to_torch_type_map = {
|
|
numpy.longlong: torch.int64,
|
|
numpy.intc: torch.int32,
|
|
numpy.int32: torch.int32,
|
|
numpy.float32: torch.float32,
|
|
numpy.float16: torch.float16,
|
|
bool: torch.bool,
|
|
}
|
|
if numpy_type not in numpy_type_to_torch_type_map:
|
|
raise ValueError(f"{numpy_type} not found in map")
|
|
|
|
return numpy_type_to_torch_type_map[numpy_type]
|
|
|
|
@staticmethod
|
|
def torch_type_to_numpy_type(torch_type: torch.dtype):
|
|
torch_type_to_numpy_type_map = {
|
|
torch.int64: numpy.longlong,
|
|
torch.int32: numpy.intc,
|
|
torch.float32: numpy.float32,
|
|
torch.float16: numpy.float16,
|
|
torch.bool: bool,
|
|
}
|
|
if torch_type not in torch_type_to_numpy_type_map:
|
|
raise ValueError(f"{torch_type} not found in map")
|
|
|
|
return torch_type_to_numpy_type_map[torch_type]
|
|
|
|
@staticmethod
|
|
def get_io_numpy_type_map(ort_session: InferenceSession) -> Dict[str, numpy.dtype]:
|
|
"""Create a mapping from input/output name to numpy data type"""
|
|
name_to_numpy_type = {}
|
|
for input in ort_session.get_inputs():
|
|
name_to_numpy_type[input.name] = TypeHelper.ort_type_to_numpy_type(input.type)
|
|
|
|
for output in ort_session.get_outputs():
|
|
name_to_numpy_type[output.name] = TypeHelper.ort_type_to_numpy_type(output.type)
|
|
return name_to_numpy_type
|
|
|
|
|
|
class IOBindingHelper:
|
|
@staticmethod
|
|
def get_output_buffers(ort_session: InferenceSession, output_shapes, device):
|
|
"""Returns a dictionary of output name as key, and 1D tensor as value. The tensor has enough space for given shape."""
|
|
output_buffers = {}
|
|
for name, shape in output_shapes.items():
|
|
ort_type = TypeHelper.get_output_type(ort_session, name)
|
|
torch_type = TypeHelper.ort_type_to_torch_type(ort_type)
|
|
output_buffers[name] = torch.empty(numpy.prod(shape), dtype=torch_type, device=device)
|
|
return output_buffers
|
|
|
|
@staticmethod
|
|
def prepare_io_binding(
|
|
ort_session,
|
|
input_ids: torch.Tensor,
|
|
position_ids: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
past: List[torch.Tensor],
|
|
output_buffers,
|
|
output_shapes,
|
|
name_to_np_type=None,
|
|
):
|
|
"""Returnas IO binding object for a session."""
|
|
if name_to_np_type is None:
|
|
name_to_np_type = TypeHelper.get_io_numpy_type_map(ort_session)
|
|
|
|
# Bind inputs and outputs to onnxruntime session
|
|
io_binding = ort_session.io_binding()
|
|
|
|
# Bind inputs
|
|
assert input_ids.is_contiguous()
|
|
io_binding.bind_input(
|
|
"input_ids",
|
|
input_ids.device.type,
|
|
0,
|
|
name_to_np_type["input_ids"],
|
|
list(input_ids.size()),
|
|
input_ids.data_ptr(),
|
|
)
|
|
|
|
if past is not None:
|
|
for i, past_i in enumerate(past):
|
|
assert past_i.is_contiguous()
|
|
|
|
data_ptr = past_i.data_ptr()
|
|
if data_ptr == 0:
|
|
# When past_sequence_length is 0, its data_ptr will be zero. IO Binding asserts that data_ptr shall not be zero.
|
|
# Here we workaround and pass data pointer of input_ids. Actual data is not used for past so it does not matter.
|
|
data_ptr = input_ids.data_ptr()
|
|
|
|
io_binding.bind_input(
|
|
f"past_{i}",
|
|
past_i.device.type,
|
|
0,
|
|
name_to_np_type[f"past_{i}"],
|
|
list(past_i.size()),
|
|
data_ptr,
|
|
)
|
|
|
|
if attention_mask is not None:
|
|
assert attention_mask.is_contiguous()
|
|
io_binding.bind_input(
|
|
"attention_mask",
|
|
attention_mask.device.type,
|
|
0,
|
|
name_to_np_type["attention_mask"],
|
|
list(attention_mask.size()),
|
|
attention_mask.data_ptr(),
|
|
)
|
|
|
|
if position_ids is not None:
|
|
assert position_ids.is_contiguous()
|
|
io_binding.bind_input(
|
|
"position_ids",
|
|
position_ids.device.type,
|
|
0,
|
|
name_to_np_type["position_ids"],
|
|
list(position_ids.size()),
|
|
position_ids.data_ptr(),
|
|
)
|
|
|
|
# Bind outputs
|
|
for output in ort_session.get_outputs():
|
|
output_name = output.name
|
|
output_buffer = output_buffers[output_name]
|
|
logger.debug(f"{output_name} device type={output_buffer.device.type} shape={list(output_buffer.size())}")
|
|
io_binding.bind_output(
|
|
output_name,
|
|
output_buffer.device.type,
|
|
0,
|
|
name_to_np_type[output_name],
|
|
output_shapes[output_name],
|
|
output_buffer.data_ptr(),
|
|
)
|
|
|
|
return io_binding
|
|
|
|
@staticmethod
|
|
def get_outputs_from_io_binding_buffer(ort_session, output_buffers, output_shapes, return_numpy=True):
|
|
"""Copy results to cpu. Returns a list of numpy array."""
|
|
ort_outputs = []
|
|
for output in ort_session.get_outputs():
|
|
output_name = output.name
|
|
buffer = output_buffers[output_name]
|
|
shape = output_shapes[output_name]
|
|
copy_tensor = buffer[0 : numpy.prod(shape)].reshape(shape).clone().detach()
|
|
if return_numpy:
|
|
ort_outputs.append(copy_tensor.cpu().numpy())
|
|
else:
|
|
ort_outputs.append(copy_tensor)
|
|
return ort_outputs
|