m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

383 lines
16 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from typing import Tuple, Union
  7. import numpy as np
  8. from fusion_base import Fusion
  9. from fusion_utils import NumpyHelper
  10. from onnx import NodeProto, TensorProto, helper
  11. from onnx_model import OnnxModel
  12. logger = getLogger(__name__)
  13. class FusionAttentionUnet(Fusion):
  14. """
  15. Fuse Attention subgraph of UNet into one Attention node.
  16. """
  17. def __init__(
  18. self, model: OnnxModel, hidden_size: int, num_heads: int, is_cross_attention: bool, enable_packed_kv: bool
  19. ):
  20. super().__init__(model, "MultiHeadAttention" if is_cross_attention else "Attention", ["LayerNormalization"])
  21. self.hidden_size = hidden_size
  22. self.num_heads = num_heads
  23. self.is_cross_attention = is_cross_attention
  24. self.enable_packed_kv = enable_packed_kv
  25. # Flags to show warning only once
  26. self.num_heads_warning = True
  27. self.hidden_size_warning = True
  28. def get_num_heads_and_hidden_size(self, reshape_q: NodeProto, layernorm_node: NodeProto) -> Tuple[int, int]:
  29. """Detect num_heads and hidden_size from a reshape node.
  30. Args:
  31. reshape_q (NodeProto): reshape node for Q
  32. add_q (NodeProto): add node for Q
  33. Returns:
  34. Tuple[int, int]: num_heads and hidden_size
  35. """
  36. # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
  37. q_shape_value = self.model.get_constant_value(reshape_q.input[1])
  38. if q_shape_value is None:
  39. logger.debug(f"{reshape_q.input[1]} is not constant.")
  40. return self.num_heads, self.hidden_size # Fall back to user specified value
  41. if len(q_shape_value) != 4 or q_shape_value[2] <= 0:
  42. logger.debug(f"q_shape_value={q_shape_value}. Expected value are like [0, 0, num_heads, -1].")
  43. return self.num_heads, self.hidden_size # Fall back to user specified value
  44. num_heads = q_shape_value[2]
  45. layernorm_bias = self.model.get_initializer(layernorm_node.input[1])
  46. if layernorm_bias is None:
  47. logger.debug(f"{layernorm_node.input[1]} is not initializer.")
  48. return self.num_heads, self.hidden_size # Fall back to user specified value
  49. hidden_size = NumpyHelper.to_array(layernorm_bias).shape[0]
  50. if self.num_heads > 0 and num_heads != self.num_heads:
  51. if self.num_heads_warning:
  52. logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.")
  53. self.num_heads_warning = False # Do not show the warning more than once
  54. if self.hidden_size > 0 and hidden_size != self.hidden_size:
  55. if self.hidden_size_warning:
  56. logger.warning(
  57. f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value."
  58. )
  59. self.hidden_size_warning = False # Do not show the warning more than once
  60. return num_heads, hidden_size
  61. def create_attention_node(
  62. self,
  63. q_matmul: NodeProto,
  64. k_matmul: NodeProto,
  65. v_matmul: NodeProto,
  66. num_heads: int,
  67. hidden_size: int,
  68. input: str,
  69. output: str,
  70. ) -> Union[NodeProto, None]:
  71. """Create an Attention node.
  72. Args:
  73. q_matmul (NodeProto): MatMul node in fully connection for Q
  74. k_matmul (NodeProto): MatMul node in fully connection for K
  75. v_matmul (NodeProto): MatMul node in fully connection for V
  76. q_add (NodeProto): Add bias node in fully connection for Q
  77. k_add (NodeProto): Add bias node in fully connection for K
  78. v_add (NodeProto): Add bias node in fully connection for V
  79. num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
  80. hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
  81. input (str): input name
  82. output (str): output name
  83. Returns:
  84. Union[NodeProto, None]: the node created or None if failed.
  85. """
  86. is_self_attention = not self.is_cross_attention
  87. if is_self_attention:
  88. if q_matmul.input[0] != input or k_matmul.input[0] != input or v_matmul.input[0] != input:
  89. logger.debug(
  90. "For self attention, input hidden state for q and k/v shall be same. Got %s, %s, %s",
  91. q_matmul.input[0],
  92. k_matmul.input[0],
  93. v_matmul.input[0],
  94. )
  95. return None
  96. else:
  97. if q_matmul.input[0] != input or (k_matmul.input[0] != v_matmul.input[0]) or (k_matmul.input[0] == input):
  98. logger.debug(
  99. "For cross attention, input hidden state for q and k/v shall be different. Got %s, %s, %s",
  100. q_matmul.input[0],
  101. k_matmul.input[0],
  102. v_matmul.input[0],
  103. )
  104. return None
  105. if hidden_size > 0 and (hidden_size % num_heads) != 0:
  106. logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
  107. return None
  108. q_weight = self.model.get_initializer(q_matmul.input[1])
  109. k_weight = self.model.get_initializer(k_matmul.input[1])
  110. v_weight = self.model.get_initializer(v_matmul.input[1])
  111. if not (q_weight and k_weight and v_weight):
  112. return None
  113. # Sometimes weights are stored in fp16
  114. if q_weight.data_type == 10:
  115. logger.debug("weights are in fp16. Please run fp16 conversion after optimization")
  116. return None
  117. qw = NumpyHelper.to_array(q_weight)
  118. kw = NumpyHelper.to_array(k_weight)
  119. vw = NumpyHelper.to_array(v_weight)
  120. logger.debug(f"qw={qw.shape} kw={kw.shape} vw={vw.shape} hidden_size={hidden_size}")
  121. # assert q and k have same shape as expected
  122. if is_self_attention:
  123. if qw.shape != kw.shape or qw.shape != vw.shape:
  124. return None
  125. qw_in_size = qw.shape[0]
  126. kw_in_size = kw.shape[0]
  127. vw_in_size = vw.shape[0]
  128. assert qw_in_size == kw_in_size and kw_in_size == vw_in_size
  129. if hidden_size > 0 and hidden_size != qw_in_size:
  130. raise ValueError(
  131. f"Input hidden size ({hidden_size}) is not same as weight dimension of q,k,v ({qw_in_size}). "
  132. "Please provide a correct input hidden size or pass in 0"
  133. )
  134. # All the matrices can have the same shape or q, k matrics can have the same shape with v being different
  135. # For 2d weights, the shapes would be [in_size, out_size].
  136. # For 3d weights, shape would be [in_size, a, b] where a*b = out_size
  137. qw_out_size = np.prod(qw.shape[1:])
  138. qkv_weight = np.stack((qw, kw, vw), axis=1)
  139. qkv_weight_dim = 3 * qw_out_size
  140. attention_node_name = self.model.create_node_name("Attention")
  141. weight = helper.make_tensor(
  142. name=attention_node_name + "_qkv_weight",
  143. data_type=TensorProto.FLOAT,
  144. dims=[qw_in_size, qkv_weight_dim],
  145. vals=qkv_weight.flatten().tolist(),
  146. )
  147. self.model.add_initializer(weight, self.this_graph_name)
  148. else: # cross attention
  149. attention_node_name = self.model.create_node_name("MultiHeadAttention")
  150. if self.enable_packed_kv:
  151. if kw.shape != vw.shape:
  152. return None
  153. kw_in_size = kw.shape[0]
  154. vw_in_size = vw.shape[0]
  155. assert kw_in_size == vw_in_size
  156. qw_out_size = qw.shape[1]
  157. kw_out_size = kw.shape[1]
  158. vw_out_size = vw.shape[1]
  159. assert qw_out_size == vw_out_size and kw_out_size == vw_out_size
  160. c = kw_in_size
  161. n = num_heads
  162. h = kw_out_size // num_heads
  163. # Concat and interleave weights so that the output of fused KV GEMM has [B, S_kv, N, 2, H] shape
  164. kv_weight = np.dstack([kw.reshape(c, n, h), vw.reshape(c, n, h)]).reshape(c, n * 2 * h)
  165. matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_KV")
  166. weight = helper.make_tensor(
  167. name=matmul_node_name + "_weight",
  168. data_type=TensorProto.FLOAT,
  169. dims=[kv_weight.shape[0], kv_weight.shape[1]],
  170. vals=kv_weight.flatten().tolist(),
  171. )
  172. self.model.add_initializer(weight, self.this_graph_name)
  173. matmul_node = helper.make_node(
  174. "MatMul",
  175. inputs=[k_matmul.input[0], matmul_node_name + "_weight"],
  176. outputs=[matmul_node_name + "_out"],
  177. name=matmul_node_name,
  178. )
  179. self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name
  180. shape_tensor = helper.make_tensor(
  181. name=matmul_node_name + "_reshape_shape",
  182. data_type=TensorProto.INT64,
  183. dims=[5],
  184. vals=[0, 0, n, 2, h],
  185. )
  186. self.model.add_initializer(shape_tensor, self.this_graph_name)
  187. reshape_node = helper.make_node(
  188. "Reshape",
  189. inputs=[matmul_node_name + "_out", matmul_node_name + "_reshape_shape"],
  190. outputs=[k_matmul.output[0]],
  191. name=matmul_node_name + "_reshape",
  192. )
  193. self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name
  194. self.nodes_to_add.extend([matmul_node, reshape_node])
  195. self.nodes_to_remove.extend([k_matmul, v_matmul])
  196. # No bias, use zeros
  197. qkv_bias = np.zeros([3, hidden_size], dtype=np.float32)
  198. qkv_bias_dim = 3 * hidden_size
  199. bias = helper.make_tensor(
  200. name=attention_node_name + "_qkv_bias",
  201. data_type=TensorProto.FLOAT,
  202. dims=[qkv_bias_dim],
  203. vals=qkv_bias.flatten().tolist(),
  204. )
  205. self.model.add_initializer(bias, self.this_graph_name)
  206. if is_self_attention:
  207. attention_inputs = [
  208. input,
  209. attention_node_name + "_qkv_weight",
  210. attention_node_name + "_qkv_bias",
  211. ]
  212. else:
  213. if not self.enable_packed_kv:
  214. attention_inputs = [
  215. q_matmul.output[0],
  216. k_matmul.output[0],
  217. v_matmul.output[0],
  218. attention_node_name + "_qkv_bias",
  219. ]
  220. else:
  221. attention_inputs = [
  222. q_matmul.output[0],
  223. k_matmul.output[0],
  224. ]
  225. attention_node = helper.make_node(
  226. "Attention" if is_self_attention else "MultiHeadAttention",
  227. inputs=attention_inputs,
  228. outputs=[output],
  229. name=attention_node_name,
  230. )
  231. attention_node.domain = "com.microsoft"
  232. attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
  233. counter_name = (
  234. "Attention (self attention)"
  235. if is_self_attention
  236. else "MultiHeadAttention ({})".format(
  237. "cross attention with packed kv" if self.enable_packed_kv else "cross attention"
  238. )
  239. )
  240. self.increase_counter(counter_name)
  241. return attention_node
  242. def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
  243. node_before_layernorm = self.model.match_parent(normalize_node, "Add", 0)
  244. # In SD 1.5, for self attention, LayerNorm has parent Reshape
  245. if node_before_layernorm is None and not self.is_cross_attention:
  246. node_before_layernorm = self.model.match_parent(normalize_node, "Reshape", 0)
  247. if node_before_layernorm is None:
  248. return
  249. root_input = node_before_layernorm.output[0]
  250. children_nodes = input_name_to_nodes[root_input]
  251. skip_add = None
  252. for node in children_nodes:
  253. if node.op_type == "Add": # or node.op_type == "SkipLayerNormalization":
  254. skip_add = node
  255. break
  256. if skip_add is None:
  257. return
  258. another_input = 1 if skip_add.input[0] == root_input else 0
  259. qkv_nodes = self.model.match_parent_path(
  260. skip_add,
  261. ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
  262. [another_input, None, None, 0, 0, 0],
  263. )
  264. if qkv_nodes is None:
  265. return
  266. (_, _, reshape_qkv, transpose_qkv, _, matmul_qkv) = qkv_nodes
  267. # No bias. For cross-attention, the input of the MatMul is encoder_hidden_states graph input.
  268. v_nodes = self.model.match_parent_path(matmul_qkv, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0])
  269. if v_nodes is None:
  270. logger.debug("fuse_attention: failed to match v path")
  271. return
  272. (_, _, _, matmul_v) = v_nodes
  273. qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Mul", "MatMul"], [0, 0, 0])
  274. if qk_nodes is not None:
  275. (_softmax_qk, _mul_qk, matmul_qk) = qk_nodes
  276. else:
  277. qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0])
  278. if qk_nodes is not None:
  279. (_softmax_qk, _add_zero, _mul_qk, matmul_qk) = qk_nodes
  280. else:
  281. logger.debug("fuse_attention: failed to match qk path")
  282. return
  283. q_nodes = self.model.match_parent_path(matmul_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [0, 0, 0, 0])
  284. if q_nodes is None:
  285. logger.debug("fuse_attention: failed to match q path")
  286. return
  287. (_, _transpose_q, reshape_q, matmul_q) = q_nodes
  288. k_nodes = self.model.match_parent_path(
  289. matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0, 0]
  290. )
  291. if k_nodes is None:
  292. logger.debug("fuse_attention: failed to match k path")
  293. return
  294. (_, _, _, _, matmul_k) = k_nodes
  295. attention_last_node = reshape_qkv
  296. q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node)
  297. if q_num_heads <= 0:
  298. logger.debug("fuse_attention: failed to detect num_heads")
  299. return
  300. # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
  301. new_node = self.create_attention_node(
  302. matmul_q,
  303. matmul_k,
  304. matmul_v,
  305. q_num_heads,
  306. q_hidden_size,
  307. input=normalize_node.output[0],
  308. output=attention_last_node.output[0],
  309. )
  310. if new_node is None:
  311. return
  312. self.nodes_to_add.append(new_node)
  313. self.node_name_to_graph_name[new_node.name] = self.this_graph_name
  314. self.nodes_to_remove.extend([attention_last_node, transpose_qkv])
  315. # Use prune graph to remove nodes since they are shared by all attention nodes.
  316. self.prune_graph = True