You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
250 lines
9.7 KiB
250 lines
9.7 KiB
# -------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
# --------------------------------------------------------------------------
|
|
from logging import getLogger
|
|
from typing import Dict, Optional
|
|
|
|
from fusion_base import Fusion
|
|
from onnx import helper
|
|
from onnx_model import OnnxModel
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
class FusionGelu(Fusion):
|
|
def __init__(self, model: OnnxModel):
|
|
super().__init__(model, "Gelu", "Erf")
|
|
|
|
def fuse(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict):
|
|
if self.fuse_1(erf_node, input_name_to_nodes, output_name_to_node):
|
|
return
|
|
if self.fuse_2(erf_node, input_name_to_nodes, output_name_to_node):
|
|
return
|
|
self.fuse_3(erf_node, input_name_to_nodes, output_name_to_node)
|
|
|
|
def fuse_1(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]:
|
|
"""
|
|
This pattern is from PyTorch model
|
|
Fuse Gelu with Erf into one node:
|
|
Pattern 1:
|
|
+-------Mul(0.5)---------------------+
|
|
| |
|
|
| v
|
|
[root] --> Div -----> Erf --> Add --> Mul -->
|
|
(B=1.4142...) (1)
|
|
|
|
Pattern 2:
|
|
+------------------------------------+
|
|
| |
|
|
| v
|
|
[root] --> Div -----> Erf --> Add --> Mul -->Mul -->
|
|
(B=1.4142...) (1) (0.5)
|
|
|
|
Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
|
|
"""
|
|
if erf_node.output[0] not in input_name_to_nodes:
|
|
return
|
|
children = input_name_to_nodes[erf_node.output[0]]
|
|
if len(children) != 1 or children[0].op_type != "Add":
|
|
return
|
|
add_after_erf = children[0]
|
|
|
|
if not self.model.has_constant_input(add_after_erf, 1):
|
|
return
|
|
|
|
if add_after_erf.output[0] not in input_name_to_nodes:
|
|
return
|
|
children = input_name_to_nodes[add_after_erf.output[0]]
|
|
if len(children) != 1 or children[0].op_type != "Mul":
|
|
return
|
|
mul_after_erf = children[0]
|
|
|
|
div = self.model.match_parent(erf_node, "Div", 0, output_name_to_node)
|
|
if div is None:
|
|
return
|
|
|
|
if self.model.find_constant_input(div, 1.4142, delta=0.001) != 1:
|
|
return
|
|
|
|
subgraph_input = div.input[0]
|
|
|
|
another = 1 if mul_after_erf.input[0] == add_after_erf.output[0] else 0
|
|
if subgraph_input == mul_after_erf.input[another]: # pattern 2
|
|
children = input_name_to_nodes[mul_after_erf.output[0]]
|
|
if len(children) != 1 or children[0].op_type != "Mul":
|
|
return
|
|
mul_half = children[0]
|
|
if not self.model.has_constant_input(mul_half, 0.5):
|
|
return
|
|
subgraph_output = mul_half.output[0]
|
|
else: # pattern 1
|
|
mul_half = self.model.match_parent(mul_after_erf, "Mul", another, output_name_to_node)
|
|
if mul_half is None:
|
|
return
|
|
|
|
if not self.model.has_constant_input(mul_half, 0.5):
|
|
return
|
|
|
|
if subgraph_input not in mul_half.input:
|
|
return
|
|
|
|
subgraph_output = mul_after_erf.output[0]
|
|
|
|
subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul_half]
|
|
if not self.model.is_safe_to_fuse_nodes(
|
|
subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node
|
|
):
|
|
return
|
|
|
|
self.nodes_to_remove.extend(subgraph_nodes)
|
|
fused_node = helper.make_node("Gelu", inputs=[subgraph_input], outputs=[subgraph_output])
|
|
fused_node.domain = "com.microsoft"
|
|
self.nodes_to_add.append(fused_node)
|
|
self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
|
|
return True
|
|
|
|
def fuse_2(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]:
|
|
"""
|
|
This pattern is from Keras model
|
|
Fuse Gelu with Erf into one node:
|
|
+------------------------------------------+
|
|
| |
|
|
| v
|
|
[root] --> Div -----> Erf --> Add --> Mul -->Mul
|
|
(B=1.4142...) (A=1) (A=0.5)
|
|
|
|
Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
|
|
"""
|
|
if erf_node.output[0] not in input_name_to_nodes:
|
|
return
|
|
children = input_name_to_nodes[erf_node.output[0]]
|
|
if len(children) != 1 or children[0].op_type != "Add":
|
|
return
|
|
add_after_erf = children[0]
|
|
|
|
if not self.model.has_constant_input(add_after_erf, 1):
|
|
return
|
|
|
|
if add_after_erf.output[0] not in input_name_to_nodes:
|
|
return
|
|
children = input_name_to_nodes[add_after_erf.output[0]]
|
|
if len(children) != 1 or children[0].op_type != "Mul":
|
|
return
|
|
mul_after_erf = children[0]
|
|
|
|
if not self.model.has_constant_input(mul_after_erf, 0.5):
|
|
return
|
|
|
|
if mul_after_erf.output[0] not in input_name_to_nodes:
|
|
return
|
|
children = input_name_to_nodes[mul_after_erf.output[0]]
|
|
if len(children) != 1 or children[0].op_type != "Mul":
|
|
return
|
|
mul = children[0]
|
|
|
|
div = self.model.match_parent(erf_node, "Div", 0, output_name_to_node)
|
|
if div is None:
|
|
return
|
|
|
|
sqrt_node = None
|
|
if self.model.find_constant_input(div, 1.4142, delta=0.001) != 1:
|
|
sqrt_node = self.model.match_parent(div, "Sqrt", 1, output_name_to_node)
|
|
if sqrt_node is None:
|
|
return
|
|
if not self.model.has_constant_input(sqrt_node, 2.0):
|
|
return
|
|
|
|
root_node = self.model.get_parent(div, 0, output_name_to_node)
|
|
if root_node is None:
|
|
return
|
|
|
|
if root_node.output[0] not in mul.input:
|
|
return
|
|
|
|
subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul]
|
|
if sqrt_node:
|
|
subgraph_nodes.append(sqrt_node)
|
|
|
|
if not self.model.is_safe_to_fuse_nodes(
|
|
subgraph_nodes, [mul.output[0]], input_name_to_nodes, output_name_to_node
|
|
):
|
|
return
|
|
|
|
self.nodes_to_remove.extend(subgraph_nodes)
|
|
fused_node = helper.make_node("Gelu", inputs=[root_node.output[0]], outputs=[mul.output[0]])
|
|
fused_node.domain = "com.microsoft"
|
|
self.nodes_to_add.append(fused_node)
|
|
self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
|
|
return True
|
|
|
|
def fuse_3(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]:
|
|
"""
|
|
This pattern is from TensorFlow model
|
|
Fuse Gelu with Erf into one node:
|
|
+----------------------------------------------+
|
|
| |
|
|
| v
|
|
[root] --> Mul -----> Erf --> Add --> Mul -->Mul
|
|
(A=0.7071067690849304) (B=1) (B=0.5)
|
|
|
|
Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
|
|
"""
|
|
|
|
if erf_node.output[0] not in input_name_to_nodes:
|
|
return
|
|
children = input_name_to_nodes[erf_node.output[0]]
|
|
if len(children) != 1 or children[0].op_type != "Add":
|
|
return
|
|
add_after_erf = children[0]
|
|
|
|
if not self.model.has_constant_input(add_after_erf, 1):
|
|
return
|
|
|
|
if add_after_erf.output[0] not in input_name_to_nodes:
|
|
return
|
|
children = input_name_to_nodes[add_after_erf.output[0]]
|
|
if len(children) != 1 or children[0].op_type != "Mul":
|
|
return
|
|
mul_half = children[0]
|
|
|
|
if not self.model.has_constant_input(mul_half, 0.5):
|
|
return
|
|
|
|
first_mul = self.model.match_parent(erf_node, "Mul", 0, output_name_to_node)
|
|
if first_mul is None:
|
|
return
|
|
|
|
i = self.model.find_constant_input(first_mul, 0.7071067690849304, delta=0.001)
|
|
if i < 0:
|
|
return
|
|
|
|
root_node = self.model.get_parent(first_mul, 0 if i == 1 else 1, output_name_to_node)
|
|
if root_node is None:
|
|
return
|
|
|
|
if mul_half.output[0] not in input_name_to_nodes:
|
|
return
|
|
children = input_name_to_nodes[mul_half.output[0]]
|
|
if len(children) != 1 or children[0].op_type != "Mul":
|
|
return
|
|
last_mul = children[0]
|
|
|
|
if not (last_mul.input[0] == root_node.output[0] or last_mul.input[1] == root_node.output[0]):
|
|
return
|
|
|
|
subgraph_nodes = [first_mul, erf_node, add_after_erf, mul_half, last_mul]
|
|
if not self.model.is_safe_to_fuse_nodes(
|
|
subgraph_nodes,
|
|
[last_mul.output[0]],
|
|
input_name_to_nodes,
|
|
output_name_to_node,
|
|
):
|
|
return
|
|
|
|
self.nodes_to_remove.extend(subgraph_nodes)
|
|
fused_node = helper.make_node("Gelu", inputs=[root_node.output[0]], outputs=[last_mul.output[0]])
|
|
fused_node.domain = "com.microsoft"
|
|
self.nodes_to_add.append(fused_node)
|
|
self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
|
|
return True
|