m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

207 lines
7.3 KiB

6 months ago
  1. import logging
  2. from typing import Dict, List, Union
  3. import numpy
  4. import torch
  5. from onnxruntime import InferenceSession
  6. logger = logging.getLogger(__name__)
  7. class TypeHelper:
  8. @staticmethod
  9. def get_input_type(ort_session: InferenceSession, name: str) -> str:
  10. for i, input in enumerate(ort_session.get_inputs()):
  11. if input.name == name:
  12. return input.type
  13. raise ValueError(f"input name {name} not found")
  14. @staticmethod
  15. def get_output_type(ort_session, name: str) -> str:
  16. for i, output in enumerate(ort_session.get_outputs()):
  17. if output.name == name:
  18. return output.type
  19. raise ValueError(f"output name {name} not found")
  20. @staticmethod
  21. def ort_type_to_numpy_type(ort_type: str):
  22. ort_type_to_numpy_type_map = {
  23. "tensor(int64)": numpy.longlong,
  24. "tensor(int32)": numpy.intc,
  25. "tensor(float)": numpy.float32,
  26. "tensor(float16)": numpy.float16,
  27. "tensor(bool)": bool,
  28. }
  29. if ort_type not in ort_type_to_numpy_type_map:
  30. raise ValueError(f"{ort_type} not found in map")
  31. return ort_type_to_numpy_type_map[ort_type]
  32. @staticmethod
  33. def ort_type_to_torch_type(ort_type: str):
  34. ort_type_to_torch_type_map = {
  35. "tensor(int64)": torch.int64,
  36. "tensor(int32)": torch.int32,
  37. "tensor(float)": torch.float32,
  38. "tensor(float16)": torch.float16,
  39. "tensor(bool)": torch.bool,
  40. }
  41. if ort_type not in ort_type_to_torch_type_map:
  42. raise ValueError(f"{ort_type} not found in map")
  43. return ort_type_to_torch_type_map[ort_type]
  44. @staticmethod
  45. def numpy_type_to_torch_type(numpy_type: numpy.dtype):
  46. numpy_type_to_torch_type_map = {
  47. numpy.longlong: torch.int64,
  48. numpy.intc: torch.int32,
  49. numpy.int32: torch.int32,
  50. numpy.float32: torch.float32,
  51. numpy.float16: torch.float16,
  52. bool: torch.bool,
  53. }
  54. if numpy_type not in numpy_type_to_torch_type_map:
  55. raise ValueError(f"{numpy_type} not found in map")
  56. return numpy_type_to_torch_type_map[numpy_type]
  57. @staticmethod
  58. def torch_type_to_numpy_type(torch_type: torch.dtype):
  59. torch_type_to_numpy_type_map = {
  60. torch.int64: numpy.longlong,
  61. torch.int32: numpy.intc,
  62. torch.float32: numpy.float32,
  63. torch.float16: numpy.float16,
  64. torch.bool: bool,
  65. }
  66. if torch_type not in torch_type_to_numpy_type_map:
  67. raise ValueError(f"{torch_type} not found in map")
  68. return torch_type_to_numpy_type_map[torch_type]
  69. @staticmethod
  70. def get_io_numpy_type_map(ort_session: InferenceSession) -> Dict[str, numpy.dtype]:
  71. """Create a mapping from input/output name to numpy data type"""
  72. name_to_numpy_type = {}
  73. for input in ort_session.get_inputs():
  74. name_to_numpy_type[input.name] = TypeHelper.ort_type_to_numpy_type(input.type)
  75. for output in ort_session.get_outputs():
  76. name_to_numpy_type[output.name] = TypeHelper.ort_type_to_numpy_type(output.type)
  77. return name_to_numpy_type
  78. class IOBindingHelper:
  79. @staticmethod
  80. def get_output_buffers(ort_session: InferenceSession, output_shapes, device):
  81. """Returns a dictionary of output name as key, and 1D tensor as value. The tensor has enough space for given shape."""
  82. output_buffers = {}
  83. for name, shape in output_shapes.items():
  84. ort_type = TypeHelper.get_output_type(ort_session, name)
  85. torch_type = TypeHelper.ort_type_to_torch_type(ort_type)
  86. output_buffers[name] = torch.empty(numpy.prod(shape), dtype=torch_type, device=device)
  87. return output_buffers
  88. @staticmethod
  89. def prepare_io_binding(
  90. ort_session,
  91. input_ids: torch.Tensor,
  92. position_ids: torch.Tensor,
  93. attention_mask: torch.Tensor,
  94. past: List[torch.Tensor],
  95. output_buffers,
  96. output_shapes,
  97. name_to_np_type=None,
  98. ):
  99. """Returnas IO binding object for a session."""
  100. if name_to_np_type is None:
  101. name_to_np_type = TypeHelper.get_io_numpy_type_map(ort_session)
  102. # Bind inputs and outputs to onnxruntime session
  103. io_binding = ort_session.io_binding()
  104. # Bind inputs
  105. assert input_ids.is_contiguous()
  106. io_binding.bind_input(
  107. "input_ids",
  108. input_ids.device.type,
  109. 0,
  110. name_to_np_type["input_ids"],
  111. list(input_ids.size()),
  112. input_ids.data_ptr(),
  113. )
  114. if past is not None:
  115. for i, past_i in enumerate(past):
  116. assert past_i.is_contiguous()
  117. data_ptr = past_i.data_ptr()
  118. if data_ptr == 0:
  119. # When past_sequence_length is 0, its data_ptr will be zero. IO Binding asserts that data_ptr shall not be zero.
  120. # Here we workaround and pass data pointer of input_ids. Actual data is not used for past so it does not matter.
  121. data_ptr = input_ids.data_ptr()
  122. io_binding.bind_input(
  123. f"past_{i}",
  124. past_i.device.type,
  125. 0,
  126. name_to_np_type[f"past_{i}"],
  127. list(past_i.size()),
  128. data_ptr,
  129. )
  130. if attention_mask is not None:
  131. assert attention_mask.is_contiguous()
  132. io_binding.bind_input(
  133. "attention_mask",
  134. attention_mask.device.type,
  135. 0,
  136. name_to_np_type["attention_mask"],
  137. list(attention_mask.size()),
  138. attention_mask.data_ptr(),
  139. )
  140. if position_ids is not None:
  141. assert position_ids.is_contiguous()
  142. io_binding.bind_input(
  143. "position_ids",
  144. position_ids.device.type,
  145. 0,
  146. name_to_np_type["position_ids"],
  147. list(position_ids.size()),
  148. position_ids.data_ptr(),
  149. )
  150. # Bind outputs
  151. for output in ort_session.get_outputs():
  152. output_name = output.name
  153. output_buffer = output_buffers[output_name]
  154. logger.debug(f"{output_name} device type={output_buffer.device.type} shape={list(output_buffer.size())}")
  155. io_binding.bind_output(
  156. output_name,
  157. output_buffer.device.type,
  158. 0,
  159. name_to_np_type[output_name],
  160. output_shapes[output_name],
  161. output_buffer.data_ptr(),
  162. )
  163. return io_binding
  164. @staticmethod
  165. def get_outputs_from_io_binding_buffer(ort_session, output_buffers, output_shapes, return_numpy=True):
  166. """Copy results to cpu. Returns a list of numpy array."""
  167. ort_outputs = []
  168. for output in ort_session.get_outputs():
  169. output_name = output.name
  170. buffer = output_buffers[output_name]
  171. shape = output_shapes[output_name]
  172. copy_tensor = buffer[0 : numpy.prod(shape)].reshape(shape).clone().detach()
  173. if return_numpy:
  174. ort_outputs.append(copy_tensor.cpu().numpy())
  175. else:
  176. ort_outputs.append(copy_tensor)
  177. return ort_outputs