图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

112 lines
3.5 KiB

  1. import bisect
  2. from collections import defaultdict
  3. from sympy.core.containers import Tuple
  4. from sympy.core.numbers import Integer
  5. def _get_mapping_from_subranks(subranks):
  6. mapping = {}
  7. counter = 0
  8. for i, rank in enumerate(subranks):
  9. for j in range(rank):
  10. mapping[counter] = (i, j)
  11. counter += 1
  12. return mapping
  13. def _get_contraction_links(args, subranks, *contraction_indices):
  14. mapping = _get_mapping_from_subranks(subranks)
  15. contraction_tuples = [[mapping[j] for j in i] for i in contraction_indices]
  16. dlinks = defaultdict(dict)
  17. for links in contraction_tuples:
  18. if len(links) == 2:
  19. (arg1, pos1), (arg2, pos2) = links
  20. dlinks[arg1][pos1] = (arg2, pos2)
  21. dlinks[arg2][pos2] = (arg1, pos1)
  22. continue
  23. return args, dict(dlinks)
  24. def _sort_contraction_indices(pairing_indices):
  25. pairing_indices = [Tuple(*sorted(i)) for i in pairing_indices]
  26. pairing_indices.sort(key=lambda x: min(x))
  27. return pairing_indices
  28. def _get_diagonal_indices(flattened_indices):
  29. axes_contraction = defaultdict(list)
  30. for i, ind in enumerate(flattened_indices):
  31. if isinstance(ind, (int, Integer)):
  32. # If the indices is a number, there can be no diagonal operation:
  33. continue
  34. axes_contraction[ind].append(i)
  35. axes_contraction = {k: v for k, v in axes_contraction.items() if len(v) > 1}
  36. # Put the diagonalized indices at the end:
  37. ret_indices = [i for i in flattened_indices if i not in axes_contraction]
  38. diag_indices = list(axes_contraction)
  39. diag_indices.sort(key=lambda x: flattened_indices.index(x))
  40. diagonal_indices = [tuple(axes_contraction[i]) for i in diag_indices]
  41. ret_indices += diag_indices
  42. ret_indices = tuple(ret_indices)
  43. return diagonal_indices, ret_indices
  44. def _get_argindex(subindices, ind):
  45. for i, sind in enumerate(subindices):
  46. if ind == sind:
  47. return i
  48. if isinstance(sind, (set, frozenset)) and ind in sind:
  49. return i
  50. raise IndexError("%s not found in %s" % (ind, subindices))
  51. def _apply_recursively_over_nested_lists(func, arr):
  52. if isinstance(arr, (tuple, list, Tuple)):
  53. return tuple(_apply_recursively_over_nested_lists(func, i) for i in arr)
  54. elif isinstance(arr, Tuple):
  55. return Tuple.fromiter(_apply_recursively_over_nested_lists(func, i) for i in arr)
  56. else:
  57. return func(arr)
  58. def _build_push_indices_up_func_transformation(flattened_contraction_indices):
  59. shifts = {0: 0}
  60. i = 0
  61. cumulative = 0
  62. while i < len(flattened_contraction_indices):
  63. j = 1
  64. while i+j < len(flattened_contraction_indices):
  65. if flattened_contraction_indices[i] + j != flattened_contraction_indices[i+j]:
  66. break
  67. j += 1
  68. cumulative += j
  69. shifts[flattened_contraction_indices[i]] = cumulative
  70. i += j
  71. shift_keys = sorted(shifts.keys())
  72. def func(idx):
  73. return shifts[shift_keys[bisect.bisect_right(shift_keys, idx)-1]]
  74. def transform(j):
  75. if j in flattened_contraction_indices:
  76. return None
  77. else:
  78. return j - func(j)
  79. return transform
  80. def _build_push_indices_down_func_transformation(flattened_contraction_indices):
  81. N = flattened_contraction_indices[-1]+2
  82. shifts = [i for i in range(N) if i not in flattened_contraction_indices]
  83. def transform(j):
  84. if j < len(shifts):
  85. return shifts[j]
  86. else:
  87. return j + shifts[-1] - len(shifts) + 1
  88. return transform