m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

485 lines
20 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from typing import List, Optional
  7. from fusion_attention import AttentionMask, FusionAttention
  8. from fusion_biasgelu import FusionBiasGelu
  9. from fusion_embedlayer import FusionEmbedLayerNormalization
  10. from fusion_fastgelu import FusionFastGelu
  11. from fusion_gelu import FusionGelu
  12. from fusion_gelu_approximation import FusionGeluApproximation
  13. from fusion_gemmfastgelu import FusionGemmFastGelu
  14. from fusion_layernorm import FusionLayerNormalization, FusionLayerNormalizationTF
  15. from fusion_options import AttentionMaskFormat, FusionOptions
  16. from fusion_qordered_attention import FusionQOrderedAttention
  17. from fusion_qordered_gelu import FusionQOrderedGelu
  18. from fusion_qordered_layernorm import FusionQOrderedLayerNormalization
  19. from fusion_qordered_matmul import FusionQOrderedMatMul
  20. from fusion_reshape import FusionReshape
  21. from fusion_shape import FusionShape
  22. from fusion_skiplayernorm import FusionBiasSkipLayerNormalization, FusionSkipLayerNormalization
  23. from fusion_utils import FusionUtils
  24. from onnx import GraphProto, ModelProto, TensorProto, ValueInfoProto, helper
  25. from onnx_model import OnnxModel
  26. logger = getLogger(__name__)
  27. class BertOptimizationOptions(FusionOptions):
  28. """This class is deprecated"""
  29. def __init__(self, model_type):
  30. logger.warning(f"BertOptimizationOptions is depreciated. Please use FusionOptions instead.")
  31. super().__init__(model_type)
  32. class BertOnnxModel(OnnxModel):
  33. def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0):
  34. """Initialize BERT ONNX Model.
  35. Args:
  36. model (ModelProto): the ONNX model
  37. num_heads (int, optional): number of attention heads. Defaults to 0 (detect the parameter automatically).
  38. hidden_size (int, optional): hidden dimension. Defaults to 0 (detect the parameter automatically).
  39. """
  40. assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0)
  41. super().__init__(model)
  42. self.num_heads = num_heads
  43. self.hidden_size = hidden_size
  44. self.attention_mask = AttentionMask(self)
  45. self.attention_fusion = FusionAttention(self, self.hidden_size, self.num_heads, self.attention_mask)
  46. self.qordered_attention_fusion = FusionQOrderedAttention(
  47. self, self.hidden_size, self.num_heads, self.attention_mask
  48. )
  49. self.utils = FusionUtils(self)
  50. def fuse_attention(self):
  51. self.attention_fusion.apply()
  52. # Only relevant in models with Q-DQ nodes
  53. self.qordered_attention_fusion.apply()
  54. def fuse_gelu(self):
  55. fusion = FusionGelu(self)
  56. fusion.apply()
  57. fusion = FusionFastGelu(self)
  58. fusion.apply()
  59. # Only relevant in models with Q-DQ nodes
  60. fusion = FusionQOrderedGelu(self)
  61. fusion.apply()
  62. def fuse_bias_gelu(self, is_fastgelu):
  63. fusion = FusionBiasGelu(self, is_fastgelu)
  64. fusion.apply()
  65. def gelu_approximation(self):
  66. fusion = FusionGeluApproximation(self)
  67. fusion.apply()
  68. def fuse_gemm_fast_gelu(self):
  69. fusion = FusionGemmFastGelu(self)
  70. fusion.apply()
  71. def fuse_add_bias_skip_layer_norm(self):
  72. fusion = FusionBiasSkipLayerNormalization(self)
  73. fusion.apply()
  74. def fuse_reshape(self):
  75. fusion = FusionReshape(self)
  76. fusion.apply()
  77. def fuse_shape(self):
  78. fusion = FusionShape(self)
  79. fusion.apply()
  80. def fuse_embed_layer(self, use_mask_index):
  81. fusion = FusionEmbedLayerNormalization(self, use_mask_index)
  82. fusion.apply()
  83. def fuse_layer_norm(self):
  84. fusion = FusionLayerNormalization(self)
  85. fusion.apply()
  86. fusion = FusionLayerNormalizationTF(self)
  87. fusion.apply()
  88. # Only relevant in models with Q-DQ nodes
  89. fusion = FusionQOrderedLayerNormalization(self)
  90. fusion.apply()
  91. def fuse_skip_layer_norm(self):
  92. fusion = FusionSkipLayerNormalization(self)
  93. fusion.apply()
  94. # Only relevant in models with Q-DQ nodes
  95. def fuse_qordered_mamtul(self):
  96. fusion = FusionQOrderedMatMul(self)
  97. fusion.apply()
  98. def get_graph_inputs_from_node_type(self, op_type: str, input_indices: List[int], casted: bool):
  99. """
  100. Get graph inputs that feed into node type (like EmbedLayerNormalization or Attention).
  101. Returns a list of the graph input names based on the filter whether it is casted or not.
  102. """
  103. graph_inputs = []
  104. output_name_to_node = self.output_name_to_node()
  105. nodes = self.get_nodes_by_op_type(op_type)
  106. for node in nodes:
  107. bert_inputs = [node.input[i] for i in input_indices if i < len(node.input)]
  108. for bert_input in bert_inputs:
  109. if self.find_graph_input(bert_input):
  110. if not casted:
  111. graph_inputs.append(bert_input)
  112. elif bert_input in output_name_to_node:
  113. parent = output_name_to_node[bert_input]
  114. if parent.op_type == "Cast" and self.find_graph_input(parent.input[0]) is not None:
  115. if casted:
  116. graph_inputs.append(parent.input[0])
  117. return graph_inputs
  118. def get_graph_inputs_from_fused_nodes(self, casted: bool):
  119. inputs = self.get_graph_inputs_from_node_type("EmbedLayerNormalization", [0, 1, 7], casted)
  120. inputs += self.get_graph_inputs_from_node_type("Attention", [3], casted)
  121. return inputs
  122. def change_graph_input_type(
  123. self,
  124. graph: GraphProto,
  125. graph_input: ValueInfoProto,
  126. new_type: int = TensorProto.INT32,
  127. ):
  128. """Change graph input type, and add Cast node if needed.
  129. Args:
  130. graph (GraphProto): graph
  131. graph_input (TensorProto): input of the graph
  132. new_type (int, optional): new data type. Defaults to TensorProto.INT32.
  133. Returns:
  134. NodeProto: a new Cast node that added. None if Cast node is not added.
  135. List[NodeProto]: Cast nodes that have been removed.
  136. """
  137. assert isinstance(graph, GraphProto)
  138. assert isinstance(graph_input, ValueInfoProto)
  139. assert self.find_graph_input(graph_input.name)
  140. if graph_input.type.tensor_type.elem_type == int(new_type):
  141. return None, []
  142. new_cast_node = None
  143. nodes_to_remove = []
  144. input_name_to_nodes = self.input_name_to_nodes()
  145. if graph_input.name in input_name_to_nodes:
  146. nodes = input_name_to_nodes[graph_input.name]
  147. # For children that is not Cast node, insert a Cast node to convert int32 to original data type.
  148. nodes_not_cast = [node for node in nodes if node.op_type != "Cast"]
  149. if nodes_not_cast:
  150. node_name = self.create_node_name("Cast")
  151. output_name = node_name + "_" + graph_input.name
  152. new_value_info = graph.value_info.add()
  153. new_value_info.CopyFrom(graph_input)
  154. new_value_info.name = output_name
  155. new_cast_node = helper.make_node(
  156. "Cast",
  157. [graph_input.name],
  158. [output_name],
  159. to=int(graph_input.type.tensor_type.elem_type),
  160. name=node_name,
  161. )
  162. graph.node.extend([new_cast_node])
  163. for node in nodes_not_cast:
  164. OnnxModel.replace_node_input(node, graph_input.name, output_name)
  165. # For children that is Cast node, no need to insert Cast.
  166. # When the children is Cast to int32, we can remove that Cast node since input type is int32 now.
  167. nodes_cast = [node for node in nodes if node.op_type == "Cast"]
  168. for node in nodes_cast:
  169. if OnnxModel.get_node_attribute(node, "to") == int(new_type):
  170. self.replace_input_of_all_nodes(node.output[0], graph_input.name)
  171. if not self.find_graph_output(node.output[0]):
  172. nodes_to_remove.append(node)
  173. if nodes_to_remove:
  174. self.remove_nodes(nodes_to_remove)
  175. graph_input.type.tensor_type.elem_type = int(new_type)
  176. return new_cast_node, nodes_to_remove
  177. def change_graph_inputs_to_int32(self):
  178. """Change data type of all graph inputs to int32 type, and add Cast node if needed."""
  179. graph = self.graph()
  180. add_cast_count = 0
  181. remove_cast_count = 0
  182. for graph_input in graph.input:
  183. new_node, removed_nodes = self.change_graph_input_type(graph, graph_input, TensorProto.INT32)
  184. if new_node:
  185. add_cast_count += 1
  186. remove_cast_count += len(removed_nodes)
  187. logger.info(
  188. f"Graph inputs are changed to int32. Added {add_cast_count} Cast nodes, and removed {remove_cast_count} Cast nodes."
  189. )
  190. def use_dynamic_axes(self, dynamic_batch_dim="batch_size", dynamic_seq_len="max_seq_len"):
  191. """
  192. Update input and output shape to use dynamic axes.
  193. """
  194. bert_graph_inputs = self.get_graph_inputs_from_fused_nodes(
  195. casted=True
  196. ) + self.get_graph_inputs_from_fused_nodes(casted=False)
  197. dynamic_batch_inputs = {}
  198. for input in self.model.graph.input:
  199. if input.name in bert_graph_inputs:
  200. dim_proto = input.type.tensor_type.shape.dim[0]
  201. dim_proto.dim_param = dynamic_batch_dim
  202. if dynamic_seq_len is not None:
  203. dim_proto = input.type.tensor_type.shape.dim[1]
  204. dim_proto.dim_param = dynamic_seq_len
  205. for output in self.model.graph.output:
  206. dim_proto = output.type.tensor_type.shape.dim[0]
  207. dim_proto.dim_param = dynamic_batch_dim
  208. def preprocess(self):
  209. self.adjust_reshape_and_expand()
  210. return
  211. def adjust_reshape_and_expand(self):
  212. nodes_to_remove = []
  213. for node in self.nodes():
  214. if node.op_type == "Reshape":
  215. # Clean up unneccessary reshape nodes.
  216. # Find reshape nodes with no actually data in "shape" attribute and remove.
  217. reshape_shape = self.get_constant_value(node.input[1])
  218. if reshape_shape is not None and reshape_shape.size == 0:
  219. nodes_to_remove.extend([node])
  220. self.replace_input_of_all_nodes(node.output[0], node.input[0])
  221. continue
  222. # Find path "Slice" -> "Reshape" -> "Expand" -> "Expand" -> current "Reshape", simplify the graph by
  223. # changing current reshape's input to output of slice.
  224. reshape_path = self.match_parent_path(
  225. node,
  226. ["Expand", "Expand", "Reshape", "Slice"],
  227. [0, 0, 0, 0],
  228. self.output_name_to_node(),
  229. )
  230. if reshape_path is not None:
  231. expand_node = reshape_path[-3]
  232. expand_shape_value = self.get_constant_value(expand_node.input[1])
  233. reshape_before_expand = reshape_path[-2]
  234. shape_value = self.get_constant_value(reshape_before_expand.input[1])
  235. slice_node = reshape_path[-1]
  236. if (
  237. expand_shape_value is not None
  238. and shape_value is not None
  239. and len(expand_shape_value) == 2
  240. and len(shape_value) == 1
  241. and expand_shape_value[1] == shape_value[0]
  242. ):
  243. node.input[0] = slice_node.output[0]
  244. if nodes_to_remove:
  245. self.remove_nodes(nodes_to_remove)
  246. logger.info(f"Removed Reshape and Expand count: {len(nodes_to_remove)}")
  247. def clean_graph(self):
  248. output_name_to_node = self.output_name_to_node()
  249. nodes_to_remove = []
  250. for node in self.nodes():
  251. # Before:
  252. # input_ids --> Shape --> Gather(indices=0) --> Unsqueeze ------+
  253. # | |
  254. # | v
  255. # +----> Shape --> Gather(indices=1) --> Unsqueeze---> Concat --> ConstantOfShape -->Cast --> EmbedLayerNormaliation/ReduceSum
  256. # After:
  257. # input_ids --> Shape --> ConstantOfShape -->Cast --> EmbedLayerNormaliation/ReduceSum
  258. # TODO: merge ConstantOfShape -->Cast to ConstantOfShape (need update the data type of value)
  259. op_input_id = {"EmbedLayerNormalization": 1, "ReduceSum": 0, "Attention": 3}
  260. if node.op_type in op_input_id:
  261. i = op_input_id[node.op_type]
  262. parent_nodes = self.match_parent_path(
  263. node,
  264. [
  265. "Cast",
  266. "ConstantOfShape",
  267. "Concat",
  268. "Unsqueeze",
  269. "Gather",
  270. "Shape",
  271. ],
  272. [i, 0, 0, 0, 0, 0],
  273. output_name_to_node,
  274. )
  275. if parent_nodes is not None:
  276. (
  277. cast,
  278. constantOfShape,
  279. concat,
  280. unsqueeze,
  281. gather,
  282. shape,
  283. ) = parent_nodes
  284. if shape.input[0] == self.graph().input[0].name:
  285. constantOfShape.input[0] = shape.output[0]
  286. output_name_to_node = self.output_name_to_node()
  287. if node.op_type == "Attention":
  288. # Before:
  289. # input_ids --> Shape -->ConstantOfShape -->Cast --> ReduceSum --> Attention
  290. # After:
  291. # remove this path, and remove the optional mask_index input of Attention node.
  292. parent_nodes = self.match_parent_path(
  293. node,
  294. ["ReduceSum", "Cast", "ConstantOfShape", "Shape"],
  295. [3, 0, 0, 0],
  296. output_name_to_node,
  297. )
  298. if parent_nodes is not None:
  299. if parent_nodes[-1].input[0] == self.graph().input[0].name:
  300. attention_node = helper.make_node(
  301. "Attention",
  302. inputs=node.input[0 : len(node.input) - 1],
  303. outputs=node.output,
  304. name=node.name + "_remove_mask",
  305. )
  306. attention_node.domain = "com.microsoft"
  307. attention_node.attribute.extend([helper.make_attribute("num_heads", self.num_heads)])
  308. self.add_node(attention_node, self.get_graph_by_node(attention_node).name)
  309. nodes_to_remove.append(node)
  310. self.remove_nodes(nodes_to_remove)
  311. def postprocess(self):
  312. self.clean_graph()
  313. self.prune_graph()
  314. def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False):
  315. if (options is not None) and not options.enable_shape_inference:
  316. self.disable_shape_inference()
  317. self.utils.remove_identity_nodes()
  318. # Remove cast nodes that having same data type of input and output based on symbolic shape inference.
  319. self.utils.remove_useless_cast_nodes()
  320. if (options is None) or options.enable_layer_norm:
  321. self.fuse_layer_norm()
  322. if (options is None) or options.enable_gelu:
  323. self.fuse_gelu()
  324. self.preprocess()
  325. self.fuse_reshape()
  326. if (options is None) or options.enable_skip_layer_norm:
  327. self.fuse_skip_layer_norm()
  328. if options is not None:
  329. self.attention_mask.set_mask_format(options.attention_mask_format)
  330. if options.use_multi_head_attention:
  331. self.attention_fusion = FusionAttention(
  332. self, self.hidden_size, self.num_heads, self.attention_mask, options.use_multi_head_attention
  333. )
  334. if (options is None) or options.enable_attention:
  335. self.fuse_attention()
  336. # Perform the MatMul fusion after the Attention fusion as we do not
  337. # want to fuse the MatMuls inside the Attention subgraphs
  338. if (options is None) or options.enable_qordered_matmul:
  339. self.fuse_qordered_mamtul()
  340. self.fuse_shape()
  341. if (options is None) or options.enable_embed_layer_norm:
  342. use_mask_index = options.attention_mask_format == AttentionMaskFormat.MaskIndexEnd
  343. self.fuse_embed_layer(use_mask_index)
  344. # Remove reshape nodes that having same shape of input and output based on symbolic shape inference.
  345. self.utils.remove_useless_reshape_nodes()
  346. self.postprocess()
  347. # Bias fusion is done after postprocess to avoid extra Reshape between bias and Gelu/FastGelu/SkipLayerNormalization
  348. if (options is None) or options.enable_bias_gelu:
  349. # Fuse Gelu and Add Bias before it.
  350. self.fuse_bias_gelu(is_fastgelu=True)
  351. self.fuse_bias_gelu(is_fastgelu=False)
  352. if (options is None) or options.enable_bias_skip_layer_norm:
  353. # Fuse SkipLayerNormalization and Add Bias before it.
  354. self.fuse_add_bias_skip_layer_norm()
  355. if options is not None and options.enable_gelu_approximation:
  356. self.gelu_approximation()
  357. if options is not None and options.enable_gemm_fast_gelu:
  358. self.fuse_gemm_fast_gelu()
  359. self.remove_unused_constant()
  360. # Use symbolic batch dimension in input and output.
  361. if add_dynamic_axes:
  362. self.use_dynamic_axes()
  363. logger.info(f"opset version: {self.get_opset_version()}")
  364. def get_fused_operator_statistics(self):
  365. """
  366. Returns node count of fused operators.
  367. """
  368. op_count = {}
  369. ops = [
  370. "EmbedLayerNormalization",
  371. "Attention",
  372. "MultiHeadAttention",
  373. "Gelu",
  374. "FastGelu",
  375. "BiasGelu",
  376. "GemmFastGelu",
  377. "LayerNormalization",
  378. "SkipLayerNormalization",
  379. ]
  380. q_ops = ["QOrderedAttention", "QOrderedGelu", "QOrderedLayerNormalization", "QOrderedMatMul"]
  381. for op in ops + q_ops:
  382. nodes = self.get_nodes_by_op_type(op)
  383. op_count[op] = len(nodes)
  384. logger.info(f"Optimized operators:{op_count}")
  385. return op_count
  386. def is_fully_optimized(self):
  387. """
  388. Returns True when the model is fully optimized.
  389. """
  390. op_count = self.get_fused_operator_statistics()
  391. embed = op_count["EmbedLayerNormalization"]
  392. attention = op_count["Attention"] + op_count["MultiHeadAttention"] + op_count["QOrderedAttention"]
  393. gelu = op_count["Gelu"] + op_count["BiasGelu"] + op_count["FastGelu"]
  394. layer_norm = op_count["LayerNormalization"] + op_count["SkipLayerNormalization"]
  395. is_perfect = (embed > 0) and (attention > 0) and (attention == gelu) and (layer_norm >= 2 * attention)
  396. if layer_norm == 0:
  397. logger.debug("Layer Normalization not fused")
  398. if gelu == 0:
  399. logger.debug("Gelu/FastGelu not fused")
  400. if embed == 0:
  401. logger.debug("Embed Layer not fused")
  402. if attention == 0:
  403. logger.warning("Attention not fused")
  404. return is_perfect