m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

887 lines
32 KiB

6 months ago
  1. from textwrap import dedent
  2. from itertools import islice, product
  3. from sympy.core.basic import Basic
  4. from sympy.core.numbers import Integer
  5. from sympy.core.sorting import ordered
  6. from sympy.core.symbol import (Dummy, symbols)
  7. from sympy.functions.combinatorial.factorials import factorial
  8. from sympy.matrices.dense import Matrix
  9. from sympy.combinatorics import RGS_enum, RGS_unrank, Permutation
  10. from sympy.utilities.iterables import (
  11. _partition, _set_partitions, binary_partitions, bracelets, capture,
  12. cartes, common_prefix, common_suffix, connected_components, dict_merge,
  13. filter_symbols, flatten, generate_bell, generate_derangements,
  14. generate_involutions, generate_oriented_forest, group, has_dups, ibin,
  15. iproduct, kbins, minlex, multiset, multiset_combinations,
  16. multiset_partitions, multiset_permutations, necklaces, numbered_symbols,
  17. partitions, permutations, postfixes,
  18. prefixes, reshape, rotate_left, rotate_right, runs, sift,
  19. strongly_connected_components, subsets, take, topological_sort, unflatten,
  20. uniq, variations, ordered_partitions, rotations, is_palindromic, iterable,
  21. NotIterable, multiset_derangements)
  22. from sympy.utilities.enumerative import (
  23. factoring_visitor, multiset_partitions_taocp )
  24. from sympy.core.singleton import S
  25. from sympy.testing.pytest import raises, warns_deprecated_sympy
  26. w, x, y, z = symbols('w,x,y,z')
  27. def test_deprecated_iterables():
  28. from sympy.utilities.iterables import default_sort_key, ordered
  29. with warns_deprecated_sympy():
  30. assert list(ordered([y, x])) == [x, y]
  31. with warns_deprecated_sympy():
  32. assert sorted([y, x], key=default_sort_key) == [x, y]
  33. def test_is_palindromic():
  34. assert is_palindromic('')
  35. assert is_palindromic('x')
  36. assert is_palindromic('xx')
  37. assert is_palindromic('xyx')
  38. assert not is_palindromic('xy')
  39. assert not is_palindromic('xyzx')
  40. assert is_palindromic('xxyzzyx', 1)
  41. assert not is_palindromic('xxyzzyx', 2)
  42. assert is_palindromic('xxyzzyx', 2, -1)
  43. assert is_palindromic('xxyzzyx', 2, 6)
  44. assert is_palindromic('xxyzyx', 1)
  45. assert not is_palindromic('xxyzyx', 2)
  46. assert is_palindromic('xxyzyx', 2, 2 + 3)
  47. def test_flatten():
  48. assert flatten((1, (1,))) == [1, 1]
  49. assert flatten((x, (x,))) == [x, x]
  50. ls = [[(-2, -1), (1, 2)], [(0, 0)]]
  51. assert flatten(ls, levels=0) == ls
  52. assert flatten(ls, levels=1) == [(-2, -1), (1, 2), (0, 0)]
  53. assert flatten(ls, levels=2) == [-2, -1, 1, 2, 0, 0]
  54. assert flatten(ls, levels=3) == [-2, -1, 1, 2, 0, 0]
  55. raises(ValueError, lambda: flatten(ls, levels=-1))
  56. class MyOp(Basic):
  57. pass
  58. assert flatten([MyOp(x, y), z]) == [MyOp(x, y), z]
  59. assert flatten([MyOp(x, y), z], cls=MyOp) == [x, y, z]
  60. assert flatten({1, 11, 2}) == list({1, 11, 2})
  61. def test_iproduct():
  62. assert list(iproduct()) == [()]
  63. assert list(iproduct([])) == []
  64. assert list(iproduct([1,2,3])) == [(1,),(2,),(3,)]
  65. assert sorted(iproduct([1, 2], [3, 4, 5])) == [
  66. (1,3),(1,4),(1,5),(2,3),(2,4),(2,5)]
  67. assert sorted(iproduct([0,1],[0,1],[0,1])) == [
  68. (0,0,0),(0,0,1),(0,1,0),(0,1,1),(1,0,0),(1,0,1),(1,1,0),(1,1,1)]
  69. assert iterable(iproduct(S.Integers)) is True
  70. assert iterable(iproduct(S.Integers, S.Integers)) is True
  71. assert (3,) in iproduct(S.Integers)
  72. assert (4, 5) in iproduct(S.Integers, S.Integers)
  73. assert (1, 2, 3) in iproduct(S.Integers, S.Integers, S.Integers)
  74. triples = set(islice(iproduct(S.Integers, S.Integers, S.Integers), 1000))
  75. for n1, n2, n3 in triples:
  76. assert isinstance(n1, Integer)
  77. assert isinstance(n2, Integer)
  78. assert isinstance(n3, Integer)
  79. for t in set(product(*([range(-2, 3)]*3))):
  80. assert t in iproduct(S.Integers, S.Integers, S.Integers)
  81. def test_group():
  82. assert group([]) == []
  83. assert group([], multiple=False) == []
  84. assert group([1]) == [[1]]
  85. assert group([1], multiple=False) == [(1, 1)]
  86. assert group([1, 1]) == [[1, 1]]
  87. assert group([1, 1], multiple=False) == [(1, 2)]
  88. assert group([1, 1, 1]) == [[1, 1, 1]]
  89. assert group([1, 1, 1], multiple=False) == [(1, 3)]
  90. assert group([1, 2, 1]) == [[1], [2], [1]]
  91. assert group([1, 2, 1], multiple=False) == [(1, 1), (2, 1), (1, 1)]
  92. assert group([1, 1, 2, 2, 2, 1, 3, 3]) == [[1, 1], [2, 2, 2], [1], [3, 3]]
  93. assert group([1, 1, 2, 2, 2, 1, 3, 3], multiple=False) == [(1, 2),
  94. (2, 3), (1, 1), (3, 2)]
  95. def test_subsets():
  96. # combinations
  97. assert list(subsets([1, 2, 3], 0)) == [()]
  98. assert list(subsets([1, 2, 3], 1)) == [(1,), (2,), (3,)]
  99. assert list(subsets([1, 2, 3], 2)) == [(1, 2), (1, 3), (2, 3)]
  100. assert list(subsets([1, 2, 3], 3)) == [(1, 2, 3)]
  101. l = list(range(4))
  102. assert list(subsets(l, 0, repetition=True)) == [()]
  103. assert list(subsets(l, 1, repetition=True)) == [(0,), (1,), (2,), (3,)]
  104. assert list(subsets(l, 2, repetition=True)) == [(0, 0), (0, 1), (0, 2),
  105. (0, 3), (1, 1), (1, 2),
  106. (1, 3), (2, 2), (2, 3),
  107. (3, 3)]
  108. assert list(subsets(l, 3, repetition=True)) == [(0, 0, 0), (0, 0, 1),
  109. (0, 0, 2), (0, 0, 3),
  110. (0, 1, 1), (0, 1, 2),
  111. (0, 1, 3), (0, 2, 2),
  112. (0, 2, 3), (0, 3, 3),
  113. (1, 1, 1), (1, 1, 2),
  114. (1, 1, 3), (1, 2, 2),
  115. (1, 2, 3), (1, 3, 3),
  116. (2, 2, 2), (2, 2, 3),
  117. (2, 3, 3), (3, 3, 3)]
  118. assert len(list(subsets(l, 4, repetition=True))) == 35
  119. assert list(subsets(l[:2], 3, repetition=False)) == []
  120. assert list(subsets(l[:2], 3, repetition=True)) == [(0, 0, 0),
  121. (0, 0, 1),
  122. (0, 1, 1),
  123. (1, 1, 1)]
  124. assert list(subsets([1, 2], repetition=True)) == \
  125. [(), (1,), (2,), (1, 1), (1, 2), (2, 2)]
  126. assert list(subsets([1, 2], repetition=False)) == \
  127. [(), (1,), (2,), (1, 2)]
  128. assert list(subsets([1, 2, 3], 2)) == \
  129. [(1, 2), (1, 3), (2, 3)]
  130. assert list(subsets([1, 2, 3], 2, repetition=True)) == \
  131. [(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)]
  132. def test_variations():
  133. # permutations
  134. l = list(range(4))
  135. assert list(variations(l, 0, repetition=False)) == [()]
  136. assert list(variations(l, 1, repetition=False)) == [(0,), (1,), (2,), (3,)]
  137. assert list(variations(l, 2, repetition=False)) == [(0, 1), (0, 2), (0, 3), (1, 0), (1, 2), (1, 3), (2, 0), (2, 1), (2, 3), (3, 0), (3, 1), (3, 2)]
  138. assert list(variations(l, 3, repetition=False)) == [(0, 1, 2), (0, 1, 3), (0, 2, 1), (0, 2, 3), (0, 3, 1), (0, 3, 2), (1, 0, 2), (1, 0, 3), (1, 2, 0), (1, 2, 3), (1, 3, 0), (1, 3, 2), (2, 0, 1), (2, 0, 3), (2, 1, 0), (2, 1, 3), (2, 3, 0), (2, 3, 1), (3, 0, 1), (3, 0, 2), (3, 1, 0), (3, 1, 2), (3, 2, 0), (3, 2, 1)]
  139. assert list(variations(l, 0, repetition=True)) == [()]
  140. assert list(variations(l, 1, repetition=True)) == [(0,), (1,), (2,), (3,)]
  141. assert list(variations(l, 2, repetition=True)) == [(0, 0), (0, 1), (0, 2),
  142. (0, 3), (1, 0), (1, 1),
  143. (1, 2), (1, 3), (2, 0),
  144. (2, 1), (2, 2), (2, 3),
  145. (3, 0), (3, 1), (3, 2),
  146. (3, 3)]
  147. assert len(list(variations(l, 3, repetition=True))) == 64
  148. assert len(list(variations(l, 4, repetition=True))) == 256
  149. assert list(variations(l[:2], 3, repetition=False)) == []
  150. assert list(variations(l[:2], 3, repetition=True)) == [
  151. (0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1),
  152. (1, 0, 0), (1, 0, 1), (1, 1, 0), (1, 1, 1)
  153. ]
  154. def test_cartes():
  155. assert list(cartes([1, 2], [3, 4, 5])) == \
  156. [(1, 3), (1, 4), (1, 5), (2, 3), (2, 4), (2, 5)]
  157. assert list(cartes()) == [()]
  158. assert list(cartes('a')) == [('a',)]
  159. assert list(cartes('a', repeat=2)) == [('a', 'a')]
  160. assert list(cartes(list(range(2)))) == [(0,), (1,)]
  161. def test_filter_symbols():
  162. s = numbered_symbols()
  163. filtered = filter_symbols(s, symbols("x0 x2 x3"))
  164. assert take(filtered, 3) == list(symbols("x1 x4 x5"))
  165. def test_numbered_symbols():
  166. s = numbered_symbols(cls=Dummy)
  167. assert isinstance(next(s), Dummy)
  168. assert next(numbered_symbols('C', start=1, exclude=[symbols('C1')])) == \
  169. symbols('C2')
  170. def test_sift():
  171. assert sift(list(range(5)), lambda _: _ % 2) == {1: [1, 3], 0: [0, 2, 4]}
  172. assert sift([x, y], lambda _: _.has(x)) == {False: [y], True: [x]}
  173. assert sift([S.One], lambda _: _.has(x)) == {False: [1]}
  174. assert sift([0, 1, 2, 3], lambda x: x % 2, binary=True) == (
  175. [1, 3], [0, 2])
  176. assert sift([0, 1, 2, 3], lambda x: x % 3 == 1, binary=True) == (
  177. [1], [0, 2, 3])
  178. raises(ValueError, lambda:
  179. sift([0, 1, 2, 3], lambda x: x % 3, binary=True))
  180. def test_take():
  181. X = numbered_symbols()
  182. assert take(X, 5) == list(symbols('x0:5'))
  183. assert take(X, 5) == list(symbols('x5:10'))
  184. assert take([1, 2, 3, 4, 5], 5) == [1, 2, 3, 4, 5]
  185. def test_dict_merge():
  186. assert dict_merge({}, {1: x, y: z}) == {1: x, y: z}
  187. assert dict_merge({1: x, y: z}, {}) == {1: x, y: z}
  188. assert dict_merge({2: z}, {1: x, y: z}) == {1: x, 2: z, y: z}
  189. assert dict_merge({1: x, y: z}, {2: z}) == {1: x, 2: z, y: z}
  190. assert dict_merge({1: y, 2: z}, {1: x, y: z}) == {1: x, 2: z, y: z}
  191. assert dict_merge({1: x, y: z}, {1: y, 2: z}) == {1: y, 2: z, y: z}
  192. def test_prefixes():
  193. assert list(prefixes([])) == []
  194. assert list(prefixes([1])) == [[1]]
  195. assert list(prefixes([1, 2])) == [[1], [1, 2]]
  196. assert list(prefixes([1, 2, 3, 4, 5])) == \
  197. [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5]]
  198. def test_postfixes():
  199. assert list(postfixes([])) == []
  200. assert list(postfixes([1])) == [[1]]
  201. assert list(postfixes([1, 2])) == [[2], [1, 2]]
  202. assert list(postfixes([1, 2, 3, 4, 5])) == \
  203. [[5], [4, 5], [3, 4, 5], [2, 3, 4, 5], [1, 2, 3, 4, 5]]
  204. def test_topological_sort():
  205. V = [2, 3, 5, 7, 8, 9, 10, 11]
  206. E = [(7, 11), (7, 8), (5, 11),
  207. (3, 8), (3, 10), (11, 2),
  208. (11, 9), (11, 10), (8, 9)]
  209. assert topological_sort((V, E)) == [3, 5, 7, 8, 11, 2, 9, 10]
  210. assert topological_sort((V, E), key=lambda v: -v) == \
  211. [7, 5, 11, 3, 10, 8, 9, 2]
  212. raises(ValueError, lambda: topological_sort((V, E + [(10, 7)])))
  213. def test_strongly_connected_components():
  214. assert strongly_connected_components(([], [])) == []
  215. assert strongly_connected_components(([1, 2, 3], [])) == [[1], [2], [3]]
  216. V = [1, 2, 3]
  217. E = [(1, 2), (1, 3), (2, 1), (2, 3), (3, 1)]
  218. assert strongly_connected_components((V, E)) == [[1, 2, 3]]
  219. V = [1, 2, 3, 4]
  220. E = [(1, 2), (2, 3), (3, 2), (3, 4)]
  221. assert strongly_connected_components((V, E)) == [[4], [2, 3], [1]]
  222. V = [1, 2, 3, 4]
  223. E = [(1, 2), (2, 1), (3, 4), (4, 3)]
  224. assert strongly_connected_components((V, E)) == [[1, 2], [3, 4]]
  225. def test_connected_components():
  226. assert connected_components(([], [])) == []
  227. assert connected_components(([1, 2, 3], [])) == [[1], [2], [3]]
  228. V = [1, 2, 3]
  229. E = [(1, 2), (1, 3), (2, 1), (2, 3), (3, 1)]
  230. assert connected_components((V, E)) == [[1, 2, 3]]
  231. V = [1, 2, 3, 4]
  232. E = [(1, 2), (2, 3), (3, 2), (3, 4)]
  233. assert connected_components((V, E)) == [[1, 2, 3, 4]]
  234. V = [1, 2, 3, 4]
  235. E = [(1, 2), (3, 4)]
  236. assert connected_components((V, E)) == [[1, 2], [3, 4]]
  237. def test_rotate():
  238. A = [0, 1, 2, 3, 4]
  239. assert rotate_left(A, 2) == [2, 3, 4, 0, 1]
  240. assert rotate_right(A, 1) == [4, 0, 1, 2, 3]
  241. A = []
  242. B = rotate_right(A, 1)
  243. assert B == []
  244. B.append(1)
  245. assert A == []
  246. B = rotate_left(A, 1)
  247. assert B == []
  248. B.append(1)
  249. assert A == []
  250. def test_multiset_partitions():
  251. A = [0, 1, 2, 3, 4]
  252. assert list(multiset_partitions(A, 5)) == [[[0], [1], [2], [3], [4]]]
  253. assert len(list(multiset_partitions(A, 4))) == 10
  254. assert len(list(multiset_partitions(A, 3))) == 25
  255. assert list(multiset_partitions([1, 1, 1, 2, 2], 2)) == [
  256. [[1, 1, 1, 2], [2]], [[1, 1, 1], [2, 2]], [[1, 1, 2, 2], [1]],
  257. [[1, 1, 2], [1, 2]], [[1, 1], [1, 2, 2]]]
  258. assert list(multiset_partitions([1, 1, 2, 2], 2)) == [
  259. [[1, 1, 2], [2]], [[1, 1], [2, 2]], [[1, 2, 2], [1]],
  260. [[1, 2], [1, 2]]]
  261. assert list(multiset_partitions([1, 2, 3, 4], 2)) == [
  262. [[1, 2, 3], [4]], [[1, 2, 4], [3]], [[1, 2], [3, 4]],
  263. [[1, 3, 4], [2]], [[1, 3], [2, 4]], [[1, 4], [2, 3]],
  264. [[1], [2, 3, 4]]]
  265. assert list(multiset_partitions([1, 2, 2], 2)) == [
  266. [[1, 2], [2]], [[1], [2, 2]]]
  267. assert list(multiset_partitions(3)) == [
  268. [[0, 1, 2]], [[0, 1], [2]], [[0, 2], [1]], [[0], [1, 2]],
  269. [[0], [1], [2]]]
  270. assert list(multiset_partitions(3, 2)) == [
  271. [[0, 1], [2]], [[0, 2], [1]], [[0], [1, 2]]]
  272. assert list(multiset_partitions([1] * 3, 2)) == [[[1], [1, 1]]]
  273. assert list(multiset_partitions([1] * 3)) == [
  274. [[1, 1, 1]], [[1], [1, 1]], [[1], [1], [1]]]
  275. a = [3, 2, 1]
  276. assert list(multiset_partitions(a)) == \
  277. list(multiset_partitions(sorted(a)))
  278. assert list(multiset_partitions(a, 5)) == []
  279. assert list(multiset_partitions(a, 1)) == [[[1, 2, 3]]]
  280. assert list(multiset_partitions(a + [4], 5)) == []
  281. assert list(multiset_partitions(a + [4], 1)) == [[[1, 2, 3, 4]]]
  282. assert list(multiset_partitions(2, 5)) == []
  283. assert list(multiset_partitions(2, 1)) == [[[0, 1]]]
  284. assert list(multiset_partitions('a')) == [[['a']]]
  285. assert list(multiset_partitions('a', 2)) == []
  286. assert list(multiset_partitions('ab')) == [[['a', 'b']], [['a'], ['b']]]
  287. assert list(multiset_partitions('ab', 1)) == [[['a', 'b']]]
  288. assert list(multiset_partitions('aaa', 1)) == [['aaa']]
  289. assert list(multiset_partitions([1, 1], 1)) == [[[1, 1]]]
  290. ans = [('mpsyy',), ('mpsy', 'y'), ('mps', 'yy'), ('mps', 'y', 'y'),
  291. ('mpyy', 's'), ('mpy', 'sy'), ('mpy', 's', 'y'), ('mp', 'syy'),
  292. ('mp', 'sy', 'y'), ('mp', 's', 'yy'), ('mp', 's', 'y', 'y'),
  293. ('msyy', 'p'), ('msy', 'py'), ('msy', 'p', 'y'), ('ms', 'pyy'),
  294. ('ms', 'py', 'y'), ('ms', 'p', 'yy'), ('ms', 'p', 'y', 'y'),
  295. ('myy', 'ps'), ('myy', 'p', 's'), ('my', 'psy'), ('my', 'ps', 'y'),
  296. ('my', 'py', 's'), ('my', 'p', 'sy'), ('my', 'p', 's', 'y'),
  297. ('m', 'psyy'), ('m', 'psy', 'y'), ('m', 'ps', 'yy'),
  298. ('m', 'ps', 'y', 'y'), ('m', 'pyy', 's'), ('m', 'py', 'sy'),
  299. ('m', 'py', 's', 'y'), ('m', 'p', 'syy'),
  300. ('m', 'p', 'sy', 'y'), ('m', 'p', 's', 'yy'),
  301. ('m', 'p', 's', 'y', 'y')]
  302. assert list(tuple("".join(part) for part in p)
  303. for p in multiset_partitions('sympy')) == ans
  304. factorings = [[24], [8, 3], [12, 2], [4, 6], [4, 2, 3],
  305. [6, 2, 2], [2, 2, 2, 3]]
  306. assert list(factoring_visitor(p, [2,3]) for
  307. p in multiset_partitions_taocp([3, 1])) == factorings
  308. def test_multiset_combinations():
  309. ans = ['iii', 'iim', 'iip', 'iis', 'imp', 'ims', 'ipp', 'ips',
  310. 'iss', 'mpp', 'mps', 'mss', 'pps', 'pss', 'sss']
  311. assert [''.join(i) for i in
  312. list(multiset_combinations('mississippi', 3))] == ans
  313. M = multiset('mississippi')
  314. assert [''.join(i) for i in
  315. list(multiset_combinations(M, 3))] == ans
  316. assert [''.join(i) for i in multiset_combinations(M, 30)] == []
  317. assert list(multiset_combinations([[1], [2, 3]], 2)) == [[[1], [2, 3]]]
  318. assert len(list(multiset_combinations('a', 3))) == 0
  319. assert len(list(multiset_combinations('a', 0))) == 1
  320. assert list(multiset_combinations('abc', 1)) == [['a'], ['b'], ['c']]
  321. raises(ValueError, lambda: list(multiset_combinations({0: 3, 1: -1}, 2)))
  322. def test_multiset_permutations():
  323. ans = ['abby', 'abyb', 'aybb', 'baby', 'bayb', 'bbay', 'bbya', 'byab',
  324. 'byba', 'yabb', 'ybab', 'ybba']
  325. assert [''.join(i) for i in multiset_permutations('baby')] == ans
  326. assert [''.join(i) for i in multiset_permutations(multiset('baby'))] == ans
  327. assert list(multiset_permutations([0, 0, 0], 2)) == [[0, 0]]
  328. assert list(multiset_permutations([0, 2, 1], 2)) == [
  329. [0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]]
  330. assert len(list(multiset_permutations('a', 0))) == 1
  331. assert len(list(multiset_permutations('a', 3))) == 0
  332. for nul in ([], {}, ''):
  333. assert list(multiset_permutations(nul)) == [[]]
  334. assert list(multiset_permutations(nul, 0)) == [[]]
  335. # impossible requests give no result
  336. assert list(multiset_permutations(nul, 1)) == []
  337. assert list(multiset_permutations(nul, -1)) == []
  338. def test():
  339. for i in range(1, 7):
  340. print(i)
  341. for p in multiset_permutations([0, 0, 1, 0, 1], i):
  342. print(p)
  343. assert capture(lambda: test()) == dedent('''\
  344. 1
  345. [0]
  346. [1]
  347. 2
  348. [0, 0]
  349. [0, 1]
  350. [1, 0]
  351. [1, 1]
  352. 3
  353. [0, 0, 0]
  354. [0, 0, 1]
  355. [0, 1, 0]
  356. [0, 1, 1]
  357. [1, 0, 0]
  358. [1, 0, 1]
  359. [1, 1, 0]
  360. 4
  361. [0, 0, 0, 1]
  362. [0, 0, 1, 0]
  363. [0, 0, 1, 1]
  364. [0, 1, 0, 0]
  365. [0, 1, 0, 1]
  366. [0, 1, 1, 0]
  367. [1, 0, 0, 0]
  368. [1, 0, 0, 1]
  369. [1, 0, 1, 0]
  370. [1, 1, 0, 0]
  371. 5
  372. [0, 0, 0, 1, 1]
  373. [0, 0, 1, 0, 1]
  374. [0, 0, 1, 1, 0]
  375. [0, 1, 0, 0, 1]
  376. [0, 1, 0, 1, 0]
  377. [0, 1, 1, 0, 0]
  378. [1, 0, 0, 0, 1]
  379. [1, 0, 0, 1, 0]
  380. [1, 0, 1, 0, 0]
  381. [1, 1, 0, 0, 0]
  382. 6\n''')
  383. raises(ValueError, lambda: list(multiset_permutations({0: 3, 1: -1})))
  384. def test_partitions():
  385. ans = [[{}], [(0, {})]]
  386. for i in range(2):
  387. assert list(partitions(0, size=i)) == ans[i]
  388. assert list(partitions(1, 0, size=i)) == ans[i]
  389. assert list(partitions(6, 2, 2, size=i)) == ans[i]
  390. assert list(partitions(6, 2, None, size=i)) != ans[i]
  391. assert list(partitions(6, None, 2, size=i)) != ans[i]
  392. assert list(partitions(6, 2, 0, size=i)) == ans[i]
  393. assert [p for p in partitions(6, k=2)] == [
  394. {2: 3}, {1: 2, 2: 2}, {1: 4, 2: 1}, {1: 6}]
  395. assert [p for p in partitions(6, k=3)] == [
  396. {3: 2}, {1: 1, 2: 1, 3: 1}, {1: 3, 3: 1}, {2: 3}, {1: 2, 2: 2},
  397. {1: 4, 2: 1}, {1: 6}]
  398. assert [p for p in partitions(8, k=4, m=3)] == [
  399. {4: 2}, {1: 1, 3: 1, 4: 1}, {2: 2, 4: 1}, {2: 1, 3: 2}] == [
  400. i for i in partitions(8, k=4, m=3) if all(k <= 4 for k in i)
  401. and sum(i.values()) <=3]
  402. assert [p for p in partitions(S(3), m=2)] == [
  403. {3: 1}, {1: 1, 2: 1}]
  404. assert [i for i in partitions(4, k=3)] == [
  405. {1: 1, 3: 1}, {2: 2}, {1: 2, 2: 1}, {1: 4}] == [
  406. i for i in partitions(4) if all(k <= 3 for k in i)]
  407. # Consistency check on output of _partitions and RGS_unrank.
  408. # This provides a sanity test on both routines. Also verifies that
  409. # the total number of partitions is the same in each case.
  410. # (from pkrathmann2)
  411. for n in range(2, 6):
  412. i = 0
  413. for m, q in _set_partitions(n):
  414. assert q == RGS_unrank(i, n)
  415. i += 1
  416. assert i == RGS_enum(n)
  417. def test_binary_partitions():
  418. assert [i[:] for i in binary_partitions(10)] == [[8, 2], [8, 1, 1],
  419. [4, 4, 2], [4, 4, 1, 1], [4, 2, 2, 2], [4, 2, 2, 1, 1],
  420. [4, 2, 1, 1, 1, 1], [4, 1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2],
  421. [2, 2, 2, 2, 1, 1], [2, 2, 2, 1, 1, 1, 1], [2, 2, 1, 1, 1, 1, 1, 1],
  422. [2, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
  423. assert len([j[:] for j in binary_partitions(16)]) == 36
  424. def test_bell_perm():
  425. assert [len(set(generate_bell(i))) for i in range(1, 7)] == [
  426. factorial(i) for i in range(1, 7)]
  427. assert list(generate_bell(3)) == [
  428. (0, 1, 2), (0, 2, 1), (2, 0, 1), (2, 1, 0), (1, 2, 0), (1, 0, 2)]
  429. # generate_bell and trotterjohnson are advertised to return the same
  430. # permutations; this is not technically necessary so this test could
  431. # be removed
  432. for n in range(1, 5):
  433. p = Permutation(range(n))
  434. b = generate_bell(n)
  435. for bi in b:
  436. assert bi == tuple(p.array_form)
  437. p = p.next_trotterjohnson()
  438. raises(ValueError, lambda: list(generate_bell(0))) # XXX is this consistent with other permutation algorithms?
  439. def test_involutions():
  440. lengths = [1, 2, 4, 10, 26, 76]
  441. for n, N in enumerate(lengths):
  442. i = list(generate_involutions(n + 1))
  443. assert len(i) == N
  444. assert len({Permutation(j)**2 for j in i}) == 1
  445. def test_derangements():
  446. assert len(list(generate_derangements(list(range(6))))) == 265
  447. assert ''.join(''.join(i) for i in generate_derangements('abcde')) == (
  448. 'badecbaecdbcaedbcdeabceadbdaecbdeacbdecabeacdbedacbedcacabedcadebcaebd'
  449. 'cdaebcdbeacdeabcdebaceabdcebadcedabcedbadabecdaebcdaecbdcaebdcbeadceab'
  450. 'dcebadeabcdeacbdebacdebcaeabcdeadbceadcbecabdecbadecdabecdbaedabcedacb'
  451. 'edbacedbca')
  452. assert list(generate_derangements([0, 1, 2, 3])) == [
  453. [1, 0, 3, 2], [1, 2, 3, 0], [1, 3, 0, 2], [2, 0, 3, 1],
  454. [2, 3, 0, 1], [2, 3, 1, 0], [3, 0, 1, 2], [3, 2, 0, 1], [3, 2, 1, 0]]
  455. assert list(generate_derangements([0, 1, 2, 2])) == [
  456. [2, 2, 0, 1], [2, 2, 1, 0]]
  457. assert list(generate_derangements('ba')) == [list('ab')]
  458. # multiset_derangements
  459. D = multiset_derangements
  460. assert list(D('abb')) == []
  461. assert [''.join(i) for i in D('ab')] == ['ba']
  462. assert [''.join(i) for i in D('abc')] == ['bca', 'cab']
  463. assert [''.join(i) for i in D('aabb')] == ['bbaa']
  464. assert [''.join(i) for i in D('aabbcccc')] == [
  465. 'ccccaabb', 'ccccabab', 'ccccabba', 'ccccbaab', 'ccccbaba',
  466. 'ccccbbaa']
  467. assert [''.join(i) for i in D('aabbccc')] == [
  468. 'cccabba', 'cccabab', 'cccaabb', 'ccacbba', 'ccacbab',
  469. 'ccacabb', 'cbccbaa', 'cbccaba', 'cbccaab', 'bcccbaa',
  470. 'bcccaba', 'bcccaab']
  471. assert [''.join(i) for i in D('books')] == ['kbsoo', 'ksboo',
  472. 'sbkoo', 'skboo', 'oksbo', 'oskbo', 'okbso', 'obkso', 'oskob',
  473. 'oksob', 'osbok', 'obsok']
  474. assert list(generate_derangements([[3], [2], [2], [1]])) == [
  475. [[2], [1], [3], [2]], [[2], [3], [1], [2]]]
  476. def test_necklaces():
  477. def count(n, k, f):
  478. return len(list(necklaces(n, k, f)))
  479. m = []
  480. for i in range(1, 8):
  481. m.append((
  482. i, count(i, 2, 0), count(i, 2, 1), count(i, 3, 1)))
  483. assert Matrix(m) == Matrix([
  484. [1, 2, 2, 3],
  485. [2, 3, 3, 6],
  486. [3, 4, 4, 10],
  487. [4, 6, 6, 21],
  488. [5, 8, 8, 39],
  489. [6, 14, 13, 92],
  490. [7, 20, 18, 198]])
  491. def test_bracelets():
  492. bc = [i for i in bracelets(2, 4)]
  493. assert Matrix(bc) == Matrix([
  494. [0, 0],
  495. [0, 1],
  496. [0, 2],
  497. [0, 3],
  498. [1, 1],
  499. [1, 2],
  500. [1, 3],
  501. [2, 2],
  502. [2, 3],
  503. [3, 3]
  504. ])
  505. bc = [i for i in bracelets(4, 2)]
  506. assert Matrix(bc) == Matrix([
  507. [0, 0, 0, 0],
  508. [0, 0, 0, 1],
  509. [0, 0, 1, 1],
  510. [0, 1, 0, 1],
  511. [0, 1, 1, 1],
  512. [1, 1, 1, 1]
  513. ])
  514. def test_generate_oriented_forest():
  515. assert list(generate_oriented_forest(5)) == [[0, 1, 2, 3, 4],
  516. [0, 1, 2, 3, 3], [0, 1, 2, 3, 2], [0, 1, 2, 3, 1], [0, 1, 2, 3, 0],
  517. [0, 1, 2, 2, 2], [0, 1, 2, 2, 1], [0, 1, 2, 2, 0], [0, 1, 2, 1, 2],
  518. [0, 1, 2, 1, 1], [0, 1, 2, 1, 0], [0, 1, 2, 0, 1], [0, 1, 2, 0, 0],
  519. [0, 1, 1, 1, 1], [0, 1, 1, 1, 0], [0, 1, 1, 0, 1], [0, 1, 1, 0, 0],
  520. [0, 1, 0, 1, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 0]]
  521. assert len(list(generate_oriented_forest(10))) == 1842
  522. def test_unflatten():
  523. r = list(range(10))
  524. assert unflatten(r) == list(zip(r[::2], r[1::2]))
  525. assert unflatten(r, 5) == [tuple(r[:5]), tuple(r[5:])]
  526. raises(ValueError, lambda: unflatten(list(range(10)), 3))
  527. raises(ValueError, lambda: unflatten(list(range(10)), -2))
  528. def test_common_prefix_suffix():
  529. assert common_prefix([], [1]) == []
  530. assert common_prefix(list(range(3))) == [0, 1, 2]
  531. assert common_prefix(list(range(3)), list(range(4))) == [0, 1, 2]
  532. assert common_prefix([1, 2, 3], [1, 2, 5]) == [1, 2]
  533. assert common_prefix([1, 2, 3], [1, 3, 5]) == [1]
  534. assert common_suffix([], [1]) == []
  535. assert common_suffix(list(range(3))) == [0, 1, 2]
  536. assert common_suffix(list(range(3)), list(range(3))) == [0, 1, 2]
  537. assert common_suffix(list(range(3)), list(range(4))) == []
  538. assert common_suffix([1, 2, 3], [9, 2, 3]) == [2, 3]
  539. assert common_suffix([1, 2, 3], [9, 7, 3]) == [3]
  540. def test_minlex():
  541. assert minlex([1, 2, 0]) == (0, 1, 2)
  542. assert minlex((1, 2, 0)) == (0, 1, 2)
  543. assert minlex((1, 0, 2)) == (0, 2, 1)
  544. assert minlex((1, 0, 2), directed=False) == (0, 1, 2)
  545. assert minlex('aba') == 'aab'
  546. assert minlex(('bb', 'aaa', 'c', 'a'), key=len) == ('c', 'a', 'bb', 'aaa')
  547. def test_ordered():
  548. assert list(ordered((x, y), hash, default=False)) in [[x, y], [y, x]]
  549. assert list(ordered((x, y), hash, default=False)) == \
  550. list(ordered((y, x), hash, default=False))
  551. assert list(ordered((x, y))) == [x, y]
  552. seq, keys = [[[1, 2, 1], [0, 3, 1], [1, 1, 3], [2], [1]],
  553. (lambda x: len(x), lambda x: sum(x))]
  554. assert list(ordered(seq, keys, default=False, warn=False)) == \
  555. [[1], [2], [1, 2, 1], [0, 3, 1], [1, 1, 3]]
  556. raises(ValueError, lambda:
  557. list(ordered(seq, keys, default=False, warn=True)))
  558. def test_runs():
  559. assert runs([]) == []
  560. assert runs([1]) == [[1]]
  561. assert runs([1, 1]) == [[1], [1]]
  562. assert runs([1, 1, 2]) == [[1], [1, 2]]
  563. assert runs([1, 2, 1]) == [[1, 2], [1]]
  564. assert runs([2, 1, 1]) == [[2], [1], [1]]
  565. from operator import lt
  566. assert runs([2, 1, 1], lt) == [[2, 1], [1]]
  567. def test_reshape():
  568. seq = list(range(1, 9))
  569. assert reshape(seq, [4]) == \
  570. [[1, 2, 3, 4], [5, 6, 7, 8]]
  571. assert reshape(seq, (4,)) == \
  572. [(1, 2, 3, 4), (5, 6, 7, 8)]
  573. assert reshape(seq, (2, 2)) == \
  574. [(1, 2, 3, 4), (5, 6, 7, 8)]
  575. assert reshape(seq, (2, [2])) == \
  576. [(1, 2, [3, 4]), (5, 6, [7, 8])]
  577. assert reshape(seq, ((2,), [2])) == \
  578. [((1, 2), [3, 4]), ((5, 6), [7, 8])]
  579. assert reshape(seq, (1, [2], 1)) == \
  580. [(1, [2, 3], 4), (5, [6, 7], 8)]
  581. assert reshape(tuple(seq), ([[1], 1, (2,)],)) == \
  582. (([[1], 2, (3, 4)],), ([[5], 6, (7, 8)],))
  583. assert reshape(tuple(seq), ([1], 1, (2,))) == \
  584. (([1], 2, (3, 4)), ([5], 6, (7, 8)))
  585. assert reshape(list(range(12)), [2, [3], {2}, (1, (3,), 1)]) == \
  586. [[0, 1, [2, 3, 4], {5, 6}, (7, (8, 9, 10), 11)]]
  587. raises(ValueError, lambda: reshape([0, 1], [-1]))
  588. raises(ValueError, lambda: reshape([0, 1], [3]))
  589. def test_uniq():
  590. assert list(uniq(p for p in partitions(4))) == \
  591. [{4: 1}, {1: 1, 3: 1}, {2: 2}, {1: 2, 2: 1}, {1: 4}]
  592. assert list(uniq(x % 2 for x in range(5))) == [0, 1]
  593. assert list(uniq('a')) == ['a']
  594. assert list(uniq('ababc')) == list('abc')
  595. assert list(uniq([[1], [2, 1], [1]])) == [[1], [2, 1]]
  596. assert list(uniq(permutations(i for i in [[1], 2, 2]))) == \
  597. [([1], 2, 2), (2, [1], 2), (2, 2, [1])]
  598. assert list(uniq([2, 3, 2, 4, [2], [1], [2], [3], [1]])) == \
  599. [2, 3, 4, [2], [1], [3]]
  600. f = [1]
  601. raises(RuntimeError, lambda: [f.remove(i) for i in uniq(f)])
  602. f = [[1]]
  603. raises(RuntimeError, lambda: [f.remove(i) for i in uniq(f)])
  604. def test_kbins():
  605. assert len(list(kbins('1123', 2, ordered=1))) == 24
  606. assert len(list(kbins('1123', 2, ordered=11))) == 36
  607. assert len(list(kbins('1123', 2, ordered=10))) == 10
  608. assert len(list(kbins('1123', 2, ordered=0))) == 5
  609. assert len(list(kbins('1123', 2, ordered=None))) == 3
  610. def test1():
  611. for orderedval in [None, 0, 1, 10, 11]:
  612. print('ordered =', orderedval)
  613. for p in kbins([0, 0, 1], 2, ordered=orderedval):
  614. print(' ', p)
  615. assert capture(lambda : test1()) == dedent('''\
  616. ordered = None
  617. [[0], [0, 1]]
  618. [[0, 0], [1]]
  619. ordered = 0
  620. [[0, 0], [1]]
  621. [[0, 1], [0]]
  622. ordered = 1
  623. [[0], [0, 1]]
  624. [[0], [1, 0]]
  625. [[1], [0, 0]]
  626. ordered = 10
  627. [[0, 0], [1]]
  628. [[1], [0, 0]]
  629. [[0, 1], [0]]
  630. [[0], [0, 1]]
  631. ordered = 11
  632. [[0], [0, 1]]
  633. [[0, 0], [1]]
  634. [[0], [1, 0]]
  635. [[0, 1], [0]]
  636. [[1], [0, 0]]
  637. [[1, 0], [0]]\n''')
  638. def test2():
  639. for orderedval in [None, 0, 1, 10, 11]:
  640. print('ordered =', orderedval)
  641. for p in kbins(list(range(3)), 2, ordered=orderedval):
  642. print(' ', p)
  643. assert capture(lambda : test2()) == dedent('''\
  644. ordered = None
  645. [[0], [1, 2]]
  646. [[0, 1], [2]]
  647. ordered = 0
  648. [[0, 1], [2]]
  649. [[0, 2], [1]]
  650. [[0], [1, 2]]
  651. ordered = 1
  652. [[0], [1, 2]]
  653. [[0], [2, 1]]
  654. [[1], [0, 2]]
  655. [[1], [2, 0]]
  656. [[2], [0, 1]]
  657. [[2], [1, 0]]
  658. ordered = 10
  659. [[0, 1], [2]]
  660. [[2], [0, 1]]
  661. [[0, 2], [1]]
  662. [[1], [0, 2]]
  663. [[0], [1, 2]]
  664. [[1, 2], [0]]
  665. ordered = 11
  666. [[0], [1, 2]]
  667. [[0, 1], [2]]
  668. [[0], [2, 1]]
  669. [[0, 2], [1]]
  670. [[1], [0, 2]]
  671. [[1, 0], [2]]
  672. [[1], [2, 0]]
  673. [[1, 2], [0]]
  674. [[2], [0, 1]]
  675. [[2, 0], [1]]
  676. [[2], [1, 0]]
  677. [[2, 1], [0]]\n''')
  678. def test_has_dups():
  679. assert has_dups(set()) is False
  680. assert has_dups(list(range(3))) is False
  681. assert has_dups([1, 2, 1]) is True
  682. assert has_dups([[1], [1]]) is True
  683. assert has_dups([[1], [2]]) is False
  684. def test__partition():
  685. assert _partition('abcde', [1, 0, 1, 2, 0]) == [
  686. ['b', 'e'], ['a', 'c'], ['d']]
  687. assert _partition('abcde', [1, 0, 1, 2, 0], 3) == [
  688. ['b', 'e'], ['a', 'c'], ['d']]
  689. output = (3, [1, 0, 1, 2, 0])
  690. assert _partition('abcde', *output) == [['b', 'e'], ['a', 'c'], ['d']]
  691. def test_ordered_partitions():
  692. from sympy.functions.combinatorial.numbers import nT
  693. f = ordered_partitions
  694. assert list(f(0, 1)) == [[]]
  695. assert list(f(1, 0)) == [[]]
  696. for i in range(1, 7):
  697. for j in [None] + list(range(1, i)):
  698. assert (
  699. sum(1 for p in f(i, j, 1)) ==
  700. sum(1 for p in f(i, j, 0)) ==
  701. nT(i, j))
  702. def test_rotations():
  703. assert list(rotations('ab')) == [['a', 'b'], ['b', 'a']]
  704. assert list(rotations(range(3))) == [[0, 1, 2], [1, 2, 0], [2, 0, 1]]
  705. assert list(rotations(range(3), dir=-1)) == [[0, 1, 2], [2, 0, 1], [1, 2, 0]]
  706. def test_ibin():
  707. assert ibin(3) == [1, 1]
  708. assert ibin(3, 3) == [0, 1, 1]
  709. assert ibin(3, str=True) == '11'
  710. assert ibin(3, 3, str=True) == '011'
  711. assert list(ibin(2, 'all')) == [(0, 0), (0, 1), (1, 0), (1, 1)]
  712. assert list(ibin(2, '', str=True)) == ['00', '01', '10', '11']
  713. raises(ValueError, lambda: ibin(-.5))
  714. raises(ValueError, lambda: ibin(2, 1))
  715. def test_iterable():
  716. assert iterable(0) is False
  717. assert iterable(1) is False
  718. assert iterable(None) is False
  719. class Test1(NotIterable):
  720. pass
  721. assert iterable(Test1()) is False
  722. class Test2(NotIterable):
  723. _iterable = True
  724. assert iterable(Test2()) is True
  725. class Test3:
  726. pass
  727. assert iterable(Test3()) is False
  728. class Test4:
  729. _iterable = True
  730. assert iterable(Test4()) is True
  731. class Test5:
  732. def __iter__(self):
  733. yield 1
  734. assert iterable(Test5()) is True
  735. class Test6(Test5):
  736. _iterable = False
  737. assert iterable(Test6()) is False