You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
421 lines
16 KiB
421 lines
16 KiB
# -------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
# --------------------------------------------------------------------------
|
|
|
|
from logging import getLogger
|
|
from typing import Tuple
|
|
|
|
import numpy as np
|
|
from fusion_attention import AttentionMask
|
|
from fusion_base import Fusion
|
|
from fusion_utils import FusionUtils, NumpyHelper
|
|
from onnx import NodeProto, helper
|
|
from onnx_model import OnnxModel
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
class FusionQOrderedAttention(Fusion):
|
|
def __init__(
|
|
self,
|
|
model: OnnxModel,
|
|
hidden_size: int,
|
|
num_heads: int,
|
|
attention_mask: AttentionMask,
|
|
):
|
|
self.hidden_size = hidden_size
|
|
self.num_heads = num_heads
|
|
self.attention_mask = attention_mask
|
|
|
|
super().__init__(model, "QOrderedAttention", "QOrderedLayerNormalization")
|
|
|
|
def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> Tuple[int, int]:
|
|
"""Detect num_heads and hidden_size from a reshape node.
|
|
Args:
|
|
reshape_q (NodeProto): reshape node for Q
|
|
Returns:
|
|
Tuple[int, int]: num_heads and hidden_size
|
|
"""
|
|
|
|
# we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
|
|
q_shape = self.model.get_initializer(reshape_q.input[1])
|
|
if q_shape is None:
|
|
logger.debug(f"{reshape_q.input[1]} is not initializer.")
|
|
|
|
# Check if the second input to Reshape flows through a Constant node
|
|
# TODO: Investigate why FusionAttention doesn't have such logic
|
|
constant_node = self.model.match_parent_path(reshape_q, ["Constant"], [1])
|
|
|
|
if constant_node is None:
|
|
return self.num_heads, self.hidden_size # Fall back to user specified value
|
|
else:
|
|
constant_node = constant_node[0]
|
|
|
|
if len(constant_node.attribute) != 1:
|
|
return self.num_heads, self.hidden_size # Fall back to user specified value
|
|
|
|
# This is assuming it is a Tensor attribute (this is a safe assumption)
|
|
q_shape = constant_node.attribute[0].t
|
|
|
|
q_shape_value = NumpyHelper.to_array(q_shape)
|
|
if len(q_shape_value) != 4 or (q_shape_value[2] <= 0 or q_shape_value[3] <= 0):
|
|
logger.debug(f"q_shape_value={q_shape_value}. Expected value are like [0, 0, num_heads, head_size].")
|
|
return self.num_heads, self.hidden_size # Fall back to user specified value
|
|
|
|
num_heads = q_shape_value[2]
|
|
head_size = q_shape_value[3]
|
|
hidden_size = num_heads * head_size
|
|
|
|
if self.num_heads > 0 and num_heads != self.num_heads:
|
|
if self.num_heads_warning:
|
|
logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.")
|
|
self.num_heads_warning = False # Do not show the warning more than once
|
|
|
|
if self.hidden_size > 0 and hidden_size != self.hidden_size:
|
|
if self.hidden_size_warning:
|
|
logger.warning(
|
|
f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value."
|
|
)
|
|
self.hidden_size_warning = False # Do not show the warning more than once
|
|
|
|
return num_heads, hidden_size
|
|
|
|
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
|
|
add_before_layernorm = self.model.match_parent_path(
|
|
normalize_node,
|
|
["QuantizeLinear", "Add"],
|
|
[0, 0],
|
|
)
|
|
|
|
if add_before_layernorm is not None:
|
|
start_node = add_before_layernorm[-1]
|
|
else:
|
|
return
|
|
|
|
# Input QDQ nodes
|
|
dequantize_input = self.model.match_parent_path(
|
|
start_node,
|
|
["DequantizeLinear"],
|
|
[None],
|
|
)
|
|
|
|
if dequantize_input is None:
|
|
logger.debug("fuse_qordered_attention: failed to match input qdq nodes path")
|
|
return
|
|
|
|
dequantize_input = dequantize_input[-1]
|
|
|
|
# QKV nodes
|
|
qkv_nodes = self.model.match_parent_path(
|
|
start_node,
|
|
["Add", "MatMul", "Reshape", "Transpose", "DequantizeLinear", "QuantizeLinear", "MatMul"],
|
|
[None, None, 0, 0, 0, 0, 0],
|
|
)
|
|
|
|
if qkv_nodes is None:
|
|
logger.debug("fuse_qordered_attention: failed to match qkv path")
|
|
return
|
|
|
|
(_, projection_matmul, reshape_qkv, transpose_qkv, dequantize_qkv, quantize_qkv, matmul_qkv) = qkv_nodes
|
|
|
|
# Make sure the Q/DQ has the proper zero points and constant per-tensor scales
|
|
if not FusionUtils.check_qdq_node_for_fusion(quantize_qkv, self.model):
|
|
return
|
|
|
|
if not FusionUtils.check_qdq_node_for_fusion(dequantize_qkv, self.model):
|
|
return
|
|
|
|
# Identify the root input to the Attention node
|
|
other_inputs = []
|
|
for i, input in enumerate(start_node.input):
|
|
if input not in output_name_to_node:
|
|
continue
|
|
|
|
if input == qkv_nodes[0].output[0]:
|
|
continue
|
|
|
|
other_inputs.append(input)
|
|
|
|
if len(other_inputs) != 1:
|
|
return
|
|
|
|
root_input = other_inputs[0]
|
|
|
|
# V nodes
|
|
v_nodes = self.model.match_parent_path(
|
|
matmul_qkv,
|
|
["Transpose", "Reshape", "DequantizeLinear", "QuantizeLinear", "Add", "MatMul"],
|
|
[1, 0, 0, 0, 0, None],
|
|
)
|
|
|
|
if v_nodes is None:
|
|
logger.debug("fuse_qordered_attention: failed to match v path")
|
|
return
|
|
|
|
(_, _, dequantize_v, quantize_v, add_v, matmul_v) = v_nodes
|
|
|
|
# Make sure the Q/DQ has the proper zero points and constant per-tensor scales
|
|
if not FusionUtils.check_qdq_node_for_fusion(quantize_v, self.model):
|
|
return
|
|
|
|
if not FusionUtils.check_qdq_node_for_fusion(dequantize_v, self.model):
|
|
return
|
|
|
|
# V MatMul weight
|
|
dequantize_v_matmul_weight = self.model.match_parent_path(matmul_v, ["DequantizeLinear"], [1])
|
|
|
|
if dequantize_v_matmul_weight is None:
|
|
logger.debug("fuse_qordered_attention: failed to match v path")
|
|
return
|
|
|
|
dequantize_v_matmul_weight = dequantize_v_matmul_weight[0]
|
|
|
|
if self.model.get_constant_value(dequantize_v_matmul_weight.input[0]) is None:
|
|
return
|
|
|
|
# Make sure the upstream DequantizeLinear-1 has the proper zero points and scales
|
|
# Per-channel scales are supported for weights alone
|
|
if not FusionUtils.check_qdq_node_for_fusion(dequantize_v_matmul_weight, self.model, False):
|
|
return
|
|
|
|
# QK nodes
|
|
qk_nodes = self.model.match_parent_path(
|
|
matmul_qkv,
|
|
[
|
|
"DequantizeLinear",
|
|
"QuantizeLinear",
|
|
"Softmax",
|
|
"Add",
|
|
"Div",
|
|
"DequantizeLinear",
|
|
"QuantizeLinear",
|
|
"MatMul",
|
|
],
|
|
[0, 0, 0, 0, None, 0, 0, 0],
|
|
)
|
|
|
|
if qk_nodes is None:
|
|
logger.debug("fuse_qordered_attention: failed to match qk path")
|
|
return
|
|
|
|
(
|
|
dequantize_qk_softmax,
|
|
quantize_qk_softmax,
|
|
softmax_qk,
|
|
add_qk,
|
|
div_qk,
|
|
dequantize_qk,
|
|
quantize_qk,
|
|
matmul_qk,
|
|
) = qk_nodes
|
|
|
|
# Make sure the Q/DQ has the proper zero points and constant per-tensor scales
|
|
if not FusionUtils.check_qdq_node_for_fusion(quantize_qk_softmax, self.model):
|
|
return
|
|
|
|
if not FusionUtils.check_qdq_node_for_fusion(dequantize_qk_softmax, self.model):
|
|
return
|
|
|
|
if not FusionUtils.check_qdq_node_for_fusion(quantize_qk, self.model):
|
|
return
|
|
|
|
if not FusionUtils.check_qdq_node_for_fusion(dequantize_qk, self.model):
|
|
return
|
|
|
|
# Q nodes
|
|
q_nodes = self.model.match_parent_path(
|
|
matmul_qk,
|
|
["Transpose", "Reshape", "DequantizeLinear", "QuantizeLinear", "Add", "MatMul"],
|
|
[0, 0, 0, 0, 0, None],
|
|
)
|
|
|
|
if q_nodes is None:
|
|
logger.debug("fuse_qordered_attention: failed to match q path")
|
|
return
|
|
|
|
(_, reshape_q, dequantize_q, quantize_q, add_q, matmul_q) = q_nodes
|
|
|
|
# Make sure the Q/DQ has the proper zero points and constant per-tensor scales
|
|
if not FusionUtils.check_qdq_node_for_fusion(quantize_q, self.model):
|
|
return
|
|
|
|
if not FusionUtils.check_qdq_node_for_fusion(dequantize_q, self.model):
|
|
return
|
|
|
|
# Q MatMul weight
|
|
dequantize_q_matmul_weight = self.model.match_parent_path(matmul_q, ["DequantizeLinear"], [1])
|
|
|
|
if dequantize_q_matmul_weight is None:
|
|
logger.debug("fuse_qordered_attention: failed to match q path")
|
|
return
|
|
|
|
dequantize_q_matmul_weight = dequantize_q_matmul_weight[0]
|
|
|
|
if self.model.get_constant_value(dequantize_q_matmul_weight.input[0]) is None:
|
|
return
|
|
|
|
# Make sure the upstream DequantizeLinear-1 has the proper zero points and scales
|
|
# Per-channel scales are supported for weights alone
|
|
if not FusionUtils.check_qdq_node_for_fusion(dequantize_q_matmul_weight, self.model, False):
|
|
return
|
|
|
|
# K nodes
|
|
k_nodes = self.model.match_parent_path(
|
|
matmul_qk,
|
|
["Transpose", "Reshape", "DequantizeLinear", "QuantizeLinear", "Add", "MatMul"],
|
|
[1, 0, 0, 0, 0, None],
|
|
)
|
|
|
|
if k_nodes is None:
|
|
logger.debug("fuse_qordered_attention: failed to match k path")
|
|
return
|
|
|
|
(_, _, dequantize_k, quantize_k, add_k, matmul_k) = k_nodes
|
|
|
|
# Make sure the Q/DQ has the proper zero points and constant per-tensor scales
|
|
if not FusionUtils.check_qdq_node_for_fusion(quantize_k, self.model):
|
|
return
|
|
|
|
if not FusionUtils.check_qdq_node_for_fusion(dequantize_k, self.model):
|
|
return
|
|
|
|
# K MatMul weight
|
|
dequantize_k_matmul_weight = self.model.match_parent_path(matmul_k, ["DequantizeLinear"], [1])
|
|
|
|
if dequantize_k_matmul_weight is None:
|
|
logger.debug("fuse_qordered_attention: failed to match k path")
|
|
return
|
|
|
|
dequantize_k_matmul_weight = dequantize_k_matmul_weight[0]
|
|
|
|
if self.model.get_constant_value(dequantize_k_matmul_weight.input[0]) is None:
|
|
return
|
|
|
|
# Make sure the upstream DequantizeLinear-1 has the proper zero points and scales
|
|
# Per-channel scales are supported for weights alone
|
|
if not FusionUtils.check_qdq_node_for_fusion(dequantize_k_matmul_weight, self.model, False):
|
|
return
|
|
|
|
# Mask nodes
|
|
mask_nodes = self.model.match_parent_path(
|
|
add_qk, ["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"], [None, 0, 1, 0, 0]
|
|
)
|
|
|
|
if mask_nodes is None:
|
|
logger.debug("fuse_qordered_attention: failed to match mask_nodes path")
|
|
return
|
|
|
|
# Ascertain `qkv_hidden_sizes` attribute value
|
|
q_weight = self.model.get_initializer(dequantize_q_matmul_weight.input[0])
|
|
k_weight = self.model.get_initializer(dequantize_k_matmul_weight.input[0])
|
|
v_weight = self.model.get_initializer(dequantize_v_matmul_weight.input[0])
|
|
|
|
qw = NumpyHelper.to_array(q_weight)
|
|
kw = NumpyHelper.to_array(k_weight)
|
|
vw = NumpyHelper.to_array(v_weight)
|
|
|
|
qw_out_size = np.prod(qw.shape[1:])
|
|
kw_out_size = np.prod(kw.shape[1:])
|
|
vw_out_size = np.prod(vw.shape[1:])
|
|
|
|
# Form QOrderedAttention node
|
|
if matmul_v.input[0] == root_input and matmul_q.input[0] == root_input and matmul_k.input[0] == root_input:
|
|
mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
|
|
|
|
# Ascertain `num_heads` and `hidden_size`
|
|
num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
|
|
|
|
# Formulate the inputs
|
|
# Actual quantized input
|
|
attention_inputs = [dequantize_input.input[0]]
|
|
attention_inputs.append(dequantize_input.input[1])
|
|
|
|
attention_inputs.append(dequantize_q.input[1])
|
|
attention_inputs.append(dequantize_k.input[1])
|
|
attention_inputs.append(dequantize_v.input[1])
|
|
|
|
attention_inputs.append(dequantize_q_matmul_weight.input[0])
|
|
attention_inputs.append(dequantize_k_matmul_weight.input[0])
|
|
attention_inputs.append(dequantize_v_matmul_weight.input[0])
|
|
|
|
attention_inputs.append(dequantize_q_matmul_weight.input[1])
|
|
attention_inputs.append(dequantize_k_matmul_weight.input[1])
|
|
attention_inputs.append(dequantize_v_matmul_weight.input[1])
|
|
|
|
if self.model.get_initializer(add_q.input[0]):
|
|
attention_inputs.append(add_q.input[0])
|
|
else: # second input is the constant bias
|
|
attention_inputs.append(add_q.input[1])
|
|
|
|
if self.model.get_initializer(add_k.input[0]):
|
|
attention_inputs.append(add_k.input[0])
|
|
else: # second input is the constant bias
|
|
attention_inputs.append(add_k.input[1])
|
|
|
|
if self.model.get_initializer(add_v.input[0]):
|
|
attention_inputs.append(add_v.input[0])
|
|
else: # second input is the constant bias
|
|
attention_inputs.append(add_v.input[1])
|
|
|
|
attention_inputs.append(quantize_qk.input[1])
|
|
attention_inputs.append(quantize_qk_softmax.input[1])
|
|
attention_inputs.append(dequantize_qkv.input[1])
|
|
|
|
# Mask input
|
|
if mask_index is not None:
|
|
attention_inputs.append(mask_index)
|
|
else:
|
|
attention_inputs.append("")
|
|
|
|
# The MatMul weight 'B' and 'bias' need some post-processing
|
|
# Transpose weight 'B' from order ROW to order COL
|
|
# This offline transpose is needed only while using the CUDA EP
|
|
# TODO: Make this fusion logic EP-agnostic ?
|
|
q_weight_tensor = self.model.get_initializer(dequantize_q_matmul_weight.input[0])
|
|
FusionUtils.transpose_2d_int8_tensor(q_weight_tensor)
|
|
|
|
k_weight_tensor = self.model.get_initializer(dequantize_k_matmul_weight.input[0])
|
|
FusionUtils.transpose_2d_int8_tensor(k_weight_tensor)
|
|
|
|
v_weight_tensor = self.model.get_initializer(dequantize_v_matmul_weight.input[0])
|
|
FusionUtils.transpose_2d_int8_tensor(v_weight_tensor)
|
|
|
|
# Name and create Attention node
|
|
attention_node_name = self.model.create_node_name("QOrderedAttention")
|
|
|
|
attention_node = helper.make_node(
|
|
"QOrderedAttention",
|
|
inputs=attention_inputs,
|
|
outputs=[reshape_qkv.output[0]],
|
|
name=attention_node_name,
|
|
)
|
|
|
|
self.model.replace_node_input(dequantize_qkv, dequantize_qkv.input[0], attention_node.output[0])
|
|
self.model.replace_node_input(projection_matmul, projection_matmul.input[0], dequantize_qkv.output[0])
|
|
|
|
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
|
|
attention_node.attribute.extend([helper.make_attribute("order_input", 1)])
|
|
attention_node.attribute.extend([helper.make_attribute("order_weight", 0)])
|
|
attention_node.attribute.extend([helper.make_attribute("order_output", 1)])
|
|
attention_node.attribute.extend(
|
|
[helper.make_attribute("qkv_hidden_sizes", [qw_out_size, kw_out_size, vw_out_size])]
|
|
)
|
|
|
|
attention_node.domain = "com.microsoft"
|
|
|
|
self.nodes_to_add.append(attention_node)
|
|
self.node_name_to_graph_name[attention_node.name] = self.this_graph_name
|
|
|
|
self.nodes_to_remove.extend([reshape_qkv, transpose_qkv, quantize_qkv, matmul_qkv])
|
|
self.nodes_to_remove.extend(qk_nodes)
|
|
self.nodes_to_remove.extend(q_nodes)
|
|
self.nodes_to_remove.extend(k_nodes)
|
|
self.nodes_to_remove.extend(v_nodes)
|
|
self.nodes_to_remove.extend(
|
|
[dequantize_q_matmul_weight, dequantize_k_matmul_weight, dequantize_v_matmul_weight]
|
|
)
|
|
|
|
# Use prune graph to remove mask nodes since they are shared by all attention nodes.
|
|
# self.nodes_to_remove.extend(mask_nodes)
|
|
self.prune_graph = True
|