|
|
# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # --------------------------------------------------------------------------
from logging import getLogger from typing import Tuple
import numpy as np from fusion_attention import AttentionMask from fusion_base import Fusion from fusion_utils import FusionUtils, NumpyHelper from onnx import NodeProto, helper from onnx_model import OnnxModel
logger = getLogger(__name__)
class FusionQOrderedAttention(Fusion): def __init__( self, model: OnnxModel, hidden_size: int, num_heads: int, attention_mask: AttentionMask, ): self.hidden_size = hidden_size self.num_heads = num_heads self.attention_mask = attention_mask
super().__init__(model, "QOrderedAttention", "QOrderedLayerNormalization")
def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> Tuple[int, int]: """Detect num_heads and hidden_size from a reshape node.
Args: reshape_q (NodeProto): reshape node for Q Returns: Tuple[int, int]: num_heads and hidden_size """
# we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size] q_shape = self.model.get_initializer(reshape_q.input[1]) if q_shape is None: logger.debug(f"{reshape_q.input[1]} is not initializer.")
# Check if the second input to Reshape flows through a Constant node # TODO: Investigate why FusionAttention doesn't have such logic constant_node = self.model.match_parent_path(reshape_q, ["Constant"], [1])
if constant_node is None: return self.num_heads, self.hidden_size # Fall back to user specified value else: constant_node = constant_node[0]
if len(constant_node.attribute) != 1: return self.num_heads, self.hidden_size # Fall back to user specified value
# This is assuming it is a Tensor attribute (this is a safe assumption) q_shape = constant_node.attribute[0].t
q_shape_value = NumpyHelper.to_array(q_shape) if len(q_shape_value) != 4 or (q_shape_value[2] <= 0 or q_shape_value[3] <= 0): logger.debug(f"q_shape_value={q_shape_value}. Expected value are like [0, 0, num_heads, head_size].") return self.num_heads, self.hidden_size # Fall back to user specified value
num_heads = q_shape_value[2] head_size = q_shape_value[3] hidden_size = num_heads * head_size
if self.num_heads > 0 and num_heads != self.num_heads: if self.num_heads_warning: logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.") self.num_heads_warning = False # Do not show the warning more than once
if self.hidden_size > 0 and hidden_size != self.hidden_size: if self.hidden_size_warning: logger.warning( f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value." ) self.hidden_size_warning = False # Do not show the warning more than once
return num_heads, hidden_size
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): add_before_layernorm = self.model.match_parent_path( normalize_node, ["QuantizeLinear", "Add"], [0, 0], )
if add_before_layernorm is not None: start_node = add_before_layernorm[-1] else: return
# Input QDQ nodes dequantize_input = self.model.match_parent_path( start_node, ["DequantizeLinear"], [None], )
if dequantize_input is None: logger.debug("fuse_qordered_attention: failed to match input qdq nodes path") return
dequantize_input = dequantize_input[-1]
# QKV nodes qkv_nodes = self.model.match_parent_path( start_node, ["Add", "MatMul", "Reshape", "Transpose", "DequantizeLinear", "QuantizeLinear", "MatMul"], [None, None, 0, 0, 0, 0, 0], )
if qkv_nodes is None: logger.debug("fuse_qordered_attention: failed to match qkv path") return
(_, projection_matmul, reshape_qkv, transpose_qkv, dequantize_qkv, quantize_qkv, matmul_qkv) = qkv_nodes
# Make sure the Q/DQ has the proper zero points and constant per-tensor scales if not FusionUtils.check_qdq_node_for_fusion(quantize_qkv, self.model): return
if not FusionUtils.check_qdq_node_for_fusion(dequantize_qkv, self.model): return
# Identify the root input to the Attention node other_inputs = [] for i, input in enumerate(start_node.input): if input not in output_name_to_node: continue
if input == qkv_nodes[0].output[0]: continue
other_inputs.append(input)
if len(other_inputs) != 1: return
root_input = other_inputs[0]
# V nodes v_nodes = self.model.match_parent_path( matmul_qkv, ["Transpose", "Reshape", "DequantizeLinear", "QuantizeLinear", "Add", "MatMul"], [1, 0, 0, 0, 0, None], )
if v_nodes is None: logger.debug("fuse_qordered_attention: failed to match v path") return
(_, _, dequantize_v, quantize_v, add_v, matmul_v) = v_nodes
# Make sure the Q/DQ has the proper zero points and constant per-tensor scales if not FusionUtils.check_qdq_node_for_fusion(quantize_v, self.model): return
if not FusionUtils.check_qdq_node_for_fusion(dequantize_v, self.model): return
# V MatMul weight dequantize_v_matmul_weight = self.model.match_parent_path(matmul_v, ["DequantizeLinear"], [1])
if dequantize_v_matmul_weight is None: logger.debug("fuse_qordered_attention: failed to match v path") return
dequantize_v_matmul_weight = dequantize_v_matmul_weight[0]
if self.model.get_constant_value(dequantize_v_matmul_weight.input[0]) is None: return
# Make sure the upstream DequantizeLinear-1 has the proper zero points and scales # Per-channel scales are supported for weights alone if not FusionUtils.check_qdq_node_for_fusion(dequantize_v_matmul_weight, self.model, False): return
# QK nodes qk_nodes = self.model.match_parent_path( matmul_qkv, [ "DequantizeLinear", "QuantizeLinear", "Softmax", "Add", "Div", "DequantizeLinear", "QuantizeLinear", "MatMul", ], [0, 0, 0, 0, None, 0, 0, 0], )
if qk_nodes is None: logger.debug("fuse_qordered_attention: failed to match qk path") return
( dequantize_qk_softmax, quantize_qk_softmax, softmax_qk, add_qk, div_qk, dequantize_qk, quantize_qk, matmul_qk, ) = qk_nodes
# Make sure the Q/DQ has the proper zero points and constant per-tensor scales if not FusionUtils.check_qdq_node_for_fusion(quantize_qk_softmax, self.model): return
if not FusionUtils.check_qdq_node_for_fusion(dequantize_qk_softmax, self.model): return
if not FusionUtils.check_qdq_node_for_fusion(quantize_qk, self.model): return
if not FusionUtils.check_qdq_node_for_fusion(dequantize_qk, self.model): return
# Q nodes q_nodes = self.model.match_parent_path( matmul_qk, ["Transpose", "Reshape", "DequantizeLinear", "QuantizeLinear", "Add", "MatMul"], [0, 0, 0, 0, 0, None], )
if q_nodes is None: logger.debug("fuse_qordered_attention: failed to match q path") return
(_, reshape_q, dequantize_q, quantize_q, add_q, matmul_q) = q_nodes
# Make sure the Q/DQ has the proper zero points and constant per-tensor scales if not FusionUtils.check_qdq_node_for_fusion(quantize_q, self.model): return
if not FusionUtils.check_qdq_node_for_fusion(dequantize_q, self.model): return
# Q MatMul weight dequantize_q_matmul_weight = self.model.match_parent_path(matmul_q, ["DequantizeLinear"], [1])
if dequantize_q_matmul_weight is None: logger.debug("fuse_qordered_attention: failed to match q path") return
dequantize_q_matmul_weight = dequantize_q_matmul_weight[0]
if self.model.get_constant_value(dequantize_q_matmul_weight.input[0]) is None: return
# Make sure the upstream DequantizeLinear-1 has the proper zero points and scales # Per-channel scales are supported for weights alone if not FusionUtils.check_qdq_node_for_fusion(dequantize_q_matmul_weight, self.model, False): return
# K nodes k_nodes = self.model.match_parent_path( matmul_qk, ["Transpose", "Reshape", "DequantizeLinear", "QuantizeLinear", "Add", "MatMul"], [1, 0, 0, 0, 0, None], )
if k_nodes is None: logger.debug("fuse_qordered_attention: failed to match k path") return
(_, _, dequantize_k, quantize_k, add_k, matmul_k) = k_nodes
# Make sure the Q/DQ has the proper zero points and constant per-tensor scales if not FusionUtils.check_qdq_node_for_fusion(quantize_k, self.model): return
if not FusionUtils.check_qdq_node_for_fusion(dequantize_k, self.model): return
# K MatMul weight dequantize_k_matmul_weight = self.model.match_parent_path(matmul_k, ["DequantizeLinear"], [1])
if dequantize_k_matmul_weight is None: logger.debug("fuse_qordered_attention: failed to match k path") return
dequantize_k_matmul_weight = dequantize_k_matmul_weight[0]
if self.model.get_constant_value(dequantize_k_matmul_weight.input[0]) is None: return
# Make sure the upstream DequantizeLinear-1 has the proper zero points and scales # Per-channel scales are supported for weights alone if not FusionUtils.check_qdq_node_for_fusion(dequantize_k_matmul_weight, self.model, False): return
# Mask nodes mask_nodes = self.model.match_parent_path( add_qk, ["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"], [None, 0, 1, 0, 0] )
if mask_nodes is None: logger.debug("fuse_qordered_attention: failed to match mask_nodes path") return
# Ascertain `qkv_hidden_sizes` attribute value q_weight = self.model.get_initializer(dequantize_q_matmul_weight.input[0]) k_weight = self.model.get_initializer(dequantize_k_matmul_weight.input[0]) v_weight = self.model.get_initializer(dequantize_v_matmul_weight.input[0])
qw = NumpyHelper.to_array(q_weight) kw = NumpyHelper.to_array(k_weight) vw = NumpyHelper.to_array(v_weight)
qw_out_size = np.prod(qw.shape[1:]) kw_out_size = np.prod(kw.shape[1:]) vw_out_size = np.prod(vw.shape[1:])
# Form QOrderedAttention node if matmul_v.input[0] == root_input and matmul_q.input[0] == root_input and matmul_k.input[0] == root_input: mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
# Ascertain `num_heads` and `hidden_size` num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
# Formulate the inputs # Actual quantized input attention_inputs = [dequantize_input.input[0]] attention_inputs.append(dequantize_input.input[1])
attention_inputs.append(dequantize_q.input[1]) attention_inputs.append(dequantize_k.input[1]) attention_inputs.append(dequantize_v.input[1])
attention_inputs.append(dequantize_q_matmul_weight.input[0]) attention_inputs.append(dequantize_k_matmul_weight.input[0]) attention_inputs.append(dequantize_v_matmul_weight.input[0])
attention_inputs.append(dequantize_q_matmul_weight.input[1]) attention_inputs.append(dequantize_k_matmul_weight.input[1]) attention_inputs.append(dequantize_v_matmul_weight.input[1])
if self.model.get_initializer(add_q.input[0]): attention_inputs.append(add_q.input[0]) else: # second input is the constant bias attention_inputs.append(add_q.input[1])
if self.model.get_initializer(add_k.input[0]): attention_inputs.append(add_k.input[0]) else: # second input is the constant bias attention_inputs.append(add_k.input[1])
if self.model.get_initializer(add_v.input[0]): attention_inputs.append(add_v.input[0]) else: # second input is the constant bias attention_inputs.append(add_v.input[1])
attention_inputs.append(quantize_qk.input[1]) attention_inputs.append(quantize_qk_softmax.input[1]) attention_inputs.append(dequantize_qkv.input[1])
# Mask input if mask_index is not None: attention_inputs.append(mask_index) else: attention_inputs.append("")
# The MatMul weight 'B' and 'bias' need some post-processing # Transpose weight 'B' from order ROW to order COL # This offline transpose is needed only while using the CUDA EP # TODO: Make this fusion logic EP-agnostic ? q_weight_tensor = self.model.get_initializer(dequantize_q_matmul_weight.input[0]) FusionUtils.transpose_2d_int8_tensor(q_weight_tensor)
k_weight_tensor = self.model.get_initializer(dequantize_k_matmul_weight.input[0]) FusionUtils.transpose_2d_int8_tensor(k_weight_tensor)
v_weight_tensor = self.model.get_initializer(dequantize_v_matmul_weight.input[0]) FusionUtils.transpose_2d_int8_tensor(v_weight_tensor)
# Name and create Attention node attention_node_name = self.model.create_node_name("QOrderedAttention")
attention_node = helper.make_node( "QOrderedAttention", inputs=attention_inputs, outputs=[reshape_qkv.output[0]], name=attention_node_name, )
self.model.replace_node_input(dequantize_qkv, dequantize_qkv.input[0], attention_node.output[0]) self.model.replace_node_input(projection_matmul, projection_matmul.input[0], dequantize_qkv.output[0])
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)]) attention_node.attribute.extend([helper.make_attribute("order_input", 1)]) attention_node.attribute.extend([helper.make_attribute("order_weight", 0)]) attention_node.attribute.extend([helper.make_attribute("order_output", 1)]) attention_node.attribute.extend( [helper.make_attribute("qkv_hidden_sizes", [qw_out_size, kw_out_size, vw_out_size])] )
attention_node.domain = "com.microsoft"
self.nodes_to_add.append(attention_node) self.node_name_to_graph_name[attention_node.name] = self.this_graph_name
self.nodes_to_remove.extend([reshape_qkv, transpose_qkv, quantize_qkv, matmul_qkv]) self.nodes_to_remove.extend(qk_nodes) self.nodes_to_remove.extend(q_nodes) self.nodes_to_remove.extend(k_nodes) self.nodes_to_remove.extend(v_nodes) self.nodes_to_remove.extend( [dequantize_q_matmul_weight, dequantize_k_matmul_weight, dequantize_v_matmul_weight] )
# Use prune graph to remove mask nodes since they are shared by all attention nodes. # self.nodes_to_remove.extend(mask_nodes) self.prune_graph = True
|