m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

68 lines
3.0 KiB

6 months ago
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. import logging
  7. logger = logging.getLogger(__name__)
  8. class PastKeyValuesHelper:
  9. """Helper functions to process past key values for encoder-decoder model"""
  10. @staticmethod
  11. def get_past_names(num_layers, present: bool = False):
  12. past_self_names = []
  13. past_cross_names = []
  14. for i in range(num_layers):
  15. past_self_names.extend(
  16. [f"present_key_self_{i}", f"present_value_self_{i}"]
  17. if present
  18. else [f"past_key_self_{i}", f"past_value_self_{i}"]
  19. )
  20. past_cross_names.extend(
  21. [f"present_key_cross_{i}", f"present_value_cross_{i}"]
  22. if present
  23. else [f"past_key_cross_{i}", f"past_value_cross_{i}"]
  24. )
  25. return past_self_names + past_cross_names
  26. @staticmethod
  27. def group_by_self_or_cross(present_key_values):
  28. """Split present state from grouped by layer to grouped by self/cross attention.
  29. Before: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0), (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1), ...
  30. After: (past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ...), (past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...)
  31. """
  32. present_self = []
  33. present_cross = []
  34. for i, present_layer_i in enumerate(present_key_values):
  35. assert len(present_layer_i) == 4, f"Expected to have four items. Got {len(present_layer_i)}"
  36. (
  37. present_key_self,
  38. present_value_self,
  39. present_key_cross,
  40. present_value_cross,
  41. ) = present_layer_i
  42. present_self.extend([present_key_self, present_value_self])
  43. present_cross.extend([present_key_cross, present_value_cross])
  44. return present_self, present_cross
  45. @staticmethod
  46. def group_by_layer(past, num_layers):
  47. """Reorder past state from grouped by self/cross attention to grouped by layer.
  48. Before: past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ..., past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...
  49. After: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0), (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1),
  50. """
  51. assert len(past) == 4 * num_layers
  52. return tuple(
  53. [
  54. past[2 * i],
  55. past[2 * i + 1],
  56. past[2 * num_layers + 2 * i],
  57. past[2 * num_layers + 2 * i + 1],
  58. ]
  59. for i in range(num_layers)
  60. )