# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger from typing import Tuple import numpy as np from fusion_attention import AttentionMask from fusion_base import Fusion from fusion_utils import FusionUtils, NumpyHelper from onnx import NodeProto, helper from onnx_model import OnnxModel logger = getLogger(__name__) class FusionQOrderedAttention(Fusion): def __init__( self, model: OnnxModel, hidden_size: int, num_heads: int, attention_mask: AttentionMask, ): self.hidden_size = hidden_size self.num_heads = num_heads self.attention_mask = attention_mask super().__init__(model, "QOrderedAttention", "QOrderedLayerNormalization") def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> Tuple[int, int]: """Detect num_heads and hidden_size from a reshape node. Args: reshape_q (NodeProto): reshape node for Q Returns: Tuple[int, int]: num_heads and hidden_size """ # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size] q_shape = self.model.get_initializer(reshape_q.input[1]) if q_shape is None: logger.debug(f"{reshape_q.input[1]} is not initializer.") # Check if the second input to Reshape flows through a Constant node # TODO: Investigate why FusionAttention doesn't have such logic constant_node = self.model.match_parent_path(reshape_q, ["Constant"], [1]) if constant_node is None: return self.num_heads, self.hidden_size # Fall back to user specified value else: constant_node = constant_node[0] if len(constant_node.attribute) != 1: return self.num_heads, self.hidden_size # Fall back to user specified value # This is assuming it is a Tensor attribute (this is a safe assumption) q_shape = constant_node.attribute[0].t q_shape_value = NumpyHelper.to_array(q_shape) if len(q_shape_value) != 4 or (q_shape_value[2] <= 0 or q_shape_value[3] <= 0): logger.debug(f"q_shape_value={q_shape_value}. Expected value are like [0, 0, num_heads, head_size].") return self.num_heads, self.hidden_size # Fall back to user specified value num_heads = q_shape_value[2] head_size = q_shape_value[3] hidden_size = num_heads * head_size if self.num_heads > 0 and num_heads != self.num_heads: if self.num_heads_warning: logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.") self.num_heads_warning = False # Do not show the warning more than once if self.hidden_size > 0 and hidden_size != self.hidden_size: if self.hidden_size_warning: logger.warning( f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value." ) self.hidden_size_warning = False # Do not show the warning more than once return num_heads, hidden_size def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): add_before_layernorm = self.model.match_parent_path( normalize_node, ["QuantizeLinear", "Add"], [0, 0], ) if add_before_layernorm is not None: start_node = add_before_layernorm[-1] else: return # Input QDQ nodes dequantize_input = self.model.match_parent_path( start_node, ["DequantizeLinear"], [None], ) if dequantize_input is None: logger.debug("fuse_qordered_attention: failed to match input qdq nodes path") return dequantize_input = dequantize_input[-1] # QKV nodes qkv_nodes = self.model.match_parent_path( start_node, ["Add", "MatMul", "Reshape", "Transpose", "DequantizeLinear", "QuantizeLinear", "MatMul"], [None, None, 0, 0, 0, 0, 0], ) if qkv_nodes is None: logger.debug("fuse_qordered_attention: failed to match qkv path") return (_, projection_matmul, reshape_qkv, transpose_qkv, dequantize_qkv, quantize_qkv, matmul_qkv) = qkv_nodes # Make sure the Q/DQ has the proper zero points and constant per-tensor scales if not FusionUtils.check_qdq_node_for_fusion(quantize_qkv, self.model): return if not FusionUtils.check_qdq_node_for_fusion(dequantize_qkv, self.model): return # Identify the root input to the Attention node other_inputs = [] for i, input in enumerate(start_node.input): if input not in output_name_to_node: continue if input == qkv_nodes[0].output[0]: continue other_inputs.append(input) if len(other_inputs) != 1: return root_input = other_inputs[0] # V nodes v_nodes = self.model.match_parent_path( matmul_qkv, ["Transpose", "Reshape", "DequantizeLinear", "QuantizeLinear", "Add", "MatMul"], [1, 0, 0, 0, 0, None], ) if v_nodes is None: logger.debug("fuse_qordered_attention: failed to match v path") return (_, _, dequantize_v, quantize_v, add_v, matmul_v) = v_nodes # Make sure the Q/DQ has the proper zero points and constant per-tensor scales if not FusionUtils.check_qdq_node_for_fusion(quantize_v, self.model): return if not FusionUtils.check_qdq_node_for_fusion(dequantize_v, self.model): return # V MatMul weight dequantize_v_matmul_weight = self.model.match_parent_path(matmul_v, ["DequantizeLinear"], [1]) if dequantize_v_matmul_weight is None: logger.debug("fuse_qordered_attention: failed to match v path") return dequantize_v_matmul_weight = dequantize_v_matmul_weight[0] if self.model.get_constant_value(dequantize_v_matmul_weight.input[0]) is None: return # Make sure the upstream DequantizeLinear-1 has the proper zero points and scales # Per-channel scales are supported for weights alone if not FusionUtils.check_qdq_node_for_fusion(dequantize_v_matmul_weight, self.model, False): return # QK nodes qk_nodes = self.model.match_parent_path( matmul_qkv, [ "DequantizeLinear", "QuantizeLinear", "Softmax", "Add", "Div", "DequantizeLinear", "QuantizeLinear", "MatMul", ], [0, 0, 0, 0, None, 0, 0, 0], ) if qk_nodes is None: logger.debug("fuse_qordered_attention: failed to match qk path") return ( dequantize_qk_softmax, quantize_qk_softmax, softmax_qk, add_qk, div_qk, dequantize_qk, quantize_qk, matmul_qk, ) = qk_nodes # Make sure the Q/DQ has the proper zero points and constant per-tensor scales if not FusionUtils.check_qdq_node_for_fusion(quantize_qk_softmax, self.model): return if not FusionUtils.check_qdq_node_for_fusion(dequantize_qk_softmax, self.model): return if not FusionUtils.check_qdq_node_for_fusion(quantize_qk, self.model): return if not FusionUtils.check_qdq_node_for_fusion(dequantize_qk, self.model): return # Q nodes q_nodes = self.model.match_parent_path( matmul_qk, ["Transpose", "Reshape", "DequantizeLinear", "QuantizeLinear", "Add", "MatMul"], [0, 0, 0, 0, 0, None], ) if q_nodes is None: logger.debug("fuse_qordered_attention: failed to match q path") return (_, reshape_q, dequantize_q, quantize_q, add_q, matmul_q) = q_nodes # Make sure the Q/DQ has the proper zero points and constant per-tensor scales if not FusionUtils.check_qdq_node_for_fusion(quantize_q, self.model): return if not FusionUtils.check_qdq_node_for_fusion(dequantize_q, self.model): return # Q MatMul weight dequantize_q_matmul_weight = self.model.match_parent_path(matmul_q, ["DequantizeLinear"], [1]) if dequantize_q_matmul_weight is None: logger.debug("fuse_qordered_attention: failed to match q path") return dequantize_q_matmul_weight = dequantize_q_matmul_weight[0] if self.model.get_constant_value(dequantize_q_matmul_weight.input[0]) is None: return # Make sure the upstream DequantizeLinear-1 has the proper zero points and scales # Per-channel scales are supported for weights alone if not FusionUtils.check_qdq_node_for_fusion(dequantize_q_matmul_weight, self.model, False): return # K nodes k_nodes = self.model.match_parent_path( matmul_qk, ["Transpose", "Reshape", "DequantizeLinear", "QuantizeLinear", "Add", "MatMul"], [1, 0, 0, 0, 0, None], ) if k_nodes is None: logger.debug("fuse_qordered_attention: failed to match k path") return (_, _, dequantize_k, quantize_k, add_k, matmul_k) = k_nodes # Make sure the Q/DQ has the proper zero points and constant per-tensor scales if not FusionUtils.check_qdq_node_for_fusion(quantize_k, self.model): return if not FusionUtils.check_qdq_node_for_fusion(dequantize_k, self.model): return # K MatMul weight dequantize_k_matmul_weight = self.model.match_parent_path(matmul_k, ["DequantizeLinear"], [1]) if dequantize_k_matmul_weight is None: logger.debug("fuse_qordered_attention: failed to match k path") return dequantize_k_matmul_weight = dequantize_k_matmul_weight[0] if self.model.get_constant_value(dequantize_k_matmul_weight.input[0]) is None: return # Make sure the upstream DequantizeLinear-1 has the proper zero points and scales # Per-channel scales are supported for weights alone if not FusionUtils.check_qdq_node_for_fusion(dequantize_k_matmul_weight, self.model, False): return # Mask nodes mask_nodes = self.model.match_parent_path( add_qk, ["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"], [None, 0, 1, 0, 0] ) if mask_nodes is None: logger.debug("fuse_qordered_attention: failed to match mask_nodes path") return # Ascertain `qkv_hidden_sizes` attribute value q_weight = self.model.get_initializer(dequantize_q_matmul_weight.input[0]) k_weight = self.model.get_initializer(dequantize_k_matmul_weight.input[0]) v_weight = self.model.get_initializer(dequantize_v_matmul_weight.input[0]) qw = NumpyHelper.to_array(q_weight) kw = NumpyHelper.to_array(k_weight) vw = NumpyHelper.to_array(v_weight) qw_out_size = np.prod(qw.shape[1:]) kw_out_size = np.prod(kw.shape[1:]) vw_out_size = np.prod(vw.shape[1:]) # Form QOrderedAttention node if matmul_v.input[0] == root_input and matmul_q.input[0] == root_input and matmul_k.input[0] == root_input: mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0]) # Ascertain `num_heads` and `hidden_size` num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q) # Formulate the inputs # Actual quantized input attention_inputs = [dequantize_input.input[0]] attention_inputs.append(dequantize_input.input[1]) attention_inputs.append(dequantize_q.input[1]) attention_inputs.append(dequantize_k.input[1]) attention_inputs.append(dequantize_v.input[1]) attention_inputs.append(dequantize_q_matmul_weight.input[0]) attention_inputs.append(dequantize_k_matmul_weight.input[0]) attention_inputs.append(dequantize_v_matmul_weight.input[0]) attention_inputs.append(dequantize_q_matmul_weight.input[1]) attention_inputs.append(dequantize_k_matmul_weight.input[1]) attention_inputs.append(dequantize_v_matmul_weight.input[1]) if self.model.get_initializer(add_q.input[0]): attention_inputs.append(add_q.input[0]) else: # second input is the constant bias attention_inputs.append(add_q.input[1]) if self.model.get_initializer(add_k.input[0]): attention_inputs.append(add_k.input[0]) else: # second input is the constant bias attention_inputs.append(add_k.input[1]) if self.model.get_initializer(add_v.input[0]): attention_inputs.append(add_v.input[0]) else: # second input is the constant bias attention_inputs.append(add_v.input[1]) attention_inputs.append(quantize_qk.input[1]) attention_inputs.append(quantize_qk_softmax.input[1]) attention_inputs.append(dequantize_qkv.input[1]) # Mask input if mask_index is not None: attention_inputs.append(mask_index) else: attention_inputs.append("") # The MatMul weight 'B' and 'bias' need some post-processing # Transpose weight 'B' from order ROW to order COL # This offline transpose is needed only while using the CUDA EP # TODO: Make this fusion logic EP-agnostic ? q_weight_tensor = self.model.get_initializer(dequantize_q_matmul_weight.input[0]) FusionUtils.transpose_2d_int8_tensor(q_weight_tensor) k_weight_tensor = self.model.get_initializer(dequantize_k_matmul_weight.input[0]) FusionUtils.transpose_2d_int8_tensor(k_weight_tensor) v_weight_tensor = self.model.get_initializer(dequantize_v_matmul_weight.input[0]) FusionUtils.transpose_2d_int8_tensor(v_weight_tensor) # Name and create Attention node attention_node_name = self.model.create_node_name("QOrderedAttention") attention_node = helper.make_node( "QOrderedAttention", inputs=attention_inputs, outputs=[reshape_qkv.output[0]], name=attention_node_name, ) self.model.replace_node_input(dequantize_qkv, dequantize_qkv.input[0], attention_node.output[0]) self.model.replace_node_input(projection_matmul, projection_matmul.input[0], dequantize_qkv.output[0]) attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)]) attention_node.attribute.extend([helper.make_attribute("order_input", 1)]) attention_node.attribute.extend([helper.make_attribute("order_weight", 0)]) attention_node.attribute.extend([helper.make_attribute("order_output", 1)]) attention_node.attribute.extend( [helper.make_attribute("qkv_hidden_sizes", [qw_out_size, kw_out_size, vw_out_size])] ) attention_node.domain = "com.microsoft" self.nodes_to_add.append(attention_node) self.node_name_to_graph_name[attention_node.name] = self.this_graph_name self.nodes_to_remove.extend([reshape_qkv, transpose_qkv, quantize_qkv, matmul_qkv]) self.nodes_to_remove.extend(qk_nodes) self.nodes_to_remove.extend(q_nodes) self.nodes_to_remove.extend(k_nodes) self.nodes_to_remove.extend(v_nodes) self.nodes_to_remove.extend( [dequantize_q_matmul_weight, dequantize_k_matmul_weight, dequantize_v_matmul_weight] ) # Use prune graph to remove mask nodes since they are shared by all attention nodes. # self.nodes_to_remove.extend(mask_nodes) self.prune_graph = True