m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

782 lines
33 KiB

# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger
from typing import Dict, List, Tuple, Union
from fusion_base import Fusion
from fusion_utils import FusionUtils
from onnx import NodeProto, TensorProto, helper
from onnx_model import OnnxModel
logger = getLogger(__name__)
class FusionEmbedLayerNoMask(Fusion):
"""
Fuse embedding layer into one node (EmbedLayerNormalization).
It supports the following model types: BERT, DistilBert, ALBert.
"""
def __init__(self, model: OnnxModel, description: str = "no mask"):
super().__init__(
model,
"EmbedLayerNormalization",
["LayerNormalization", "SkipLayerNormalization"],
description,
)
self.utils = FusionUtils(model)
self.shape_infer_helper = self.model.infer_runtime_shape({}, update=True)
# The following will be reset in each fuse call of FusionEmbedLayerNormalization
self.attention = None
self.embed_node = None
def match_two_gather(self, add: NodeProto) -> Union[None, Tuple[NodeProto, NodeProto]]:
gather_0_path = self.model.match_parent_path(add, ["Gather"], [0])
if gather_0_path is None:
return None
gather_1_path = self.model.match_parent_path(add, ["Gather"], [1])
if gather_1_path is None:
return None
return gather_0_path[0], gather_1_path[0]
def check_attention_subgraph(
self,
layernorm: NodeProto,
input_name_to_nodes: Dict[str, List[NodeProto]],
is_distil_bert: bool,
) -> bool:
"""Check that LayerNormalization has a child of Attention node or subgraph like Attention.
Args:
layernorm (NodeProto): LayerNormalization node
input_name_to_nodes (Dict[str, List[NodeProto]]): map from input name to nodes
is_distil_bert (bool): whether it is DistilBert or not
Returns:
bool: whether there is Attention node or subgraph like Attention
"""
self.attention = self.model.find_first_child_by_type(
layernorm, "Attention", input_name_to_nodes, recursive=False
)
if self.attention is not None:
return True
if layernorm.output[0] not in input_name_to_nodes:
return False
children = input_name_to_nodes[layernorm.output[0]]
children_types = sorted([child.op_type for child in children])
# Try find MultiHeadAttention
if children_types == ["MatMul", "MatMul", "MatMul", "SkipLayerNormalization"]:
for node in children:
if node.op_type == "SkipLayerNormalization":
path1 = self.model.match_parent_path(
node,
["Add", "MatMul", "MultiHeadAttention", "MatMul"],
[None, None, 0, 0],
)
if path1 is not None and path1[-1].input[0] == layernorm.output[0]:
self.cross_attention = path1[2]
return True
# In case user disables attention fusion, check whether subgraph looks like Attention.
# For Albert, there is MatMul+Add after embedding layer before attention.
if len(children) == 1 and children[0].op_type == "MatMul" and children[0].output[0] in input_name_to_nodes:
grandchildren = input_name_to_nodes[children[0].output[0]]
if (
len(grandchildren) == 1
and grandchildren[0].op_type == "Add"
and grandchildren[0].output[0] in input_name_to_nodes
):
nodes = input_name_to_nodes[grandchildren[0].output[0]]
for node in nodes:
if node.op_type == "Attention":
self.attention = node
return True
children_types = sorted([child.op_type for child in nodes])
# Two Shape nodes might be merged by ORT
if is_distil_bert:
# SkipLayerNormailization might exist when model has been optimized by ORT first.
if (
children_types != ["MatMul", "MatMul", "MatMul", "Shape", "SkipLayerNormalization"]
and children_types != ["Add", "MatMul", "MatMul", "MatMul", "Shape", "Shape"]
and children_types != ["Add", "MatMul", "MatMul", "MatMul", "Shape"]
):
logger.debug("No Attention like subgraph in children of LayerNormalization")
return False
else:
if children_types != ["Add", "MatMul", "MatMul", "MatMul",] and children_types != [
"MatMul",
"MatMul",
"MatMul",
"SkipLayerNormalization",
]:
logger.debug("No Attention like subgraph in children of LayerNormalization")
return False
return True
def match_position_embedding_distilbert(self, position_embedding_gather, input_ids, output_name_to_node):
""" Match position embedding path from input_ids to Gather for DistilBert.
Pattern is like the following:
(input_ids)
|
Shape
| \
| Gather (indices=1)
| |
| Cast (optional)
| |
| Range (start=0, end=*, delta=1)
| |
| Unsqueeze
| /
Expand
|
Gather
"""
# remove after tests pass
path1 = self.model.match_parent_path(position_embedding_gather, ["Expand", "Shape"], [1, 1])
if path1 is None:
path1 = self.model.match_parent_path(
position_embedding_gather,
["Expand", "Where", "Reshape", "Shape"],
[1, 1, 2, 0],
)
if path1 is None:
return False
expand, shape = path1[0], path1[-1]
if shape.input[0] != input_ids:
return False
_, path2, _ = self.model.match_parent_paths(
expand,
[
(["Unsqueeze", "Range", "Cast", "Gather", "Shape"], [0, 0, 1, 0, 0]),
(["Unsqueeze", "Range", "Gather", "Shape"], [0, 0, 1, 0]),
],
output_name_to_node,
)
if path2 is None:
return False
range_node = path2[1]
if not (
self.utils.check_node_input_value(range_node, 0, 0) and self.utils.check_node_input_value(range_node, 2, 1)
):
return False
gather_node = path2[-2]
if not (self.utils.check_node_input_value(gather_node, 1, 1)):
return False
shape_node = path2[-1]
if shape_node.input[0] != input_ids:
return False
return True
def match_position_embedding_roberta(self, position_embedding_gather, input_ids, output_name_to_node):
"""Match position embedding path from input_ids to Gather for Roberta.
Roberta Embedding Layer Pattern (* is optional since it might be removed by ORT, ? is the padding word id):
(input_ids) --> Equal(B=?) -- Not -- Cast(to=6) -- CumSum(axis=1) -- Mul -- Cast(to=7) -- Add(B=1) -- Cast(to=7)* --> Gather
| ^
V |
+------------------------------+
Roberta new pattern from transformers v4.9:
(input_ids) --> Equal(B=?) -- Not -- Cast(to=6) -- CumSum(axis=1) -- Add(B=0) -- Mul -- Cast(to=7) -- Add(B=1) --> Gather
| ^
V |
+-------------------------------------------+
start_node = position_embedding_gather
start_index = 1
# match optional Cast node.
parent = self.model.get_parent(start_node, start_index, output_name_to_node)
if parent is None:
return
if parent.op_type == "Cast":
if OnnxModel.get_node_attribute(parent, "to") != 7:
return
start_node = parent
start_index = 0
i, path, return_indices = self.model.match_parent_paths(
start_node,
[ (['Add', 'Cast', 'Mul', 'CumSum', 'Cast', 'Not', 'Equal'], [start_index, 0, 0, 0, 0, 0, 0]),
(['Add', 'Cast', 'Mul', 'Add', 'CumSum', 'Cast', 'Not', 'Equal'], [start_index, 0, 0, 0, 0, 0, 0, 0])],
output_name_to_node)
if path is not None:
# constant input of Add shall be 1.
i, value = self.model.get_constant_input(path[0])
if value != 1:
return False
_, self.padding_word_id = self.model.get_constant_input(path[-1])
return input_ids == path[-1].input[0]
"""
return False
def match_position_embedding_bert(self, position_embedding_gather, input_ids, output_name_to_node):
""" Match position embedding path from input_ids to Gather for BERT.
BERT Embedding Layer Pattern:
(input_ids)
/ \
/ Shape
/ |
/ Gather (indices=1)
/ |
/ Add (optional, B=0)
/ |
Gather (segment_ids) Unsqueeze (axes=0)
\ | |
\ Gather Slice (data[1,512], starts=0, ends=*, axes=1, steps=1)
\ / |
Add Gather
\ /
Add
|
LayerNormalization
"""
path = self.model.match_parent_path(
position_embedding_gather,
["Slice", "Unsqueeze"],
[1, 2],
output_name_to_node,
)
if path is None:
return False
slice, unsqueeze = path
slice_weight = self.model.get_constant_value(slice.input[0])
if not (
slice_weight is not None
and len(slice_weight.shape) == 2
and slice_weight.shape[0] == 1
and self.utils.check_node_input_value(slice, 1, [0])
and self.utils.check_node_input_value(slice, 3, [1])
and (len(slice.input) == 4 or self.utils.check_node_input_value(slice, 4, [1]))
):
return False
opset_version = self.model.get_opset_version()
if opset_version < 13:
if not FusionUtils.check_node_attribute(unsqueeze, "axes", [0]):
return False
else:
if not self.utils.check_node_input_value(unsqueeze, 1, [0]):
return False
node = self.model.get_parent(unsqueeze, 0, output_name_to_node)
if node is None:
return False
if node.op_type == "Add":
if not self.utils.check_node_input_value(node, 1, 0):
return False
gather = self.model.get_parent(node, 0, output_name_to_node)
else:
gather = node
if gather is None or gather.op_type != "Gather":
return False
if not (self.utils.check_node_input_value(gather, 1, 1)):
return False
shape = self.model.get_parent(gather, 0, output_name_to_node)
if shape is None or shape.op_type != "Shape":
return False
return input_ids == shape.input[0]
def match_position_embedding(self, position_embedding_gather, input_ids, output_name_to_node):
if self.match_position_embedding_bert(position_embedding_gather, input_ids, output_name_to_node):
return True
# TODO: Support roberta (position starts from 2 instead of 0) in EmbedLayerNormalization kernel
# related: https://github.com/huggingface/transformers/issues/10736
# if self.match_position_embedding_roberta(position_embedding_gather, input_ids, output_name_to_node):
# return True
if self.match_position_embedding_distilbert(position_embedding_gather, input_ids, output_name_to_node):
return True
return False
def check_embedding(self, word_embedding_gather, segment_embedding_gather, position_embedding_gather):
"""Sanity check of embedding weights, and match hidden_size of weights and shape of inputs."""
input_ids = word_embedding_gather.input[1]
segment_ids = segment_embedding_gather.input[1] if segment_embedding_gather else None
position_ids = position_embedding_gather.input[1]
if self.shape_infer_helper is not None:
input_ids_shape = self.shape_infer_helper.get_edge_shape(input_ids)
position_ids_shape = self.shape_infer_helper.get_edge_shape(position_ids)
assert input_ids_shape and position_ids_shape
if not (
len(input_ids_shape) == 2
and len(position_ids_shape) == 2
and input_ids_shape[1] == position_ids_shape[1]
):
logger.info(
"Cannot fuse EmbedLayerNormalization: input_ids and position_ids not matched in 2nd dimension: {} vs {}".format(
input_ids_shape, position_ids_shape
)
)
return False
if segment_ids and not self.shape_infer_helper.compare_shape(input_ids, segment_ids):
logger.info(
"Cannot fuse EmbedLayerNormalization: input_ids and segment_ids does not have same shape: {} != {}".format(
input_ids_shape,
self.shape_infer_helper.get_edge_shape(segment_ids),
)
)
return False
word_embedding_table = self.model.get_constant_value(word_embedding_gather.input[0])
if word_embedding_table is None or len(word_embedding_table.shape) != 2:
logger.info("Cannot fuse EmbedLayerNormalization: word embedding table is not expected")
return False
position_embedding_table = self.model.get_constant_value(position_embedding_gather.input[0])
if (
position_embedding_table is None
or len(position_embedding_table.shape) != 2
or (word_embedding_table.shape[1] != position_embedding_table.shape[1])
):
logger.info("Cannot fuse EmbedLayerNormalization: position embedding table is not expected")
return False
if segment_ids:
segment_embedding_table = self.model.get_constant_value(segment_embedding_gather.input[0])
if (
segment_embedding_table is None
or len(segment_embedding_table.shape) != 2
or (word_embedding_table.shape[1] != segment_embedding_table.shape[1])
):
logger.info("Cannot fuse EmbedLayerNormalization: segment embedding table is not expected")
return False
# In normal case, word embeding table is the largest, and segment embedding table is the smallest, while postion embedding table is in between.
# TODO: use other information (like initializer names) to identify different embedding weights automatically.
if word_embedding_table.shape[0] <= position_embedding_table.shape[0]:
logger.warning(
f"word_embedding_table ({word_embedding_gather.input[0]}) size {word_embedding_table.shape[0]} <= position_embedding_table ({position_embedding_gather.input[0]}) size {position_embedding_table.shape[0]}"
)
if segment_ids:
if word_embedding_table.shape[0] <= segment_embedding_table.shape[0]:
logger.warning(
f"word_embedding_table ({word_embedding_gather.input[0]}) size {word_embedding_table.shape[0]} <= segment_embedding_table ({segment_embedding_gather.input[0]}) size {segment_embedding_table.shape[0]}"
)
if position_embedding_table.shape[0] <= segment_embedding_table.shape[0]:
logger.warning(
f"position_embedding_table ({position_embedding_gather.input[0]}) size {position_embedding_table.shape[0]} <= segment_embedding_table ({segment_embedding_gather.input[0]}) size {segment_embedding_table.shape[0]}"
)
return True
def cast_to_int32(self, input_name: str) -> Tuple[str, Union[None, NodeProto]]:
"""Cast a graph input or node input to int32.
Args:
input_name (str): name of graph input or node input
Returns:
A tuple of casted input name and the cast node.
int32_output (str): If input is int32, it is the input name, Otherwise it is output name of Cast node.
input_cast_node (Union[None, NodeProto]): Cast node. It could be None if input is int32.
"""
input_cast_node = None
graph_input = self.model.find_graph_input(input_name)
if graph_input is not None:
if graph_input.type.tensor_type.elem_type != TensorProto.INT32:
int32_output, input_cast_node = self.utils.cast_input_to_int32(input_name)
else:
int32_output = input_name
else:
int32_output, input_cast_node = self.utils.cast_input_to_int32(input_name)
return int32_output, input_cast_node
def create_fused_node(
self,
input_ids: str,
layernorm: NodeProto,
word_embedding_gather: NodeProto,
position_embedding_gather: NodeProto,
segment_embedding_gather: Union[None, NodeProto],
position_ids: str = None,
embedding_sum_output=False,
):
"""Create an EmbedLayerNormalization node. Note that segment embedding is optional.
Args:
input_ids (str): input_ids for word embeddings
layernorm (NodeProto): LayerNormalization or SkipLayerNormalization node.
word_embedding_gather (NodeProto): the Gather node for word embedding
position_embedding_gather (NodeProto): the Gather node for position embedding
segment_embedding_gather (Union[None, NodeProto]): the Gather node for segment embedding, or None.
Returns:
NodeProto: the EmbedLayerNormalization node created.
"""
nodes_to_add = []
input_ids, _ = self.cast_to_int32(input_ids)
node_name = self.model.create_node_name("EmbedLayerNormalization")
if layernorm.op_type == "LayerNormalization":
gamma = layernorm.input[1]
beta = layernorm.input[2]
else: # SkipLayerNormalization
gamma = layernorm.input[2]
beta = layernorm.input[3]
embed_node_inputs = None
if segment_embedding_gather is not None:
segment_ids, _ = self.cast_to_int32(segment_embedding_gather.input[1])
embed_node_inputs = [
input_ids,
segment_ids,
word_embedding_gather.input[0],
position_embedding_gather.input[0],
segment_embedding_gather.input[0],
gamma,
beta,
]
else: # no segment embedding
embed_node_inputs = [
input_ids,
"",
word_embedding_gather.input[0],
position_embedding_gather.input[0],
"",
gamma,
beta,
]
if position_ids is not None:
# Adding an empty input for mask before position_ids
embed_node_inputs.append("")
position_ids, _ = self.cast_to_int32(position_ids)
embed_node_inputs.append(position_ids)
embed_node_outputs = [node_name + "_output", node_name + "_dummy_mask_index"]
if embedding_sum_output:
embed_node_outputs.append(node_name + "_embedding_sum")
embed_node = helper.make_node(
"EmbedLayerNormalization",
embed_node_inputs,
outputs=embed_node_outputs,
name=node_name,
)
embed_node.domain = "com.microsoft"
# Pass attribute "epsilon" from normalize node to EmbedLayerNormalization.
for att in layernorm.attribute:
if att.name == "epsilon":
embed_node.attribute.extend([att])
# Set default value to 1e-12 if no attribute is found.
# OnnxRuntime 1.2.0 or older has no epsilon attribute. The optimized model can only work for 1.3.0 or later.
if len(embed_node.attribute) == 0:
embed_node.attribute.extend([helper.make_attribute("epsilon", 1.0e-12)])
# Make sure new EmbedLayerNormalization node is the last one in self.nodes_to_add.
nodes_to_add.append(embed_node)
for node in nodes_to_add:
self.node_name_to_graph_name[node.name] = self.this_graph_name
self.nodes_to_add.extend(nodes_to_add)
self.embed_node = embed_node
return embed_node
def finish_fusion(self, layernorm, embed_node):
self.model.replace_input_of_all_nodes(layernorm.output[0], embed_node.output[0])
# use prune graph to remove nodes that is not needed
self.prune_graph = True
def is_embedding_sum_needed(self, add_before_layer_norm):
"""Check that Add before layer norm has an output to add before next layernorm
Args:
add_before_layer_norm (NodeProto): Add before any LayerNormalization node in topological order of graph
Returns:
bool: whether there is an extra output needed out of embed layer norm node
"""
nodes = self.model.get_children(add_before_layer_norm)
return len(nodes) > 1
def fuse_gpt2(self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node):
# graph checks
# gpt2 has no segment embedding, subgraph pattern is like
# input_ids position_ids
# | |
# Gather Gather
# \ /
# Add _ _ _ _ _
# | |
# LayerNormalization |
# | |
# Attention |
# | |
# Matmul |
# | /
# Add /
# \ /
# Add
two_gather = self.match_two_gather(add_before_layernorm)
if two_gather is None:
return False
word_embedding_gather, position_embedding_gather = two_gather
input_ids = word_embedding_gather.input[1]
position_ids = position_embedding_gather.input[1]
if not self.check_attention_subgraph(layernorm, input_name_to_nodes, is_distil_bert=False):
return False
if not self.check_embedding(word_embedding_gather, None, position_embedding_gather):
return False
# If the add_before_layernorm node is an Add node, then the add_output output is the first index
# output of this node.
# If the add_before_layernorm node is SkipLayerNormalization node, then the add_output output
# is the (optional) fourth index output of this node.
add_output = None
optional_embedding_sum_output = False
if (add_before_layernorm.op_type == "Add" and self.is_embedding_sum_needed(add_before_layernorm)) or (
add_before_layernorm.op_type == "SkipLayerNormalization" and len(add_before_layernorm.output) >= 4
):
optional_embedding_sum_output = True
add_output = (
add_before_layernorm.output[0]
if add_before_layernorm.op_type == "Add"
else add_before_layernorm.output[3]
)
# make the fused node
embed_node = self.create_fused_node(
input_ids,
layernorm,
word_embedding_gather,
position_embedding_gather,
None,
position_ids,
optional_embedding_sum_output,
)
# direct the output to another add too
self.model.replace_input_of_all_nodes(layernorm.output[0], embed_node.output[0])
if optional_embedding_sum_output:
self.model.replace_input_of_all_nodes(add_output, embed_node.output[2])
return True
def fuse_distilbert(self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node):
"""Fuse embedding layer for DistilBert
Args:
layernorm (NodeProto): node of LayerNormalization or SkipLayerNormalization
add_before_layernorm (NodeProto): the Add node before LayerNormalization, or the SkipLayerNormalization itself
input_name_to_nodes (Dict[str, List[NodeProto]]): map from input name to nodes
output_name_to_node (Dict[str, List[NodeProto]]): map from output name to nodes
"""
# DistilBert has no segment embedding, subgraph pattern is like
# input_ids
# | \
# | (position_embedding_subgraph)
# | |
# Gather Gather
# \ /
# Add
# |
# LayerNormalization
two_gather = self.match_two_gather(add_before_layernorm)
if two_gather is None:
return False
word_embedding_gather, position_embedding_gather = two_gather
input_ids = word_embedding_gather.input[1]
if not self.check_attention_subgraph(layernorm, input_name_to_nodes, is_distil_bert=True):
return False
if not self.match_position_embedding(position_embedding_gather, input_ids, output_name_to_node):
return False
if not self.check_embedding(word_embedding_gather, None, position_embedding_gather):
return False
embed_node = self.create_fused_node(
input_ids, layernorm, word_embedding_gather, position_embedding_gather, None
)
self.finish_fusion(layernorm, embed_node)
return True
def fuse_bert(self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node):
"""Fuse embedding layer for Bert
Args:
layernorm (NodeProto): node of LayerNormalization or SkipLayerNormalization
add_before_layernorm (NodeProto): the Add node before LayerNormalization, or the SkipLayerNormalization itself
input_name_to_nodes (Dict[str, List[NodeProto]]): map from input name to nodes
output_name_to_node (Dict[str, List[NodeProto]]): map from output name to nodes
"""
add_2_gather = self.model.match_parent_path(add_before_layernorm, ["Add"], [0])
if add_2_gather is None:
return False
two_gather = self.match_two_gather(add_2_gather[0])
if two_gather is None:
return False
word_embedding_gather, segment_embedding_gather = two_gather
input_ids = word_embedding_gather.input[1]
if not self.check_attention_subgraph(layernorm, input_name_to_nodes, is_distil_bert=False):
return False
position_embedding_path = self.model.match_parent_path(add_before_layernorm, ["Gather"], [1])
if position_embedding_path is None:
return False
position_embedding_gather = position_embedding_path[0]
if not self.match_position_embedding(position_embedding_gather, input_ids, output_name_to_node):
if not self.match_position_embedding(segment_embedding_gather, input_ids, output_name_to_node):
return False
# position and segment are switched
temp = segment_embedding_gather
segment_embedding_gather = position_embedding_gather
position_embedding_gather = temp
if not self.check_embedding(word_embedding_gather, segment_embedding_gather, position_embedding_gather):
return False
embed_node = self.create_fused_node(
input_ids,
layernorm,
word_embedding_gather,
position_embedding_gather,
segment_embedding_gather,
)
self.finish_fusion(layernorm, embed_node)
return True
def fuse(self, node, input_name_to_nodes, output_name_to_node):
if node.op_type == "LayerNormalization":
first_add_path = self.model.match_parent_path(node, ["Add"], [0])
if first_add_path is None:
return
add_before_layernorm = first_add_path[0]
else: # SkipLayerNormalization
add_before_layernorm = node # Add is fused into SkipLayerNormalization
if self.fuse_gpt2(node, add_before_layernorm, input_name_to_nodes, output_name_to_node):
return
if self.fuse_distilbert(node, add_before_layernorm, input_name_to_nodes, output_name_to_node):
return
if self.fuse_bert(node, add_before_layernorm, input_name_to_nodes, output_name_to_node):
return
class FusionEmbedLayerNormalization(FusionEmbedLayerNoMask):
def __init__(self, model: OnnxModel, use_mask_index=False):
super().__init__(model, "with mask")
self.use_mask_index = use_mask_index
def replace_mask(self, mask_int32, attention_nodes):
# Inputs of EmbedLayerNorm: input_ids, segment_ids (optional), word_embedding, position_embedding,
# segment_embedding (optional), gamma, beta, mask (optional), position_ids (optional)
embed_node = self.embed_node
if len(embed_node.input) == 7:
embed_node.input.append(mask_int32)
logger.debug("append mask to %s", embed_node.name)
elif len(embed_node.input) > 7 and embed_node.input[7] == "":
embed_node.input[7] = mask_int32
logger.debug("replace mask in %s", embed_node.name)
else:
logger.debug("skip mask in %s", embed_node.name)
return
for attention_node in attention_nodes:
logger.debug("update mask_index in %s", attention_node.name)
if attention_node.op_type == "Attention":
attention_node.input[3] = embed_node.output[1]
elif attention_node.op_type == "MultiHeadAttention":
attention_node.input[4] = embed_node.output[1]
def fuse(self, node, input_name_to_nodes, output_name_to_node):
# Reset attention and embed_node so that we know fusion is successful when they are not None.
self.attention = None
self.cross_attention = None
self.embed_node = None
super().fuse(node, input_name_to_nodes, output_name_to_node)
if self.embed_node is None:
return
if not self.use_mask_index:
logger.debug("--use_mask_index is not set: EmbedLayerNormalization will not have mask")
self.increase_counter("EmbedLayerNormalization(no mask)")
return
if self.attention is None and self.cross_attention is None:
logger.debug("EmbedLayerNormalization will not have mask since attention node is not found")
self.increase_counter("EmbedLayerNormalization(no mask)")
return
if self.attention:
mask_int32 = self.attention.input[3]
else:
mask_int32 = self.cross_attention.input[4]
children_nodes = input_name_to_nodes[mask_int32]
if self.model.find_graph_input(mask_int32):
attention_nodes = [node for node in children_nodes if node.op_type in ["Attention", "MultiHeadAttention"]]
self.replace_mask(mask_int32, attention_nodes)
self.increase_counter("EmbedLayerNormalization(with mask)")
return
if mask_int32 not in output_name_to_node:
logger.debug("EmbedLayerNormalization will not have mask since %s is not a node output", mask_int32)
self.increase_counter("EmbedLayerNormalization(no mask)")
return
node = output_name_to_node[mask_int32]
if node.op_type in ["ReduceSum", "Cast"]:
attention_nodes = [node for node in children_nodes if node.op_type in ["Attention", "MultiHeadAttention"]]
if node.op_type == "ReduceSum":
mask_int32 = node.input[0]
if len(children_nodes) == len(attention_nodes):
self.nodes_to_remove.append(node)
self.replace_mask(mask_int32, attention_nodes)
self.increase_counter("EmbedLayerNormalization(with mask)")