You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
68 lines
3.0 KiB
68 lines
3.0 KiB
# -------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License. See License.txt in the project root for
|
|
# license information.
|
|
# --------------------------------------------------------------------------
|
|
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class PastKeyValuesHelper:
|
|
"""Helper functions to process past key values for encoder-decoder model"""
|
|
|
|
@staticmethod
|
|
def get_past_names(num_layers, present: bool = False):
|
|
past_self_names = []
|
|
past_cross_names = []
|
|
for i in range(num_layers):
|
|
past_self_names.extend(
|
|
[f"present_key_self_{i}", f"present_value_self_{i}"]
|
|
if present
|
|
else [f"past_key_self_{i}", f"past_value_self_{i}"]
|
|
)
|
|
past_cross_names.extend(
|
|
[f"present_key_cross_{i}", f"present_value_cross_{i}"]
|
|
if present
|
|
else [f"past_key_cross_{i}", f"past_value_cross_{i}"]
|
|
)
|
|
return past_self_names + past_cross_names
|
|
|
|
@staticmethod
|
|
def group_by_self_or_cross(present_key_values):
|
|
"""Split present state from grouped by layer to grouped by self/cross attention.
|
|
Before: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0), (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1), ...
|
|
After: (past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ...), (past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...)
|
|
|
|
"""
|
|
present_self = []
|
|
present_cross = []
|
|
for i, present_layer_i in enumerate(present_key_values):
|
|
assert len(present_layer_i) == 4, f"Expected to have four items. Got {len(present_layer_i)}"
|
|
(
|
|
present_key_self,
|
|
present_value_self,
|
|
present_key_cross,
|
|
present_value_cross,
|
|
) = present_layer_i
|
|
present_self.extend([present_key_self, present_value_self])
|
|
present_cross.extend([present_key_cross, present_value_cross])
|
|
return present_self, present_cross
|
|
|
|
@staticmethod
|
|
def group_by_layer(past, num_layers):
|
|
"""Reorder past state from grouped by self/cross attention to grouped by layer.
|
|
Before: past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ..., past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...
|
|
After: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0), (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1),
|
|
"""
|
|
assert len(past) == 4 * num_layers
|
|
return tuple(
|
|
[
|
|
past[2 * i],
|
|
past[2 * i + 1],
|
|
past[2 * num_layers + 2 * i],
|
|
past[2 * num_layers + 2 * i + 1],
|
|
]
|
|
for i in range(num_layers)
|
|
)
|