You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
542 lines
21 KiB
542 lines
21 KiB
# -------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
# --------------------------------------------------------------------------
|
|
from logging import getLogger
|
|
|
|
import numpy as np
|
|
from fusion_base import Fusion
|
|
from fusion_utils import FusionUtils
|
|
from onnx import TensorProto, helper, numpy_helper
|
|
from onnx_model import OnnxModel
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
class FusionGptAttentionPastBase(Fusion):
|
|
"""Base class for GPT Attention Fusion with past state"""
|
|
|
|
def __init__(self, model: OnnxModel, num_heads: int):
|
|
super().__init__(model, "Attention", ["LayerNormalization", "SkipLayerNormalization"], "with past")
|
|
self.num_heads = num_heads
|
|
self.utils = FusionUtils(model)
|
|
self.casted_attention_mask = {} # map from name of attention mask to the name that casted to int32
|
|
self.mask_filter_value = None
|
|
|
|
def match_past_pattern_1(self, concat_k, concat_v, output_name_to_node):
|
|
# Pattern 1:
|
|
# {past}
|
|
# / \
|
|
# / \
|
|
# Gather(axes=0, indices=0) Gather(indices=1)
|
|
# | |
|
|
# Transpose (perm=0,1,3,2) |
|
|
# | |
|
|
# Concat_k Concat_v
|
|
# | /
|
|
# Transpose (perm=0,1,3,2) /
|
|
# | /
|
|
# Unsqueeze Unsqueeze
|
|
# \ /
|
|
# \ /
|
|
# Concat
|
|
# |
|
|
# {present}
|
|
gather = self.model.get_parent(concat_v, 0, output_name_to_node)
|
|
if gather.op_type != "Gather":
|
|
logger.debug("match_past_pattern_1: expect Gather for past")
|
|
return None
|
|
|
|
if not self.model.find_constant_input(gather, 1) == 1:
|
|
logger.debug("match_past_pattern_1: expect indices=1 for Gather of past")
|
|
return None
|
|
past = gather.input[0]
|
|
|
|
parent = self.model.get_parent(concat_k, 0, output_name_to_node)
|
|
if parent.op_type == "Gather":
|
|
gather_past_k = parent
|
|
else:
|
|
past_k_nodes = self.model.match_parent_path(concat_k, ["Transpose", "Gather"], [0, 0])
|
|
if past_k_nodes is None:
|
|
logger.debug("match_past_pattern_1: failed match Transpose and Gather")
|
|
return None
|
|
gather_past_k = past_k_nodes[-1]
|
|
|
|
if not self.model.find_constant_input(gather_past_k, 0) == 1:
|
|
logger.debug("match_past_pattern_1: expect indices=0 for Gather k of past")
|
|
return None
|
|
past_k = gather_past_k.input[0]
|
|
if past != past_k:
|
|
logger.debug("match_past_pattern_1: expect past to be same")
|
|
return None
|
|
|
|
return past
|
|
|
|
def match_past_pattern_2(self, concat_k, concat_v, output_name_to_node):
|
|
# Pattern 2:
|
|
# Split (QKV)
|
|
# / | |
|
|
# / | +----------------------+
|
|
# | |
|
|
# | {past} |
|
|
# | | |
|
|
# Reshape Split Reshape
|
|
# | / \ |
|
|
# Transpose_k Squeeze Squeeze Transpose_v
|
|
# | | \ /
|
|
# +------|---+ \ /
|
|
# | | \ /
|
|
# Concat_k Concat_v
|
|
# | |
|
|
# Unsqueeze Unsqueeze
|
|
# \ /
|
|
# Concat
|
|
# |
|
|
# {present}
|
|
#
|
|
squeeze = self.model.get_parent(concat_v, 0, output_name_to_node)
|
|
if squeeze.op_type != "Squeeze":
|
|
logger.debug("match_past_pattern_2: expect Squeeze as parent of concat_v")
|
|
return None
|
|
|
|
split = self.model.get_parent(squeeze, 0, output_name_to_node)
|
|
if split.op_type != "Split":
|
|
logger.debug("match_past_pattern_2: expect Split for past path")
|
|
return None
|
|
|
|
opset_version = self.model.get_opset_version()
|
|
if opset_version < 13:
|
|
if not FusionUtils.check_node_attribute(squeeze, "axes", [0]):
|
|
logger.debug("match_past_pattern_2: axes != [0] for Squeeze in past path")
|
|
return None
|
|
|
|
if not FusionUtils.check_node_attribute(split, "split", [1, 1]):
|
|
logger.debug("match_past_pattern_2: split != [1, 1] for Split in past path")
|
|
return None
|
|
else:
|
|
if not self.utils.check_node_input_value(squeeze, 1, [0]):
|
|
logger.debug("match_past_pattern_2: axes != [0] for Squeeze in past path")
|
|
return None
|
|
|
|
if not self.utils.check_node_input_value(split, 1, [1, 1]):
|
|
logger.debug("match_past_pattern_2: split != [1, 1] for Split in past path")
|
|
return None
|
|
|
|
if not FusionUtils.check_node_attribute(split, "axis", 0, default_value=0):
|
|
logger.debug("match_past_pattern_2: attribute axis of Split are not expected in past path")
|
|
return None
|
|
past = split.input[0]
|
|
|
|
past_k_nodes = self.model.match_parent_path(concat_k, ["Squeeze", "Split"], [0, 0])
|
|
if past_k_nodes is None:
|
|
logger.debug("match_past_pattern_2: failed to match past_k_nodes path")
|
|
return None
|
|
past_k = past_k_nodes[-1].input[0]
|
|
|
|
if past != past_k:
|
|
logger.info("match_past_pattern_2: expect past to be same")
|
|
return None
|
|
|
|
return past
|
|
|
|
def match_present(self, concat_v, input_name_to_nodes):
|
|
unsqueeze_present_v = self.model.find_first_child_by_type(
|
|
concat_v, "Unsqueeze", input_name_to_nodes, recursive=False
|
|
)
|
|
if not unsqueeze_present_v:
|
|
logger.info("expect unsqueeze for present")
|
|
return None
|
|
concat_present = self.model.find_first_child_by_type(
|
|
unsqueeze_present_v, "Concat", input_name_to_nodes, recursive=False
|
|
)
|
|
if not concat_present:
|
|
logger.info("expect concat for present")
|
|
return None
|
|
|
|
present = concat_present.output[0]
|
|
return present
|
|
|
|
def cast_attention_mask(self, input_name):
|
|
if input_name in self.casted_attention_mask:
|
|
attention_mask_input_name = self.casted_attention_mask[input_name]
|
|
elif self.model.find_graph_input(input_name):
|
|
casted, attention_mask_input_name = self.utils.cast_graph_input_to_int32(input_name)
|
|
self.casted_attention_mask[input_name] = attention_mask_input_name
|
|
else:
|
|
attention_mask_input_name, cast_node = self.utils.cast_input_to_int32(input_name)
|
|
self.casted_attention_mask[input_name] = attention_mask_input_name
|
|
return attention_mask_input_name
|
|
|
|
|
|
class FusionGptAttention(FusionGptAttentionPastBase):
|
|
"""
|
|
Fuse GPT-2 Attention with past state subgraph into one Attention node.
|
|
"""
|
|
|
|
def __init__(self, model: OnnxModel, num_heads: int):
|
|
super().__init__(model, num_heads)
|
|
|
|
def create_attention_node(
|
|
self,
|
|
fc_weight,
|
|
fc_bias,
|
|
gemm_qkv,
|
|
past,
|
|
present,
|
|
input,
|
|
output,
|
|
mask,
|
|
is_unidirectional,
|
|
):
|
|
attention_node_name = self.model.create_node_name("GptAttention")
|
|
attention_node = helper.make_node(
|
|
"Attention",
|
|
inputs=[input, fc_weight, fc_bias, mask, past],
|
|
outputs=[attention_node_name + "_output", present],
|
|
name=attention_node_name,
|
|
)
|
|
attention_node.domain = "com.microsoft"
|
|
attention_node.attribute.extend(
|
|
[
|
|
helper.make_attribute("num_heads", self.num_heads),
|
|
helper.make_attribute("unidirectional", 1 if is_unidirectional else 0),
|
|
]
|
|
)
|
|
|
|
if self.mask_filter_value is not None:
|
|
attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
|
|
|
|
matmul_node = helper.make_node(
|
|
"MatMul",
|
|
inputs=[attention_node_name + "_output", gemm_qkv.input[1]],
|
|
outputs=[attention_node_name + "_matmul_output"],
|
|
name=attention_node_name + "_matmul",
|
|
)
|
|
|
|
add_node = helper.make_node(
|
|
"Add",
|
|
inputs=[attention_node_name + "_matmul_output", gemm_qkv.input[2]],
|
|
outputs=[output],
|
|
name=attention_node_name + "_add",
|
|
)
|
|
self.nodes_to_add.extend([attention_node, matmul_node, add_node])
|
|
self.node_name_to_graph_name[attention_node.name] = self.this_graph_name
|
|
self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name
|
|
self.node_name_to_graph_name[add_node.name] = self.this_graph_name
|
|
|
|
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
|
|
past = None
|
|
present = None
|
|
return_indice = []
|
|
|
|
is_normalize_node_skiplayernorm = normalize_node.op_type == "SkipLayerNormalization"
|
|
qkv_nodes = None
|
|
|
|
if not is_normalize_node_skiplayernorm:
|
|
qkv_nodes = self.model.match_parent_path(
|
|
normalize_node,
|
|
["Add", "Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"],
|
|
[0, None, 0, 0, 0, 0, 0],
|
|
output_name_to_node=output_name_to_node,
|
|
return_indice=return_indice,
|
|
) # yapf: disable
|
|
else:
|
|
qkv_nodes = self.model.match_parent_path(
|
|
normalize_node,
|
|
["Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"],
|
|
[None, 0, 0, 0, 0, 0],
|
|
output_name_to_node=output_name_to_node,
|
|
return_indice=return_indice,
|
|
) # yapf: disable
|
|
|
|
if qkv_nodes is None:
|
|
return
|
|
|
|
another_input = None
|
|
if not is_normalize_node_skiplayernorm:
|
|
(
|
|
add_qkv,
|
|
reshape_qkv,
|
|
gemm_qkv,
|
|
reshape_1,
|
|
reshape_2,
|
|
transpose_qkv,
|
|
matmul_qkv,
|
|
) = qkv_nodes
|
|
|
|
another_input = add_qkv.input[1 - return_indice[0]]
|
|
else:
|
|
(
|
|
reshape_qkv,
|
|
gemm_qkv,
|
|
reshape_1,
|
|
reshape_2,
|
|
transpose_qkv,
|
|
matmul_qkv,
|
|
) = qkv_nodes
|
|
|
|
v_nodes = self.model.match_parent_path(matmul_qkv, ["Concat", "Transpose", "Reshape", "Split"], [1, 1, 0, 0])
|
|
if v_nodes is None:
|
|
logger.debug("fuse_attention: failed to match v path")
|
|
return
|
|
(concat_v, transpose_v, reshape_v, split_fc) = v_nodes
|
|
|
|
# Try match pattern using Gemm + LayerNormalization
|
|
fc_nodes = self.model.match_parent_path(
|
|
split_fc,
|
|
["Reshape", "Gemm", "Reshape", "LayerNormalization"],
|
|
[0, 0, 0, 0],
|
|
output_name_to_node,
|
|
)
|
|
|
|
# Try match pattern using Gemm + SkipLayerNormalization
|
|
if fc_nodes is None:
|
|
fc_nodes = self.model.match_parent_path(
|
|
split_fc,
|
|
["Reshape", "Gemm", "Reshape", "SkipLayerNormalization"],
|
|
[0, 0, 0, 0],
|
|
output_name_to_node,
|
|
)
|
|
|
|
# Try match pattern using MatMul
|
|
if fc_nodes is None:
|
|
# LayerNormalization
|
|
fc_nodes = self.model.match_parent_path(
|
|
split_fc,
|
|
["Add", "MatMul", "LayerNormalization"],
|
|
[0, None, 0],
|
|
output_name_to_node,
|
|
)
|
|
|
|
# SkipLayerNormalization
|
|
if fc_nodes is None:
|
|
fc_nodes = self.model.match_parent_path(
|
|
split_fc,
|
|
["Add", "MatMul", "SkipLayerNormalization"],
|
|
[0, None, 0],
|
|
output_name_to_node,
|
|
)
|
|
|
|
if fc_nodes is None:
|
|
logger.debug("fuse_attention: failed to match fc path")
|
|
return
|
|
|
|
fc_weight = fc_nodes[1].input[1]
|
|
i, _ = self.model.get_constant_input(fc_nodes[0])
|
|
fc_bias = fc_nodes[0].input[i]
|
|
else:
|
|
fc_weight = fc_nodes[1].input[1]
|
|
fc_bias = fc_nodes[1].input[2]
|
|
|
|
layernorm_before_attention = fc_nodes[-1]
|
|
|
|
# `another_input` will be non-None only if
|
|
# (1) SkipLayerNorm fusion wasn't turned ON
|
|
# (2) SkipLayerNorm fusion was turned ON but upstream layer's LayerNorm + Add was not
|
|
# fused into a SkipLayerNorm. This can happen if the shapes to the Add node are different.
|
|
# So, keep the following check if SkipLayerNorm fusion is turned ON or OFF.
|
|
if another_input is not None and not another_input in layernorm_before_attention.input:
|
|
logger.debug("Upstream Add and (Skip)LayerNormalization shall have one same input")
|
|
return
|
|
|
|
is_unidirectional = True
|
|
slice_mask = None
|
|
input_mask_nodes = None
|
|
concat_k_to_match = None
|
|
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Sub", "Mul", "Div", "MatMul"], [0, 0, 0, 0, 0])
|
|
if qk_nodes is not None:
|
|
(softmax_qk, sub_qk, mul_qk, div_qk, matmul_qk) = qk_nodes
|
|
mask_nodes = self.model.match_parent_path(
|
|
sub_qk,
|
|
[
|
|
"Mul",
|
|
"Sub",
|
|
"Slice",
|
|
"Slice",
|
|
"Unsqueeze",
|
|
"Sub",
|
|
"Squeeze",
|
|
"Slice",
|
|
"Shape",
|
|
"Div",
|
|
],
|
|
[1, 0, 1, 0, 1, 0, 0, 0, 0, 0],
|
|
) # yapf: disable
|
|
if mask_nodes is None:
|
|
logger.debug("fuse_attention: failed to match unidirectional mask path")
|
|
return
|
|
div_mask = mask_nodes[-1]
|
|
slice_mask = mask_nodes[3]
|
|
|
|
if div_qk != div_mask:
|
|
logger.debug("fuse_attention: skip since div_qk != div_mask")
|
|
return
|
|
|
|
if len(mask_nodes) > 1 and mask_nodes[0].op_type == "Mul":
|
|
_, mul_val = self.model.get_constant_input(mask_nodes[0])
|
|
if mul_val != -10000:
|
|
self.mask_filter_value = -mul_val
|
|
|
|
else:
|
|
# New pattern for gpt2 from PyTorch 1.5.0 and Transformers 2.9.0.
|
|
i, qk_nodes, _ = self.model.match_parent_paths(
|
|
matmul_qkv,
|
|
[
|
|
(["Softmax", "Where", "Div", "MatMul"], [0, 0, 1, 0]),
|
|
(["Softmax", "Add", "Where", "Div", "MatMul"], [0, 0, None, 1, 0]),
|
|
],
|
|
output_name_to_node,
|
|
)
|
|
if qk_nodes is None:
|
|
logger.debug("fuse_attention: failed to match qk nodes")
|
|
return
|
|
|
|
where_qk = qk_nodes[-3]
|
|
div_qk = qk_nodes[-2]
|
|
matmul_qk = qk_nodes[-1]
|
|
|
|
if i == 1:
|
|
add_qk = qk_nodes[1]
|
|
_, input_mask_nodes, _ = self.model.match_parent_paths(
|
|
add_qk,
|
|
[
|
|
(
|
|
["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze", "Reshape"],
|
|
[None, 0, 1, 0, 0, 0],
|
|
),
|
|
(
|
|
["Mul", "Sub", "Unsqueeze", "Unsqueeze", "Reshape"],
|
|
[None, 0, 1, 0, 0],
|
|
),
|
|
(
|
|
["Mul", "Sub", "Unsqueeze", "Unsqueeze"],
|
|
[None, 0, 1, 0],
|
|
), # useless cast and reshape are removed.
|
|
],
|
|
output_name_to_node,
|
|
) # yapf: disable
|
|
if input_mask_nodes is None:
|
|
logger.debug("fuse_attention: failed to match input attention mask path")
|
|
return
|
|
if len(input_mask_nodes) > 1 and input_mask_nodes[0].op_type == "Mul":
|
|
_, mul_val = self.model.get_constant_input(input_mask_nodes[0])
|
|
if mul_val != -10000:
|
|
self.mask_filter_value = mul_val
|
|
|
|
mask_nodes = self.model.match_parent_path(
|
|
where_qk,
|
|
[
|
|
"Cast",
|
|
"Slice",
|
|
"Slice",
|
|
"Unsqueeze",
|
|
"Sub",
|
|
"Squeeze",
|
|
"Slice",
|
|
"Shape",
|
|
],
|
|
[0, 0, 0, 1, 0, 0, 0, 0],
|
|
output_name_to_node,
|
|
) # yapf: disable
|
|
if mask_nodes is None:
|
|
# TODO: match mask path for GPT2LMHeadModel_BeamSearchStep.
|
|
logger.debug("fuse_attention: failed to match mask path")
|
|
return
|
|
|
|
slice_mask = mask_nodes[2]
|
|
|
|
div_or_concat = self.model.get_parent(mask_nodes[-1], 0, output_name_to_node)
|
|
if div_or_concat.op_type == "Div":
|
|
div_mask = div_or_concat
|
|
if div_qk != div_mask:
|
|
logger.debug("fuse_attention: skip since div_qk != div_mask")
|
|
return
|
|
elif div_or_concat.op_type == "Concat":
|
|
concat_k_to_match = div_or_concat
|
|
else:
|
|
logger.debug("fuse_attention: failed to match mask path")
|
|
|
|
# Validate that the mask data is either lower triangular (unidirectional) or all ones
|
|
mask_data = numpy_helper.to_array(self.model.get_initializer(slice_mask.input[0]))
|
|
if not (
|
|
len(mask_data.shape) == 4 and mask_data.shape[:2] == (1, 1) and mask_data.shape[2] == mask_data.shape[3]
|
|
):
|
|
logger.debug("fuse_attention: skip since mask shape is not 1x1xWxW")
|
|
return
|
|
if np.allclose(mask_data, np.ones_like(mask_data)):
|
|
is_unidirectional = False
|
|
elif not np.allclose(mask_data, np.tril(np.ones_like(mask_data))):
|
|
logger.debug("fuse_attention: skip since mask is neither lower triangular nor ones")
|
|
return
|
|
|
|
q_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Split"], [0, 0, 0])
|
|
if q_nodes is None:
|
|
logger.debug("fuse_attention: failed to match q path")
|
|
return
|
|
(transpose_q, reshape_q, split_q) = q_nodes
|
|
if split_fc != split_q:
|
|
logger.debug("fuse_attention: skip since split_fc != split_q")
|
|
return
|
|
|
|
k_nodes = self.model.match_parent_path(matmul_qk, ["Concat", "Transpose", "Reshape", "Split"], [1, 1, 0, 0])
|
|
if k_nodes is None:
|
|
# This pattern is from pytorch 1.7.1 and transformers 4.6.1
|
|
k_nodes = self.model.match_parent_path(
|
|
matmul_qk,
|
|
["Transpose", "Concat", "Transpose", "Reshape", "Split"],
|
|
[1, 0, 1, 0, 0],
|
|
)
|
|
if k_nodes is None:
|
|
logger.debug("fuse_attention: failed to match k path")
|
|
return
|
|
else:
|
|
(_, concat_k, transpose_k, reshape_k, split_k) = k_nodes
|
|
else:
|
|
(concat_k, transpose_k, reshape_k, split_k) = k_nodes
|
|
if split_fc != split_k:
|
|
logger.debug("fuse_attention: skip since split_fc != split_k")
|
|
return
|
|
|
|
if concat_k_to_match and concat_k != concat_k_to_match:
|
|
logger.debug("fuse_attention: skip since concat_k != concat_k_to_match")
|
|
return
|
|
|
|
attention_mask_input_name = ""
|
|
if input_mask_nodes is not None:
|
|
input_name = input_mask_nodes[-1].input[0]
|
|
attention_mask_input_name = self.cast_attention_mask(input_name)
|
|
|
|
# Match past and present paths
|
|
past = self.match_past_pattern_1(concat_k, concat_v, output_name_to_node) or self.match_past_pattern_2(
|
|
concat_k, concat_v, output_name_to_node
|
|
)
|
|
if past is None:
|
|
logger.info("fuse_attention: failed to match past path")
|
|
return
|
|
if not self.model.find_graph_input(past):
|
|
logger.debug("past is not graph input.")
|
|
# For GPT2LMHeadModel_BeamSearchStep, there is an extra Gather node to select beam index so it is not graph input.
|
|
|
|
present = self.match_present(concat_v, input_name_to_nodes)
|
|
if present is None:
|
|
logger.info("fuse_attention: failed to match present path")
|
|
return
|
|
if not self.model.find_graph_output(present):
|
|
logger.info("expect present to be graph output")
|
|
return
|
|
|
|
self.create_attention_node(
|
|
fc_weight,
|
|
fc_bias,
|
|
gemm_qkv,
|
|
past,
|
|
present,
|
|
layernorm_before_attention.output[0],
|
|
reshape_qkv.output[0],
|
|
attention_mask_input_name,
|
|
is_unidirectional,
|
|
)
|
|
|
|
# we rely on prune_graph() to clean old subgraph nodes:
|
|
# qk_nodes + q_nodes + k_nodes + v_nodes + mask_nodes + [reshape_qkv, transpose_qkv, matmul_qkv]
|
|
self.prune_graph = True
|