图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

366 lines
16 KiB

  1. from sympy.core.numbers import (E, Rational, pi)
  2. from sympy.functions.elementary.exponential import exp
  3. from sympy.functions.elementary.miscellaneous import sqrt
  4. from sympy.core import S, symbols, I
  5. from sympy.discrete.convolutions import (
  6. convolution, convolution_fft, convolution_ntt, convolution_fwht,
  7. convolution_subset, covering_product, intersecting_product)
  8. from sympy.testing.pytest import raises
  9. from sympy.abc import x, y
  10. def test_convolution():
  11. # fft
  12. a = [1, Rational(5, 3), sqrt(3), Rational(7, 5)]
  13. b = [9, 5, 5, 4, 3, 2]
  14. c = [3, 5, 3, 7, 8]
  15. d = [1422, 6572, 3213, 5552]
  16. assert convolution(a, b) == convolution_fft(a, b)
  17. assert convolution(a, b, dps=9) == convolution_fft(a, b, dps=9)
  18. assert convolution(a, d, dps=7) == convolution_fft(d, a, dps=7)
  19. assert convolution(a, d[1:], dps=3) == convolution_fft(d[1:], a, dps=3)
  20. # prime moduli of the form (m*2**k + 1), sequence length
  21. # should be a divisor of 2**k
  22. p = 7*17*2**23 + 1
  23. q = 19*2**10 + 1
  24. # ntt
  25. assert convolution(d, b, prime=q) == convolution_ntt(b, d, prime=q)
  26. assert convolution(c, b, prime=p) == convolution_ntt(b, c, prime=p)
  27. assert convolution(d, c, prime=p) == convolution_ntt(c, d, prime=p)
  28. raises(TypeError, lambda: convolution(b, d, dps=5, prime=q))
  29. raises(TypeError, lambda: convolution(b, d, dps=6, prime=q))
  30. # fwht
  31. assert convolution(a, b, dyadic=True) == convolution_fwht(a, b)
  32. assert convolution(a, b, dyadic=False) == convolution(a, b)
  33. raises(TypeError, lambda: convolution(b, d, dps=2, dyadic=True))
  34. raises(TypeError, lambda: convolution(b, d, prime=p, dyadic=True))
  35. raises(TypeError, lambda: convolution(a, b, dps=2, dyadic=True))
  36. raises(TypeError, lambda: convolution(b, c, prime=p, dyadic=True))
  37. # subset
  38. assert convolution(a, b, subset=True) == convolution_subset(a, b) == \
  39. convolution(a, b, subset=True, dyadic=False) == \
  40. convolution(a, b, subset=True)
  41. assert convolution(a, b, subset=False) == convolution(a, b)
  42. raises(TypeError, lambda: convolution(a, b, subset=True, dyadic=True))
  43. raises(TypeError, lambda: convolution(c, d, subset=True, dps=6))
  44. raises(TypeError, lambda: convolution(a, c, subset=True, prime=q))
  45. def test_cyclic_convolution():
  46. # fft
  47. a = [1, Rational(5, 3), sqrt(3), Rational(7, 5)]
  48. b = [9, 5, 5, 4, 3, 2]
  49. assert convolution([1, 2, 3], [4, 5, 6], cycle=0) == \
  50. convolution([1, 2, 3], [4, 5, 6], cycle=5) == \
  51. convolution([1, 2, 3], [4, 5, 6])
  52. assert convolution([1, 2, 3], [4, 5, 6], cycle=3) == [31, 31, 28]
  53. a = [Rational(1, 3), Rational(7, 3), Rational(5, 9), Rational(2, 7), Rational(5, 8)]
  54. b = [Rational(3, 5), Rational(4, 7), Rational(7, 8), Rational(8, 9)]
  55. assert convolution(a, b, cycle=0) == \
  56. convolution(a, b, cycle=len(a) + len(b) - 1)
  57. assert convolution(a, b, cycle=4) == [Rational(87277, 26460), Rational(30521, 11340),
  58. Rational(11125, 4032), Rational(3653, 1080)]
  59. assert convolution(a, b, cycle=6) == [Rational(20177, 20160), Rational(676, 315), Rational(47, 24),
  60. Rational(3053, 1080), Rational(16397, 5292), Rational(2497, 2268)]
  61. assert convolution(a, b, cycle=9) == \
  62. convolution(a, b, cycle=0) + [S.Zero]
  63. # ntt
  64. a = [2313, 5323532, S(3232), 42142, 42242421]
  65. b = [S(33456), 56757, 45754, 432423]
  66. assert convolution(a, b, prime=19*2**10 + 1, cycle=0) == \
  67. convolution(a, b, prime=19*2**10 + 1, cycle=8) == \
  68. convolution(a, b, prime=19*2**10 + 1)
  69. assert convolution(a, b, prime=19*2**10 + 1, cycle=5) == [96, 17146, 2664,
  70. 15534, 3517]
  71. assert convolution(a, b, prime=19*2**10 + 1, cycle=7) == [4643, 3458, 1260,
  72. 15534, 3517, 16314, 13688]
  73. assert convolution(a, b, prime=19*2**10 + 1, cycle=9) == \
  74. convolution(a, b, prime=19*2**10 + 1) + [0]
  75. # fwht
  76. u, v, w, x, y = symbols('u v w x y')
  77. p, q, r, s, t = symbols('p q r s t')
  78. c = [u, v, w, x, y]
  79. d = [p, q, r, s, t]
  80. assert convolution(a, b, dyadic=True, cycle=3) == \
  81. [2499522285783, 19861417974796, 4702176579021]
  82. assert convolution(a, b, dyadic=True, cycle=5) == [2718149225143,
  83. 2114320852171, 20571217906407, 246166418903, 1413262436976]
  84. assert convolution(c, d, dyadic=True, cycle=4) == \
  85. [p*u + p*y + q*v + r*w + s*x + t*u + t*y,
  86. p*v + q*u + q*y + r*x + s*w + t*v,
  87. p*w + q*x + r*u + r*y + s*v + t*w,
  88. p*x + q*w + r*v + s*u + s*y + t*x]
  89. assert convolution(c, d, dyadic=True, cycle=6) == \
  90. [p*u + q*v + r*w + r*y + s*x + t*w + t*y,
  91. p*v + q*u + r*x + s*w + s*y + t*x,
  92. p*w + q*x + r*u + s*v,
  93. p*x + q*w + r*v + s*u,
  94. p*y + t*u,
  95. q*y + t*v]
  96. # subset
  97. assert convolution(a, b, subset=True, cycle=7) == [18266671799811,
  98. 178235365533, 213958794, 246166418903, 1413262436976,
  99. 2397553088697, 1932759730434]
  100. assert convolution(a[1:], b, subset=True, cycle=4) == \
  101. [178104086592, 302255835516, 244982785880, 3717819845434]
  102. assert convolution(a, b[:-1], subset=True, cycle=6) == [1932837114162,
  103. 178235365533, 213958794, 245166224504, 1413262436976, 2397553088697]
  104. assert convolution(c, d, subset=True, cycle=3) == \
  105. [p*u + p*x + q*w + r*v + r*y + s*u + t*w,
  106. p*v + p*y + q*u + s*y + t*u + t*x,
  107. p*w + q*y + r*u + t*v]
  108. assert convolution(c, d, subset=True, cycle=5) == \
  109. [p*u + q*y + t*v,
  110. p*v + q*u + r*y + t*w,
  111. p*w + r*u + s*y + t*x,
  112. p*x + q*w + r*v + s*u,
  113. p*y + t*u]
  114. raises(ValueError, lambda: convolution([1, 2, 3], [4, 5, 6], cycle=-1))
  115. def test_convolution_fft():
  116. assert all(convolution_fft([], x, dps=y) == [] for x in ([], [1]) for y in (None, 3))
  117. assert convolution_fft([1, 2, 3], [4, 5, 6]) == [4, 13, 28, 27, 18]
  118. assert convolution_fft([1], [5, 6, 7]) == [5, 6, 7]
  119. assert convolution_fft([1, 3], [5, 6, 7]) == [5, 21, 25, 21]
  120. assert convolution_fft([1 + 2*I], [2 + 3*I]) == [-4 + 7*I]
  121. assert convolution_fft([1 + 2*I, 3 + 4*I, 5 + 3*I/5], [Rational(2, 5) + 4*I/7]) == \
  122. [Rational(-26, 35) + I*48/35, Rational(-38, 35) + I*116/35, Rational(58, 35) + I*542/175]
  123. assert convolution_fft([Rational(3, 4), Rational(5, 6)], [Rational(7, 8), Rational(1, 3), Rational(2, 5)]) == \
  124. [Rational(21, 32), Rational(47, 48), Rational(26, 45), Rational(1, 3)]
  125. assert convolution_fft([Rational(1, 9), Rational(2, 3), Rational(3, 5)], [Rational(2, 5), Rational(3, 7), Rational(4, 9)]) == \
  126. [Rational(2, 45), Rational(11, 35), Rational(8152, 14175), Rational(523, 945), Rational(4, 15)]
  127. assert convolution_fft([pi, E, sqrt(2)], [sqrt(3), 1/pi, 1/E]) == \
  128. [sqrt(3)*pi, 1 + sqrt(3)*E, E/pi + pi*exp(-1) + sqrt(6),
  129. sqrt(2)/pi + 1, sqrt(2)*exp(-1)]
  130. assert convolution_fft([2321, 33123], [5321, 6321, 71323]) == \
  131. [12350041, 190918524, 374911166, 2362431729]
  132. assert convolution_fft([312313, 31278232], [32139631, 319631]) == \
  133. [10037624576503, 1005370659728895, 9997492572392]
  134. raises(TypeError, lambda: convolution_fft(x, y))
  135. raises(ValueError, lambda: convolution_fft([x, y], [y, x]))
  136. def test_convolution_ntt():
  137. # prime moduli of the form (m*2**k + 1), sequence length
  138. # should be a divisor of 2**k
  139. p = 7*17*2**23 + 1
  140. q = 19*2**10 + 1
  141. r = 2*500000003 + 1 # only for sequences of length 1 or 2
  142. # s = 2*3*5*7 # composite modulus
  143. assert all(convolution_ntt([], x, prime=y) == [] for x in ([], [1]) for y in (p, q, r))
  144. assert convolution_ntt([2], [3], r) == [6]
  145. assert convolution_ntt([2, 3], [4], r) == [8, 12]
  146. assert convolution_ntt([32121, 42144, 4214, 4241], [32132, 3232, 87242], p) == [33867619,
  147. 459741727, 79180879, 831885249, 381344700, 369993322]
  148. assert convolution_ntt([121913, 3171831, 31888131, 12], [17882, 21292, 29921, 312], q) == \
  149. [8158, 3065, 3682, 7090, 1239, 2232, 3744]
  150. assert convolution_ntt([12, 19, 21, 98, 67], [2, 6, 7, 8, 9], p) == \
  151. convolution_ntt([12, 19, 21, 98, 67], [2, 6, 7, 8, 9], q)
  152. assert convolution_ntt([12, 19, 21, 98, 67], [21, 76, 17, 78, 69], p) == \
  153. convolution_ntt([12, 19, 21, 98, 67], [21, 76, 17, 78, 69], q)
  154. raises(ValueError, lambda: convolution_ntt([2, 3], [4, 5], r))
  155. raises(ValueError, lambda: convolution_ntt([x, y], [y, x], q))
  156. raises(TypeError, lambda: convolution_ntt(x, y, p))
  157. def test_convolution_fwht():
  158. assert convolution_fwht([], []) == []
  159. assert convolution_fwht([], [1]) == []
  160. assert convolution_fwht([1, 2, 3], [4, 5, 6]) == [32, 13, 18, 27]
  161. assert convolution_fwht([Rational(5, 7), Rational(6, 8), Rational(7, 3)], [2, 4, Rational(6, 7)]) == \
  162. [Rational(45, 7), Rational(61, 14), Rational(776, 147), Rational(419, 42)]
  163. a = [1, Rational(5, 3), sqrt(3), Rational(7, 5), 4 + 5*I]
  164. b = [94, 51, 53, 45, 31, 27, 13]
  165. c = [3 + 4*I, 5 + 7*I, 3, Rational(7, 6), 8]
  166. assert convolution_fwht(a, b) == [53*sqrt(3) + 366 + 155*I,
  167. 45*sqrt(3) + Rational(5848, 15) + 135*I,
  168. 94*sqrt(3) + Rational(1257, 5) + 65*I,
  169. 51*sqrt(3) + Rational(3974, 15),
  170. 13*sqrt(3) + 452 + 470*I,
  171. Rational(4513, 15) + 255*I,
  172. 31*sqrt(3) + Rational(1314, 5) + 265*I,
  173. 27*sqrt(3) + Rational(3676, 15) + 225*I]
  174. assert convolution_fwht(b, c) == [Rational(1993, 2) + 733*I, Rational(6215, 6) + 862*I,
  175. Rational(1659, 2) + 527*I, Rational(1988, 3) + 551*I, 1019 + 313*I, Rational(3955, 6) + 325*I,
  176. Rational(1175, 2) + 52*I, Rational(3253, 6) + 91*I]
  177. assert convolution_fwht(a[3:], c) == [Rational(-54, 5) + I*293/5, -1 + I*204/5,
  178. Rational(133, 15) + I*35/6, Rational(409, 30) + 15*I, Rational(56, 5), 32 + 40*I, 0, 0]
  179. u, v, w, x, y, z = symbols('u v w x y z')
  180. assert convolution_fwht([u, v], [x, y]) == [u*x + v*y, u*y + v*x]
  181. assert convolution_fwht([u, v, w], [x, y]) == \
  182. [u*x + v*y, u*y + v*x, w*x, w*y]
  183. assert convolution_fwht([u, v, w], [x, y, z]) == \
  184. [u*x + v*y + w*z, u*y + v*x, u*z + w*x, v*z + w*y]
  185. raises(TypeError, lambda: convolution_fwht(x, y))
  186. raises(TypeError, lambda: convolution_fwht(x*y, u + v))
  187. def test_convolution_subset():
  188. assert convolution_subset([], []) == []
  189. assert convolution_subset([], [Rational(1, 3)]) == []
  190. assert convolution_subset([6 + I*3/7], [Rational(2, 3)]) == [4 + I*2/7]
  191. a = [1, Rational(5, 3), sqrt(3), 4 + 5*I]
  192. b = [64, 71, 55, 47, 33, 29, 15]
  193. c = [3 + I*2/3, 5 + 7*I, 7, Rational(7, 5), 9]
  194. assert convolution_subset(a, b) == [64, Rational(533, 3), 55 + 64*sqrt(3),
  195. 71*sqrt(3) + Rational(1184, 3) + 320*I, 33, 84,
  196. 15 + 33*sqrt(3), 29*sqrt(3) + 157 + 165*I]
  197. assert convolution_subset(b, c) == [192 + I*128/3, 533 + I*1486/3,
  198. 613 + I*110/3, Rational(5013, 5) + I*1249/3,
  199. 675 + 22*I, 891 + I*751/3,
  200. 771 + 10*I, Rational(3736, 5) + 105*I]
  201. assert convolution_subset(a, c) == convolution_subset(c, a)
  202. assert convolution_subset(a[:2], b) == \
  203. [64, Rational(533, 3), 55, Rational(416, 3), 33, 84, 15, 25]
  204. assert convolution_subset(a[:2], c) == \
  205. [3 + I*2/3, 10 + I*73/9, 7, Rational(196, 15), 9, 15, 0, 0]
  206. u, v, w, x, y, z = symbols('u v w x y z')
  207. assert convolution_subset([u, v, w], [x, y]) == [u*x, u*y + v*x, w*x, w*y]
  208. assert convolution_subset([u, v, w, x], [y, z]) == \
  209. [u*y, u*z + v*y, w*y, w*z + x*y]
  210. assert convolution_subset([u, v], [x, y, z]) == \
  211. convolution_subset([x, y, z], [u, v])
  212. raises(TypeError, lambda: convolution_subset(x, z))
  213. raises(TypeError, lambda: convolution_subset(Rational(7, 3), u))
  214. def test_covering_product():
  215. assert covering_product([], []) == []
  216. assert covering_product([], [Rational(1, 3)]) == []
  217. assert covering_product([6 + I*3/7], [Rational(2, 3)]) == [4 + I*2/7]
  218. a = [1, Rational(5, 8), sqrt(7), 4 + 9*I]
  219. b = [66, 81, 95, 49, 37, 89, 17]
  220. c = [3 + I*2/3, 51 + 72*I, 7, Rational(7, 15), 91]
  221. assert covering_product(a, b) == [66, Rational(1383, 8), 95 + 161*sqrt(7),
  222. 130*sqrt(7) + 1303 + 2619*I, 37,
  223. Rational(671, 4), 17 + 54*sqrt(7),
  224. 89*sqrt(7) + Rational(4661, 8) + 1287*I]
  225. assert covering_product(b, c) == [198 + 44*I, 7740 + 10638*I,
  226. 1412 + I*190/3, Rational(42684, 5) + I*31202/3,
  227. 9484 + I*74/3, 22163 + I*27394/3,
  228. 10621 + I*34/3, Rational(90236, 15) + 1224*I]
  229. assert covering_product(a, c) == covering_product(c, a)
  230. assert covering_product(b, c[:-1]) == [198 + 44*I, 7740 + 10638*I,
  231. 1412 + I*190/3, Rational(42684, 5) + I*31202/3,
  232. 111 + I*74/3, 6693 + I*27394/3,
  233. 429 + I*34/3, Rational(23351, 15) + 1224*I]
  234. assert covering_product(a, c[:-1]) == [3 + I*2/3,
  235. Rational(339, 4) + I*1409/12, 7 + 10*sqrt(7) + 2*sqrt(7)*I/3,
  236. -403 + 772*sqrt(7)/15 + 72*sqrt(7)*I + I*12658/15]
  237. u, v, w, x, y, z = symbols('u v w x y z')
  238. assert covering_product([u, v, w], [x, y]) == \
  239. [u*x, u*y + v*x + v*y, w*x, w*y]
  240. assert covering_product([u, v, w, x], [y, z]) == \
  241. [u*y, u*z + v*y + v*z, w*y, w*z + x*y + x*z]
  242. assert covering_product([u, v], [x, y, z]) == \
  243. covering_product([x, y, z], [u, v])
  244. raises(TypeError, lambda: covering_product(x, z))
  245. raises(TypeError, lambda: covering_product(Rational(7, 3), u))
  246. def test_intersecting_product():
  247. assert intersecting_product([], []) == []
  248. assert intersecting_product([], [Rational(1, 3)]) == []
  249. assert intersecting_product([6 + I*3/7], [Rational(2, 3)]) == [4 + I*2/7]
  250. a = [1, sqrt(5), Rational(3, 8) + 5*I, 4 + 7*I]
  251. b = [67, 51, 65, 48, 36, 79, 27]
  252. c = [3 + I*2/5, 5 + 9*I, 7, Rational(7, 19), 13]
  253. assert intersecting_product(a, b) == [195*sqrt(5) + Rational(6979, 8) + 1886*I,
  254. 178*sqrt(5) + 520 + 910*I, Rational(841, 2) + 1344*I,
  255. 192 + 336*I, 0, 0, 0, 0]
  256. assert intersecting_product(b, c) == [Rational(128553, 19) + I*9521/5,
  257. Rational(17820, 19) + 1602*I, Rational(19264, 19), Rational(336, 19), 1846, 0, 0, 0]
  258. assert intersecting_product(a, c) == intersecting_product(c, a)
  259. assert intersecting_product(b[1:], c[:-1]) == [Rational(64788, 19) + I*8622/5,
  260. Rational(12804, 19) + 1152*I, Rational(11508, 19), Rational(252, 19), 0, 0, 0, 0]
  261. assert intersecting_product(a, c[:-2]) == \
  262. [Rational(-99, 5) + 10*sqrt(5) + 2*sqrt(5)*I/5 + I*3021/40,
  263. -43 + 5*sqrt(5) + 9*sqrt(5)*I + 71*I, Rational(245, 8) + 84*I, 0]
  264. u, v, w, x, y, z = symbols('u v w x y z')
  265. assert intersecting_product([u, v, w], [x, y]) == \
  266. [u*x + u*y + v*x + w*x + w*y, v*y, 0, 0]
  267. assert intersecting_product([u, v, w, x], [y, z]) == \
  268. [u*y + u*z + v*y + w*y + w*z + x*y, v*z + x*z, 0, 0]
  269. assert intersecting_product([u, v], [x, y, z]) == \
  270. intersecting_product([x, y, z], [u, v])
  271. raises(TypeError, lambda: intersecting_product(x, z))
  272. raises(TypeError, lambda: intersecting_product(u, Rational(8, 3)))