m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

364 lines
14 KiB

7 months ago
  1. # Copyright (c) Microsoft Corporation. All rights reserved.
  2. # Licensed under the MIT License.
  3. import logging
  4. import pathlib
  5. import onnx
  6. from onnx import version_converter
  7. import onnxruntime as ort
  8. def iterate_graph_per_node_func(graph, per_node_func, **func_args):
  9. """
  10. Iterate the graph including subgraphs calling the per_node_func for each node.
  11. :param graph: Graph to iterate
  12. :param per_node_func: Function to call for each node. Signature is fn(node: onnx:NodeProto, **kwargs)
  13. :param func_args: The keyword args to pass through.
  14. """
  15. for node in graph.node:
  16. per_node_func(node, **func_args)
  17. # recurse into subgraph for control flow nodes (Scan/Loop/If)
  18. for attr in node.attribute:
  19. if attr.HasField("g"):
  20. iterate_graph_per_node_func(attr.g, per_node_func, **func_args)
  21. def iterate_graph_per_graph_func(graph, per_graph_func, **func_args):
  22. """
  23. Iterate the graph including subgraphs calling the per_graph_func for each Graph.
  24. :param graph: Graph to iterate
  25. :param per_graph_func: Function to call for each graph. Signature is fn(graph: onnx:GraphProto, **kwargs)
  26. :param func_args: The keyword args to pass through.
  27. """
  28. per_graph_func(graph, **func_args)
  29. for node in graph.node:
  30. # recurse into subgraph for control flow nodes (Scan/Loop/If)
  31. for attr in node.attribute:
  32. if attr.HasField("g"):
  33. iterate_graph_per_graph_func(attr.g, per_graph_func, **func_args)
  34. def get_opsets_imported(model: onnx.ModelProto):
  35. """
  36. Get the opsets imported by the model
  37. :param model: Model to check.
  38. :return: Map of domain to opset.
  39. """
  40. opsets = {}
  41. for entry in model.opset_import:
  42. # if empty it's ai.onnx
  43. domain = entry.domain or "ai.onnx"
  44. opsets[domain] = entry.version
  45. return opsets
  46. def update_onnx_opset(
  47. model_path: pathlib.Path, opset: int, out_path: pathlib.Path = None, logger: logging.Logger = None
  48. ):
  49. """
  50. Helper to update the opset of a model using onnx version_converter. Target opset must be greater than current opset.
  51. :param model_path: Path to model to update
  52. :param opset: Opset to update model to
  53. :param out_path: Optional output path for updated model to be saved to.
  54. :param logger: Optional logger for diagnostic output
  55. :returns: Updated onnx.ModelProto
  56. """
  57. model_path_str = str(model_path.resolve(strict=True))
  58. if logger:
  59. logger.info("Updating %s to opset %d", model_path_str, opset)
  60. model = onnx.load(model_path_str)
  61. new_model = version_converter.convert_version(model, opset)
  62. if out_path:
  63. onnx.save(new_model, str(out_path))
  64. if logger:
  65. logger.info("Saved updated model to %s", out_path)
  66. return new_model
  67. def optimize_model(
  68. model_path: pathlib.Path,
  69. output_path: pathlib.Path,
  70. level: ort.GraphOptimizationLevel = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC,
  71. log_level: int = 3,
  72. ):
  73. """
  74. Optimize an ONNX model using ONNX Runtime to the specified level
  75. :param model_path: Path to ONNX model
  76. :param output_path: Path to save optimized model to.
  77. :param level: onnxruntime.GraphOptimizationLevel to use. Default is ORT_ENABLE_BASIC.
  78. :param log_level: Log level. Defaults to Error (3) so we don't get output about unused initializers being removed.
  79. Warning (2) or Info (1) may be desirable in some scenarios.
  80. """
  81. so = ort.SessionOptions()
  82. so.optimized_model_filepath = str(output_path.resolve())
  83. so.graph_optimization_level = level
  84. so.log_severity_level = log_level
  85. # create session to optimize. this will write the updated model to output_path
  86. _ = ort.InferenceSession(str(model_path.resolve(strict=True)), so, providers=["CPUExecutionProvider"])
  87. def _replace_symbolic_dim_value(graph: onnx.GraphProto, **kwargs):
  88. param_to_replace = kwargs["dim_param"]
  89. value = kwargs["value"]
  90. def update_dim_values(value_infos):
  91. for vi in value_infos:
  92. if vi.type.HasField("tensor_type"):
  93. shape = vi.type.tensor_type.shape
  94. if shape:
  95. for dim in shape.dim:
  96. if dim.HasField("dim_param") and dim.dim_param == param_to_replace:
  97. dim.Clear()
  98. dim.dim_value = value
  99. update_dim_values(graph.input)
  100. update_dim_values(graph.output)
  101. update_dim_values(graph.value_info)
  102. def _remove_invalid_dim_values_impl(graph: onnx.GraphProto):
  103. def clear_invalid_values(value):
  104. if value.type.HasField("tensor_type"):
  105. shape = value.type.tensor_type.shape
  106. if shape:
  107. for dim in shape.dim:
  108. if dim.HasField("dim_value") and dim.dim_value < 1:
  109. dim.Clear()
  110. for i in graph.input:
  111. clear_invalid_values(i)
  112. for o in graph.output:
  113. clear_invalid_values(o)
  114. for vi in graph.value_info:
  115. clear_invalid_values(vi)
  116. def remove_invalid_dim_values(graph: onnx.GraphProto):
  117. """
  118. Iterate the graph and subgraphs, unsetting any dim_value entries that have a value of less than 1.
  119. These are typically erroneously inserted by a converter to represent a dynamic dimension.
  120. :param graph: GraphProto to update
  121. """
  122. iterate_graph_per_graph_func(graph, _remove_invalid_dim_values_impl)
  123. def make_dim_param_fixed(graph: onnx.GraphProto, param_name: str, value: int):
  124. """
  125. Iterate all values in the graph, replacing dim_param in a tensor shape with the provided value.
  126. :param graph: GraphProto to update
  127. :param param_name: dim_param to set
  128. :param value: value to use
  129. """
  130. iterate_graph_per_graph_func(graph, _replace_symbolic_dim_value, dim_param=param_name, value=value)
  131. def make_input_shape_fixed(graph: onnx.GraphProto, input_name: str, fixed_shape: [int]):
  132. """
  133. Update the named graph input to set shape to the provided value. This can be used to set unknown dims as well
  134. as to replace dim values.
  135. If setting the input shape replaces a dim_param, update any other values in the graph that use the dim_param.
  136. :param graph: Graph to update
  137. :param input_name: Name of graph input to update.
  138. :param fixed_shape: Shape to use.
  139. """
  140. # remove any invalid dim values first. typically this is a dim_value of -1.
  141. remove_invalid_dim_values(graph)
  142. for i in graph.input:
  143. if i.name == input_name:
  144. if not i.type.HasField("tensor_type"):
  145. raise ValueError(f"Input {input_name} is not a tensor")
  146. # graph inputs are required to have a shape to provide the rank
  147. shape = i.type.tensor_type.shape
  148. if len(shape.dim) != len(fixed_shape):
  149. raise ValueError(f"Rank mismatch. Existing:{len(shape.dim)} Replacement:{len(fixed_shape)}")
  150. for idx, dim in enumerate(shape.dim):
  151. # check any existing fixed dims match
  152. if dim.HasField("dim_value"):
  153. if dim.dim_value != fixed_shape[idx]:
  154. raise ValueError(
  155. f"Can't replace existing fixed size of {dim.dim_value} with {fixed_shape[idx]} "
  156. f"for dimension {idx + 1}"
  157. )
  158. elif dim.HasField("dim_param"):
  159. # replacing a dim_param so have to do that through the entire graph
  160. make_dim_param_fixed(graph, dim.dim_param, fixed_shape[idx])
  161. else:
  162. # replacing an unknown dim
  163. dim.Clear()
  164. dim.dim_value = fixed_shape[idx]
  165. return
  166. raise ValueError(
  167. f"Input {input_name} was not found in graph inputs. "
  168. f'Valid input names are: {",".join([i.name for i in graph.input])}'
  169. )
  170. def fix_output_shapes(model: onnx.ModelProto):
  171. """
  172. Update the output shapesof a model where the input shape/s were made fixed, if possible.
  173. This is mainly to make the model usage clearer if the output shapes can be inferred from the new input shapes.
  174. :param model: Model that had input shapes fixed.
  175. """
  176. # get a version of the model with shape inferencing info in it. this will provide fixed output shapes if possible.
  177. m2 = onnx.shape_inference.infer_shapes(model)
  178. onnx.checker.check_model(m2)
  179. for idx, o in enumerate(model.graph.output):
  180. if not is_fixed_size_tensor(o):
  181. new_o = m2.graph.output[idx]
  182. if is_fixed_size_tensor(new_o):
  183. o.type.tensor_type.shape.CopyFrom(new_o.type.tensor_type.shape)
  184. def _create_producer_consumer_link(
  185. node_to_producers: dict, node_to_consumers: dict, producer: onnx.NodeProto, consumer: onnx.NodeProto
  186. ):
  187. """
  188. Create links between two nodes for a value produced by one and consumed by the other.
  189. :param node_to_producers: Map of NodeProto to set of nodes that produce values the node consumes as inputs.
  190. :param node_to_consumers: Map of NodeProto to set of nodes that consume values the node produces as outputs.
  191. :param producer: Producer node
  192. :param consumer: Consumer node
  193. """
  194. if consumer not in node_to_producers:
  195. node_to_producers[consumer] = set()
  196. if producer not in node_to_consumers:
  197. node_to_consumers[producer] = set()
  198. # add entry mapping this node to the producer of this input
  199. node_to_producers[consumer].add(producer)
  200. node_to_consumers[producer].add(consumer)
  201. def _map_node_dependencies(graph: onnx.GraphProto, node_to_producers: dict, node_to_consumers: dict):
  202. graph_inputs = set([i.name for i in graph.input])
  203. initializers = set([i.name for i in graph.initializer])
  204. # map of value name to node that creates it. copy parent values but override if values get shadowed
  205. producers = {}
  206. implicit_inputs = set()
  207. def is_local_value(value):
  208. return value in producers or value in initializers or value in graph_inputs
  209. for node in graph.node:
  210. inputs = [i for i in node.input]
  211. for attr in node.attribute:
  212. if attr.HasField("g"):
  213. subgraph_implicit_inputs = _map_node_dependencies(attr.g, node_to_producers, node_to_consumers)
  214. inputs += subgraph_implicit_inputs
  215. for i in inputs:
  216. if not i:
  217. # missing optional input
  218. continue
  219. if is_local_value(i):
  220. if i in producers:
  221. producer = producers[i]
  222. _create_producer_consumer_link(node_to_producers, node_to_consumers, producer, node)
  223. else:
  224. implicit_inputs.add(i)
  225. for o in node.output:
  226. producers[o] = node
  227. return implicit_inputs
  228. def get_producer_consumer_maps(graph: onnx.GraphProto):
  229. """
  230. Get maps for connections between the node that produces each value and the nodes that consume the value.
  231. Processing includes subgraphs. As the map key is a Node instance from the Graph there should be no ambiguity.
  232. :param graph: Graph to process.
  233. :return: Tuple with two maps.
  234. First is node_to_producers map of a node to set of all nodes producing input it consumes.
  235. Second is node_to_consumers map of a node to set of all nodes consuming output it creates.
  236. e.g. NodeA and NodeB provide inputs to NodeC. NodeC provides input to NodeD
  237. node_to_consumers[NodeA] = set([NodeC])
  238. node_to_consumers[NodeB] = set([NodeC])
  239. node_to_producers[NodeC] = set([NodeA, NodeB])
  240. node_to_consumers[NodeC] = set([NodeD])
  241. node_to_producers[NodeD] = set([NodeC])
  242. """
  243. # use a hash of the object id for NodeProto.
  244. # we need this for the partitioning checker where we keep maps with nodes as the key.
  245. onnx.NodeProto.__hash__ = lambda self: id(self)
  246. node_to_producers = {} # map of node instance to nodes producing input values it consumes
  247. node_to_consumers = {} # map of node instance to nodes consuming output values it produces
  248. implicit_inputs = _map_node_dependencies(graph, node_to_producers, node_to_consumers)
  249. # top level graph should have no implicit inputs
  250. if implicit_inputs:
  251. raise ValueError(
  252. "This appears to be an invalid model with missing inputs of " f'{",".join(sorted(implicit_inputs))}'
  253. )
  254. return node_to_producers, node_to_consumers
  255. def is_fixed_size_tensor(value: onnx.ValueInfoProto):
  256. """
  257. Check if value is a tensor with a fixed shape.
  258. :param value: onnx.ValueInfoProto to check
  259. :return: True if value is a tensor, with a shape, where all dimensions have fixed values.
  260. """
  261. is_fixed = False
  262. if value.type.HasField("tensor_type"):
  263. shape = value.type.tensor_type.shape
  264. if shape:
  265. is_fixed = True # scalar has no dims so set to True and unset if we hit a dim without a valid value
  266. for dim in shape.dim:
  267. if dim.HasField("dim_value") and dim.dim_value > 0:
  268. continue
  269. # anything else means it's a dynamic value
  270. is_fixed = False
  271. break
  272. return is_fixed
  273. def get_optimization_level(level):
  274. """Convert string to GraphOptimizationLevel."""
  275. if level == "disable":
  276. return ort.GraphOptimizationLevel.ORT_DISABLE_ALL
  277. if level == "basic":
  278. # Constant folding and other optimizations that only use ONNX operators
  279. return ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
  280. if level == "extended":
  281. # Optimizations using custom operators, excluding NCHWc and NHWC layout optimizers
  282. return ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
  283. if level == "all":
  284. return ort.GraphOptimizationLevel.ORT_ENABLE_ALL
  285. raise ValueError("Invalid optimization level of " + level)