m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

112 lines
3.5 KiB

6 months ago
  1. import bisect
  2. from collections import defaultdict
  3. from sympy.core.containers import Tuple
  4. from sympy.core.numbers import Integer
  5. def _get_mapping_from_subranks(subranks):
  6. mapping = {}
  7. counter = 0
  8. for i, rank in enumerate(subranks):
  9. for j in range(rank):
  10. mapping[counter] = (i, j)
  11. counter += 1
  12. return mapping
  13. def _get_contraction_links(args, subranks, *contraction_indices):
  14. mapping = _get_mapping_from_subranks(subranks)
  15. contraction_tuples = [[mapping[j] for j in i] for i in contraction_indices]
  16. dlinks = defaultdict(dict)
  17. for links in contraction_tuples:
  18. if len(links) == 2:
  19. (arg1, pos1), (arg2, pos2) = links
  20. dlinks[arg1][pos1] = (arg2, pos2)
  21. dlinks[arg2][pos2] = (arg1, pos1)
  22. continue
  23. return args, dict(dlinks)
  24. def _sort_contraction_indices(pairing_indices):
  25. pairing_indices = [Tuple(*sorted(i)) for i in pairing_indices]
  26. pairing_indices.sort(key=lambda x: min(x))
  27. return pairing_indices
  28. def _get_diagonal_indices(flattened_indices):
  29. axes_contraction = defaultdict(list)
  30. for i, ind in enumerate(flattened_indices):
  31. if isinstance(ind, (int, Integer)):
  32. # If the indices is a number, there can be no diagonal operation:
  33. continue
  34. axes_contraction[ind].append(i)
  35. axes_contraction = {k: v for k, v in axes_contraction.items() if len(v) > 1}
  36. # Put the diagonalized indices at the end:
  37. ret_indices = [i for i in flattened_indices if i not in axes_contraction]
  38. diag_indices = list(axes_contraction)
  39. diag_indices.sort(key=lambda x: flattened_indices.index(x))
  40. diagonal_indices = [tuple(axes_contraction[i]) for i in diag_indices]
  41. ret_indices += diag_indices
  42. ret_indices = tuple(ret_indices)
  43. return diagonal_indices, ret_indices
  44. def _get_argindex(subindices, ind):
  45. for i, sind in enumerate(subindices):
  46. if ind == sind:
  47. return i
  48. if isinstance(sind, (set, frozenset)) and ind in sind:
  49. return i
  50. raise IndexError("%s not found in %s" % (ind, subindices))
  51. def _apply_recursively_over_nested_lists(func, arr):
  52. if isinstance(arr, (tuple, list, Tuple)):
  53. return tuple(_apply_recursively_over_nested_lists(func, i) for i in arr)
  54. elif isinstance(arr, Tuple):
  55. return Tuple.fromiter(_apply_recursively_over_nested_lists(func, i) for i in arr)
  56. else:
  57. return func(arr)
  58. def _build_push_indices_up_func_transformation(flattened_contraction_indices):
  59. shifts = {0: 0}
  60. i = 0
  61. cumulative = 0
  62. while i < len(flattened_contraction_indices):
  63. j = 1
  64. while i+j < len(flattened_contraction_indices):
  65. if flattened_contraction_indices[i] + j != flattened_contraction_indices[i+j]:
  66. break
  67. j += 1
  68. cumulative += j
  69. shifts[flattened_contraction_indices[i]] = cumulative
  70. i += j
  71. shift_keys = sorted(shifts.keys())
  72. def func(idx):
  73. return shifts[shift_keys[bisect.bisect_right(shift_keys, idx)-1]]
  74. def transform(j):
  75. if j in flattened_contraction_indices:
  76. return None
  77. else:
  78. return j - func(j)
  79. return transform
  80. def _build_push_indices_down_func_transformation(flattened_contraction_indices):
  81. N = flattened_contraction_indices[-1]+2
  82. shifts = [i for i in range(N) if i not in flattened_contraction_indices]
  83. def transform(j):
  84. if j < len(shifts):
  85. return shifts[j]
  86. else:
  87. return j + shifts[-1] - len(shifts) + 1
  88. return transform