图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

207 lines
7.3 KiB

  1. import logging
  2. from typing import Dict, List, Union
  3. import numpy
  4. import torch
  5. from onnxruntime import InferenceSession
  6. logger = logging.getLogger(__name__)
  7. class TypeHelper:
  8. @staticmethod
  9. def get_input_type(ort_session: InferenceSession, name: str) -> str:
  10. for i, input in enumerate(ort_session.get_inputs()):
  11. if input.name == name:
  12. return input.type
  13. raise ValueError(f"input name {name} not found")
  14. @staticmethod
  15. def get_output_type(ort_session, name: str) -> str:
  16. for i, output in enumerate(ort_session.get_outputs()):
  17. if output.name == name:
  18. return output.type
  19. raise ValueError(f"output name {name} not found")
  20. @staticmethod
  21. def ort_type_to_numpy_type(ort_type: str):
  22. ort_type_to_numpy_type_map = {
  23. "tensor(int64)": numpy.longlong,
  24. "tensor(int32)": numpy.intc,
  25. "tensor(float)": numpy.float32,
  26. "tensor(float16)": numpy.float16,
  27. "tensor(bool)": bool,
  28. }
  29. if ort_type not in ort_type_to_numpy_type_map:
  30. raise ValueError(f"{ort_type} not found in map")
  31. return ort_type_to_numpy_type_map[ort_type]
  32. @staticmethod
  33. def ort_type_to_torch_type(ort_type: str):
  34. ort_type_to_torch_type_map = {
  35. "tensor(int64)": torch.int64,
  36. "tensor(int32)": torch.int32,
  37. "tensor(float)": torch.float32,
  38. "tensor(float16)": torch.float16,
  39. "tensor(bool)": torch.bool,
  40. }
  41. if ort_type not in ort_type_to_torch_type_map:
  42. raise ValueError(f"{ort_type} not found in map")
  43. return ort_type_to_torch_type_map[ort_type]
  44. @staticmethod
  45. def numpy_type_to_torch_type(numpy_type: numpy.dtype):
  46. numpy_type_to_torch_type_map = {
  47. numpy.longlong: torch.int64,
  48. numpy.intc: torch.int32,
  49. numpy.int32: torch.int32,
  50. numpy.float32: torch.float32,
  51. numpy.float16: torch.float16,
  52. bool: torch.bool,
  53. }
  54. if numpy_type not in numpy_type_to_torch_type_map:
  55. raise ValueError(f"{numpy_type} not found in map")
  56. return numpy_type_to_torch_type_map[numpy_type]
  57. @staticmethod
  58. def torch_type_to_numpy_type(torch_type: torch.dtype):
  59. torch_type_to_numpy_type_map = {
  60. torch.int64: numpy.longlong,
  61. torch.int32: numpy.intc,
  62. torch.float32: numpy.float32,
  63. torch.float16: numpy.float16,
  64. torch.bool: bool,
  65. }
  66. if torch_type not in torch_type_to_numpy_type_map:
  67. raise ValueError(f"{torch_type} not found in map")
  68. return torch_type_to_numpy_type_map[torch_type]
  69. @staticmethod
  70. def get_io_numpy_type_map(ort_session: InferenceSession) -> Dict[str, numpy.dtype]:
  71. """Create a mapping from input/output name to numpy data type"""
  72. name_to_numpy_type = {}
  73. for input in ort_session.get_inputs():
  74. name_to_numpy_type[input.name] = TypeHelper.ort_type_to_numpy_type(input.type)
  75. for output in ort_session.get_outputs():
  76. name_to_numpy_type[output.name] = TypeHelper.ort_type_to_numpy_type(output.type)
  77. return name_to_numpy_type
  78. class IOBindingHelper:
  79. @staticmethod
  80. def get_output_buffers(ort_session: InferenceSession, output_shapes, device):
  81. """Returns a dictionary of output name as key, and 1D tensor as value. The tensor has enough space for given shape."""
  82. output_buffers = {}
  83. for name, shape in output_shapes.items():
  84. ort_type = TypeHelper.get_output_type(ort_session, name)
  85. torch_type = TypeHelper.ort_type_to_torch_type(ort_type)
  86. output_buffers[name] = torch.empty(numpy.prod(shape), dtype=torch_type, device=device)
  87. return output_buffers
  88. @staticmethod
  89. def prepare_io_binding(
  90. ort_session,
  91. input_ids: torch.Tensor,
  92. position_ids: torch.Tensor,
  93. attention_mask: torch.Tensor,
  94. past: List[torch.Tensor],
  95. output_buffers,
  96. output_shapes,
  97. name_to_np_type=None,
  98. ):
  99. """Returnas IO binding object for a session."""
  100. if name_to_np_type is None:
  101. name_to_np_type = TypeHelper.get_io_numpy_type_map(ort_session)
  102. # Bind inputs and outputs to onnxruntime session
  103. io_binding = ort_session.io_binding()
  104. # Bind inputs
  105. assert input_ids.is_contiguous()
  106. io_binding.bind_input(
  107. "input_ids",
  108. input_ids.device.type,
  109. 0,
  110. name_to_np_type["input_ids"],
  111. list(input_ids.size()),
  112. input_ids.data_ptr(),
  113. )
  114. if past is not None:
  115. for i, past_i in enumerate(past):
  116. assert past_i.is_contiguous()
  117. data_ptr = past_i.data_ptr()
  118. if data_ptr == 0:
  119. # When past_sequence_length is 0, its data_ptr will be zero. IO Binding asserts that data_ptr shall not be zero.
  120. # Here we workaround and pass data pointer of input_ids. Actual data is not used for past so it does not matter.
  121. data_ptr = input_ids.data_ptr()
  122. io_binding.bind_input(
  123. f"past_{i}",
  124. past_i.device.type,
  125. 0,
  126. name_to_np_type[f"past_{i}"],
  127. list(past_i.size()),
  128. data_ptr,
  129. )
  130. if attention_mask is not None:
  131. assert attention_mask.is_contiguous()
  132. io_binding.bind_input(
  133. "attention_mask",
  134. attention_mask.device.type,
  135. 0,
  136. name_to_np_type["attention_mask"],
  137. list(attention_mask.size()),
  138. attention_mask.data_ptr(),
  139. )
  140. if position_ids is not None:
  141. assert position_ids.is_contiguous()
  142. io_binding.bind_input(
  143. "position_ids",
  144. position_ids.device.type,
  145. 0,
  146. name_to_np_type["position_ids"],
  147. list(position_ids.size()),
  148. position_ids.data_ptr(),
  149. )
  150. # Bind outputs
  151. for output in ort_session.get_outputs():
  152. output_name = output.name
  153. output_buffer = output_buffers[output_name]
  154. logger.debug(f"{output_name} device type={output_buffer.device.type} shape={list(output_buffer.size())}")
  155. io_binding.bind_output(
  156. output_name,
  157. output_buffer.device.type,
  158. 0,
  159. name_to_np_type[output_name],
  160. output_shapes[output_name],
  161. output_buffer.data_ptr(),
  162. )
  163. return io_binding
  164. @staticmethod
  165. def get_outputs_from_io_binding_buffer(ort_session, output_buffers, output_shapes, return_numpy=True):
  166. """Copy results to cpu. Returns a list of numpy array."""
  167. ort_outputs = []
  168. for output in ort_session.get_outputs():
  169. output_name = output.name
  170. buffer = output_buffers[output_name]
  171. shape = output_shapes[output_name]
  172. copy_tensor = buffer[0 : numpy.prod(shape)].reshape(shape).clone().detach()
  173. if return_numpy:
  174. ort_outputs.append(copy_tensor.cpu().numpy())
  175. else:
  176. ort_outputs.append(copy_tensor)
  177. return ort_outputs