图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

121 lines
3.8 KiB

import logging
import onnx
from onnx import onnx_pb as onnx_proto
from ..quant_utils import attribute_to_kwarg, ms_domain
from .base_operator import QuantOperatorBase
"""
Quantizes the EmbedLayerNorm fused ONNXRuntime Op.
This Quant operator keeps the input and segment IDs at int32 but will quantize all initializer and
weight inputs associated with the node to uint8.
"""
class EmbedLayerNormalizationQuant(QuantOperatorBase):
def __init__(self, onnx_quantizer, onnx_node):
super().__init__(onnx_quantizer, onnx_node)
def should_quantize(self):
return self.quantizer.should_quantize_node(self.node)
def quantize(self):
node = self.node
assert node.op_type == "EmbedLayerNormalization"
if len(node.output) > 2:
logging.info(f"Quantization is not applied to {node.name} since it has 3 outputs")
return super().quantize()
"""
Pre-quantization EmbedLayerNorm inputs:
[0] input_ids (int32)
[1] segment_ids (int32)
[2] word_embedding (float32)
[3] position_embedding (float32)
[4] segment_embedding (float32)
[5] gamma (float32)
[6] beta (float32)
[7] mask (int32) (optional)
"""
(
quantized_input_names,
zero_point_names,
scale_names,
nodes,
) = self.quantizer.quantize_activation(node, [2, 3, 4, 5, 6])
if quantized_input_names is None:
return super().quantize()
qembed_layer_norm_name = "" if node.name == "" else node.name + "_quant"
"""
Quantized Input Tensor List
[0] input_ids (int32)
[1] segment_ids (int32)
[2] word_embedding (uint8)
[3] position_embedding (uint8)
[4] segment_embedding (uint8)
[5] gamma (uint8)
[6] beta (uint8)
[7] mask (int32) (optional)
[8] word_embedding_scale (float)
[9] position_embedding_scale (float)
[10] segment_embedding_scale (float)
[11] gamma_scale (float)
[12] beta_scale (float)
[13] word_embedding_zero_point (uint8)
[14] position_embedding_zero_point (uint8)
[15] segment_embedding_zero_point (uint8)
[16] gamma_zero_point (uint8)
[17] beta_zero_point (uint8)
"""
inputs = []
# 'input_ids'
inputs.extend([node.input[0]])
# 'segment_ids'
inputs.extend([node.input[1]])
# 'word_embedding_quant'
inputs.extend([quantized_input_names[0]])
# 'position_embedding_quant'
inputs.extend([quantized_input_names[1]])
# 'segment_embedding_quant'
inputs.extend([quantized_input_names[2]])
# 'gamma_quant'
inputs.extend([quantized_input_names[3]])
# 'beta_quant'
inputs.extend([quantized_input_names[4]])
# 'mask' (optional)
inputs.extend([node.input[7] if len(node.input) > 7 else ""])
# Add all scales:
inputs.extend([scale_names[0]])
inputs.extend([scale_names[1]])
inputs.extend([scale_names[2]])
inputs.extend([scale_names[3]])
inputs.extend([scale_names[4]])
# Add all zero points:
inputs.extend([zero_point_names[0]])
inputs.extend([zero_point_names[1]])
inputs.extend([zero_point_names[2]])
inputs.extend([zero_point_names[3]])
inputs.extend([zero_point_names[4]])
kwargs = {}
for attribute in node.attribute:
kwargs.update(attribute_to_kwarg(attribute))
kwargs["domain"] = ms_domain
qembed_layer_norm_node = onnx.helper.make_node(
"QEmbedLayerNormalization",
inputs,
node.output,
qembed_layer_norm_name,
**kwargs,
)
nodes.append(qembed_layer_norm_node)
self.quantizer.new_nodes += nodes