You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
359 lines
13 KiB
359 lines
13 KiB
# -------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
# --------------------------------------------------------------------------
|
|
from logging import getLogger
|
|
|
|
import numpy as np
|
|
from fusion_base import Fusion
|
|
from fusion_gpt_attention import FusionGptAttentionPastBase
|
|
from fusion_utils import FusionUtils
|
|
from onnx import TensorProto, helper, numpy_helper
|
|
from onnx_model import OnnxModel
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
def is_close(value, expected_value):
|
|
return abs(value - expected_value) <= 1e-6
|
|
|
|
|
|
class FusionGptAttentionMegatron(FusionGptAttentionPastBase):
|
|
"""
|
|
Fuse GPT-2 Attention with past state subgraph from Megatron into one Attention node.
|
|
"""
|
|
|
|
def __init__(self, model: OnnxModel, num_heads: int):
|
|
super().__init__(model, num_heads)
|
|
|
|
def fuse_attention_node(
|
|
self,
|
|
matmul_before_split,
|
|
add_before_split,
|
|
past,
|
|
present,
|
|
input,
|
|
reshape_qkv,
|
|
mask,
|
|
):
|
|
attention_node_name = self.model.create_node_name("GptAttention")
|
|
int32_mask = self.cast_attention_mask(mask)
|
|
output = reshape_qkv.output[0]
|
|
i = 1 if (add_before_split.input[0] == matmul_before_split.output[0]) else 0
|
|
attention_node = helper.make_node(
|
|
"Attention",
|
|
inputs=[
|
|
input,
|
|
matmul_before_split.input[1],
|
|
add_before_split.input[i],
|
|
int32_mask,
|
|
past,
|
|
],
|
|
outputs=[output, present],
|
|
name=attention_node_name,
|
|
)
|
|
attention_node.domain = "com.microsoft"
|
|
attention_node.attribute.extend(
|
|
[
|
|
helper.make_attribute("num_heads", self.num_heads),
|
|
helper.make_attribute("unidirectional", 0), # unidirectional shall not be ON for 4D attention mask
|
|
]
|
|
)
|
|
if self.mask_filter_value is not None:
|
|
attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
|
|
|
|
nodes_to_add = [attention_node]
|
|
self.nodes_to_add.extend(nodes_to_add)
|
|
|
|
for node in nodes_to_add:
|
|
self.node_name_to_graph_name[node.name] = self.this_graph_name
|
|
|
|
self.nodes_to_remove.append(reshape_qkv)
|
|
|
|
# we rely on prune_graph() to clean old subgraph nodes
|
|
self.prune_graph = True
|
|
|
|
def match_mask(self, sub_qk, mul_qk, matmul_qk, layernorm_before_attention):
|
|
mask_nodes = self.model.match_parent_path(
|
|
sub_qk, ["Mul", "Sub", "Slice", "Slice"], [1, 0, 1, 0]
|
|
) # yapf: disable
|
|
if mask_nodes is None:
|
|
logger.debug("fuse_attention: failed to match unidirectional mask path")
|
|
return None
|
|
(mul_mask, sub_mask, last_slice_mask, slice_mask) = mask_nodes
|
|
|
|
if len(mask_nodes) > 1 and mask_nodes[0].op_type == "Mul":
|
|
_, mul_val = self.model.get_constant_input(mask_nodes[0])
|
|
if mul_val != 10000:
|
|
self.mask_filter_value = -mul_val
|
|
|
|
if mul_qk.input[1] != last_slice_mask.output[0]:
|
|
logger.debug("fuse_attention failed: mul_qk.input[1] != last_slice_mask.output[0]")
|
|
return None
|
|
|
|
if not self.utils.check_node_input_value(mul_mask, 1, 10000.0):
|
|
logger.debug("fuse_attention failed: mul_mask input 1 is not constant 10000.0")
|
|
return None
|
|
|
|
if not self.utils.check_node_input_value(sub_mask, 0, 1.0):
|
|
logger.debug("fuse_attention failed: sub_mask input 0 is not constant 1.0")
|
|
return None
|
|
|
|
if not self.model.find_graph_input(slice_mask.input[0]):
|
|
logger.info("expect slick_mask input 0 to be graph input")
|
|
return None
|
|
|
|
if not self.utils.check_node_input_value(last_slice_mask, 1, [0]):
|
|
logger.debug("fuse_attention failed: last_slice_mask input 1 (starts) is not constant [0]")
|
|
return None
|
|
|
|
if not self.utils.check_node_input_value(last_slice_mask, 3, [3]):
|
|
logger.debug("fuse_attention failed: last_slice_mask input 3 (axes) is not constant [3]")
|
|
return False
|
|
|
|
if not self.utils.check_node_input_value(last_slice_mask, 4, [1]):
|
|
logger.debug("fuse_attention failed: last_slice_mask input 4 (steps) is not constant [1]")
|
|
return False
|
|
|
|
if not self.utils.check_node_input_value(slice_mask, 3, [2]):
|
|
logger.debug("fuse_attention failed: slice_mask input 3 (axes) is not constant [2]")
|
|
return None
|
|
|
|
if not self.utils.check_node_input_value(slice_mask, 4, [1]):
|
|
logger.debug("fuse_attention failed: slice_mask input 4 (steps) is not constant [1]")
|
|
return None
|
|
|
|
last_slice_path = self.model.match_parent_path(
|
|
last_slice_mask, ["Unsqueeze", "Gather", "Shape", "MatMul"], [2, 0, 0, 0]
|
|
)
|
|
if last_slice_path is None or last_slice_path[-1] != matmul_qk:
|
|
logger.debug("fuse_attention: failed to match last slice path")
|
|
return None
|
|
|
|
first_slice_path = self.model.match_parent_path(
|
|
slice_mask, ["Unsqueeze", "Gather", "Shape", "MatMul"], [2, 0, 0, 0]
|
|
)
|
|
if first_slice_path is None or first_slice_path[-1] != matmul_qk:
|
|
logger.debug("fuse_attention: failed to match first slice path")
|
|
return None
|
|
|
|
first_slice_sub = self.model.match_parent_path(
|
|
slice_mask,
|
|
["Unsqueeze", "Sub", "Gather", "Shape", "MatMul"],
|
|
[1, 0, 0, 0, 0],
|
|
)
|
|
if first_slice_sub is None or first_slice_sub[-1] != matmul_qk:
|
|
logger.debug("fuse_attention: failed to match last slice sub path")
|
|
return None
|
|
|
|
first_slice_sub_1 = self.model.match_parent_path(
|
|
slice_mask,
|
|
["Unsqueeze", "Sub", "Gather", "Shape", "LayerNormalization"],
|
|
[1, 0, 1, 0, 0],
|
|
)
|
|
|
|
if first_slice_sub_1 is None:
|
|
first_slice_sub_1 = self.model.match_parent_path(
|
|
slice_mask,
|
|
["Unsqueeze", "Sub", "Gather", "Shape", "SkipLayerNormalization"],
|
|
[1, 0, 1, 0, 0],
|
|
)
|
|
|
|
if first_slice_sub_1 is None or first_slice_sub_1[-1] != layernorm_before_attention:
|
|
logger.debug("fuse_attention: failed to match last slice sub path 1")
|
|
return None
|
|
|
|
return slice_mask.input[0]
|
|
|
|
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
|
|
past = None
|
|
present = None
|
|
|
|
is_normalize_node_skiplayernorm = normalize_node.op_type == "SkipLayerNormalization"
|
|
qkv_nodes = None
|
|
|
|
if not is_normalize_node_skiplayernorm:
|
|
qkv_nodes = self.model.match_parent_path(
|
|
normalize_node,
|
|
["Add", "Add", "MatMul", "Reshape", "Transpose", "MatMul"],
|
|
[0, 1, None, 0, 0, 0],
|
|
output_name_to_node=output_name_to_node,
|
|
) # yapf: disable
|
|
else:
|
|
qkv_nodes = self.model.match_parent_path(
|
|
normalize_node,
|
|
["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
|
|
[1, None, 0, 0, 0],
|
|
output_name_to_node=output_name_to_node,
|
|
) # yapf: disable
|
|
|
|
if qkv_nodes is None:
|
|
return
|
|
|
|
skip_input = None
|
|
if not is_normalize_node_skiplayernorm:
|
|
(
|
|
add_skip,
|
|
add_after_attention,
|
|
matmul_after_attention,
|
|
reshape_qkv,
|
|
transpose_qkv,
|
|
matmul_qkv,
|
|
) = qkv_nodes
|
|
|
|
skip_input = add_skip.input[0]
|
|
else:
|
|
(
|
|
add_after_attention,
|
|
matmul_after_attention,
|
|
reshape_qkv,
|
|
transpose_qkv,
|
|
matmul_qkv,
|
|
) = qkv_nodes
|
|
|
|
skip_input = normalize_node.input[0]
|
|
|
|
v_nodes = self.model.match_parent_path(
|
|
matmul_qkv,
|
|
[
|
|
"Concat",
|
|
"Transpose",
|
|
"Reshape",
|
|
"Split",
|
|
"Add",
|
|
"MatMul",
|
|
"LayerNormalization",
|
|
],
|
|
[1, 1, 0, 0, 0, None, 0],
|
|
) # yapf: disable
|
|
|
|
if v_nodes is None:
|
|
v_nodes = self.model.match_parent_path(
|
|
matmul_qkv,
|
|
[
|
|
"Concat",
|
|
"Transpose",
|
|
"Reshape",
|
|
"Split",
|
|
"Add",
|
|
"MatMul",
|
|
"SkipLayerNormalization",
|
|
],
|
|
[1, 1, 0, 0, 0, None, 0],
|
|
) # yapf: disable
|
|
|
|
if v_nodes is None:
|
|
logger.debug("fuse_attention: failed to match v path")
|
|
return
|
|
(
|
|
concat_v,
|
|
transpose_v,
|
|
reshape_v,
|
|
split_v,
|
|
add_before_split,
|
|
matmul_before_split,
|
|
layernorm_before_attention,
|
|
) = v_nodes
|
|
|
|
if (
|
|
layernorm_before_attention.op_type == "LayerNormalization"
|
|
and skip_input != layernorm_before_attention.input[0]
|
|
):
|
|
logger.debug("fuse_attention: skip_input != layernorm_before_attention.input[0]")
|
|
return
|
|
|
|
if (
|
|
layernorm_before_attention.op_type == "SkipLayerNormalization"
|
|
and skip_input != layernorm_before_attention.output[3]
|
|
):
|
|
logger.debug("fuse_attention: skip_input != layernorm_before_attention.input[0]")
|
|
return
|
|
|
|
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Sub", "Mul", "MatMul"], [0, 0, 0, 0])
|
|
if qk_nodes is None:
|
|
logger.debug("fuse_attention: failed to match qk path")
|
|
return None
|
|
(softmax_qk, sub_qk, mul_qk, matmul_qk) = qk_nodes
|
|
if self.model.get_node_attribute(softmax_qk, "axis") != 3:
|
|
logger.debug("fuse_attention failed: softmax_qk axis != 3")
|
|
return None
|
|
|
|
attention_mask = self.match_mask(sub_qk, mul_qk, matmul_qk, layernorm_before_attention)
|
|
|
|
q_nodes = self.model.match_parent_path(matmul_qk, ["Div", "Transpose", "Reshape", "Split"], [0, 0, 0, 0])
|
|
if q_nodes is None:
|
|
logger.debug("fuse_attention: failed to match q path")
|
|
return
|
|
(div_q, transpose_q, reshape_q, split_q) = q_nodes
|
|
if split_v != split_q:
|
|
logger.debug("fuse_attention: skip since split_v != split_q")
|
|
return
|
|
|
|
k_nodes = self.model.match_parent_path(
|
|
matmul_qk,
|
|
["Div", "Transpose", "Concat", "Transpose", "Reshape", "Split"],
|
|
[1, 0, 0, 1, 0, 0],
|
|
)
|
|
if k_nodes is None:
|
|
logger.debug("fuse_attention: failed to match k path")
|
|
return
|
|
(div_k, _, concat_k, transpose_k, reshape_k, split_k) = k_nodes
|
|
if split_v != split_k:
|
|
logger.debug("fuse_attention: skip since split_v != split_k")
|
|
return
|
|
|
|
i, value = self.model.get_constant_input(reshape_k)
|
|
if not (
|
|
isinstance(value, np.ndarray)
|
|
and list(value.shape) == [4]
|
|
and value[0] == 0
|
|
and value[1] == 0
|
|
and value[2] > 0
|
|
and value[3] > 0
|
|
):
|
|
logger.debug("fuse_attention: reshape constant input is not [0, 0, N, H]")
|
|
return
|
|
|
|
num_heads = value[2]
|
|
if num_heads != self.num_heads:
|
|
logger.info(f"Detected num_heads={num_heads}. Ignore user specified value {self.num_heads}")
|
|
self.num_heads = num_heads
|
|
|
|
hidden_size_per_head = value[3]
|
|
i, value = self.model.get_constant_input(div_k)
|
|
expected_value = float(np.sqrt(np.sqrt(hidden_size_per_head)))
|
|
if not is_close(value, expected_value):
|
|
logger.debug(f"fuse_attention: div_k value={value} expected={expected_value}")
|
|
return
|
|
|
|
i, value = self.model.get_constant_input(div_q)
|
|
if not is_close(value, expected_value):
|
|
logger.debug(f"fuse_attention: div_q value={value} expected={expected_value}")
|
|
return
|
|
|
|
# Match past and present paths
|
|
past = self.match_past_pattern_2(concat_k, concat_v, output_name_to_node)
|
|
if past is None:
|
|
logger.debug("fuse_attention: match past failed")
|
|
return
|
|
if not self.model.find_graph_input(past):
|
|
logger.debug("fuse_attention: past is not graph input.")
|
|
# For GPT2LMHeadModel_BeamSearchStep, there is an extra Gather node to select beam index so it is not graph input.
|
|
|
|
present = self.match_present(concat_v, input_name_to_nodes)
|
|
if present is None:
|
|
logger.debug("fuse_attention: match present failed")
|
|
return
|
|
if not self.model.find_graph_output(present):
|
|
logger.info("fuse_attention: expect present to be graph output")
|
|
return
|
|
|
|
self.fuse_attention_node(
|
|
matmul_before_split,
|
|
add_before_split,
|
|
past,
|
|
present,
|
|
layernorm_before_attention.output[0],
|
|
reshape_qkv,
|
|
attention_mask,
|
|
)
|