m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

68 lines
3.0 KiB

# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import logging
logger = logging.getLogger(__name__)
class PastKeyValuesHelper:
"""Helper functions to process past key values for encoder-decoder model"""
@staticmethod
def get_past_names(num_layers, present: bool = False):
past_self_names = []
past_cross_names = []
for i in range(num_layers):
past_self_names.extend(
[f"present_key_self_{i}", f"present_value_self_{i}"]
if present
else [f"past_key_self_{i}", f"past_value_self_{i}"]
)
past_cross_names.extend(
[f"present_key_cross_{i}", f"present_value_cross_{i}"]
if present
else [f"past_key_cross_{i}", f"past_value_cross_{i}"]
)
return past_self_names + past_cross_names
@staticmethod
def group_by_self_or_cross(present_key_values):
"""Split present state from grouped by layer to grouped by self/cross attention.
Before: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0), (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1), ...
After: (past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ...), (past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...)
"""
present_self = []
present_cross = []
for i, present_layer_i in enumerate(present_key_values):
assert len(present_layer_i) == 4, f"Expected to have four items. Got {len(present_layer_i)}"
(
present_key_self,
present_value_self,
present_key_cross,
present_value_cross,
) = present_layer_i
present_self.extend([present_key_self, present_value_self])
present_cross.extend([present_key_cross, present_value_cross])
return present_self, present_cross
@staticmethod
def group_by_layer(past, num_layers):
"""Reorder past state from grouped by self/cross attention to grouped by layer.
Before: past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ..., past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...
After: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0), (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1),
"""
assert len(past) == 4 * num_layers
return tuple(
[
past[2 * i],
past[2 * i + 1],
past[2 * num_layers + 2 * i],
past[2 * num_layers + 2 * i + 1],
]
for i in range(num_layers)
)