m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

605 lines
19 KiB

7 months ago
  1. from functools import reduce
  2. import itertools
  3. from operator import add
  4. from sympy.core.add import Add
  5. from sympy.core.containers import Tuple
  6. from sympy.core.function import Function
  7. from sympy.core.mul import Mul
  8. from sympy.core.power import Pow
  9. from sympy.core.relational import Eq
  10. from sympy.core.singleton import S
  11. from sympy.core.symbol import (Symbol, symbols)
  12. from sympy.core.sympify import sympify
  13. from sympy.functions.elementary.exponential import exp
  14. from sympy.functions.elementary.miscellaneous import sqrt
  15. from sympy.functions.elementary.piecewise import Piecewise
  16. from sympy.functions.elementary.trigonometric import (cos, sin)
  17. from sympy.matrices.dense import Matrix
  18. from sympy.polys.rootoftools import CRootOf
  19. from sympy.series.order import O
  20. from sympy.simplify.cse_main import cse
  21. from sympy.simplify.simplify import signsimp
  22. from sympy.tensor.indexed import (Idx, IndexedBase)
  23. from sympy.core.function import count_ops
  24. from sympy.simplify.cse_opts import sub_pre, sub_post
  25. from sympy.functions.special.hyper import meijerg
  26. from sympy.simplify import cse_main, cse_opts
  27. from sympy.utilities.iterables import subsets
  28. from sympy.testing.pytest import XFAIL, raises
  29. from sympy.matrices import (MutableDenseMatrix, MutableSparseMatrix,
  30. ImmutableDenseMatrix, ImmutableSparseMatrix)
  31. from sympy.matrices.expressions import MatrixSymbol
  32. w, x, y, z = symbols('w,x,y,z')
  33. x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12 = symbols('x:13')
  34. def test_numbered_symbols():
  35. ns = cse_main.numbered_symbols(prefix='y')
  36. assert list(itertools.islice(
  37. ns, 0, 10)) == [Symbol('y%s' % i) for i in range(0, 10)]
  38. ns = cse_main.numbered_symbols(prefix='y')
  39. assert list(itertools.islice(
  40. ns, 10, 20)) == [Symbol('y%s' % i) for i in range(10, 20)]
  41. ns = cse_main.numbered_symbols()
  42. assert list(itertools.islice(
  43. ns, 0, 10)) == [Symbol('x%s' % i) for i in range(0, 10)]
  44. # Dummy "optimization" functions for testing.
  45. def opt1(expr):
  46. return expr + y
  47. def opt2(expr):
  48. return expr*z
  49. def test_preprocess_for_cse():
  50. assert cse_main.preprocess_for_cse(x, [(opt1, None)]) == x + y
  51. assert cse_main.preprocess_for_cse(x, [(None, opt1)]) == x
  52. assert cse_main.preprocess_for_cse(x, [(None, None)]) == x
  53. assert cse_main.preprocess_for_cse(x, [(opt1, opt2)]) == x + y
  54. assert cse_main.preprocess_for_cse(
  55. x, [(opt1, None), (opt2, None)]) == (x + y)*z
  56. def test_postprocess_for_cse():
  57. assert cse_main.postprocess_for_cse(x, [(opt1, None)]) == x
  58. assert cse_main.postprocess_for_cse(x, [(None, opt1)]) == x + y
  59. assert cse_main.postprocess_for_cse(x, [(None, None)]) == x
  60. assert cse_main.postprocess_for_cse(x, [(opt1, opt2)]) == x*z
  61. # Note the reverse order of application.
  62. assert cse_main.postprocess_for_cse(
  63. x, [(None, opt1), (None, opt2)]) == x*z + y
  64. def test_cse_single():
  65. # Simple substitution.
  66. e = Add(Pow(x + y, 2), sqrt(x + y))
  67. substs, reduced = cse([e])
  68. assert substs == [(x0, x + y)]
  69. assert reduced == [sqrt(x0) + x0**2]
  70. subst42, (red42,) = cse([42]) # issue_15082
  71. assert len(subst42) == 0 and red42 == 42
  72. subst_half, (red_half,) = cse([0.5])
  73. assert len(subst_half) == 0 and red_half == 0.5
  74. def test_cse_single2():
  75. # Simple substitution, test for being able to pass the expression directly
  76. e = Add(Pow(x + y, 2), sqrt(x + y))
  77. substs, reduced = cse(e)
  78. assert substs == [(x0, x + y)]
  79. assert reduced == [sqrt(x0) + x0**2]
  80. substs, reduced = cse(Matrix([[1]]))
  81. assert isinstance(reduced[0], Matrix)
  82. subst42, (red42,) = cse(42) # issue 15082
  83. assert len(subst42) == 0 and red42 == 42
  84. subst_half, (red_half,) = cse(0.5) # issue 15082
  85. assert len(subst_half) == 0 and red_half == 0.5
  86. def test_cse_not_possible():
  87. # No substitution possible.
  88. e = Add(x, y)
  89. substs, reduced = cse([e])
  90. assert substs == []
  91. assert reduced == [x + y]
  92. # issue 6329
  93. eq = (meijerg((1, 2), (y, 4), (5,), [], x) +
  94. meijerg((1, 3), (y, 4), (5,), [], x))
  95. assert cse(eq) == ([], [eq])
  96. def test_nested_substitution():
  97. # Substitution within a substitution.
  98. e = Add(Pow(w*x + y, 2), sqrt(w*x + y))
  99. substs, reduced = cse([e])
  100. assert substs == [(x0, w*x + y)]
  101. assert reduced == [sqrt(x0) + x0**2]
  102. def test_subtraction_opt():
  103. # Make sure subtraction is optimized.
  104. e = (x - y)*(z - y) + exp((x - y)*(z - y))
  105. substs, reduced = cse(
  106. [e], optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)])
  107. assert substs == [(x0, (x - y)*(y - z))]
  108. assert reduced == [-x0 + exp(-x0)]
  109. e = -(x - y)*(z - y) + exp(-(x - y)*(z - y))
  110. substs, reduced = cse(
  111. [e], optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)])
  112. assert substs == [(x0, (x - y)*(y - z))]
  113. assert reduced == [x0 + exp(x0)]
  114. # issue 4077
  115. n = -1 + 1/x
  116. e = n/x/(-n)**2 - 1/n/x
  117. assert cse(e, optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)]) == \
  118. ([], [0])
  119. assert cse(((w + x + y + z)*(w - y - z))/(w + x)**3) == \
  120. ([(x0, w + x), (x1, y + z)], [(w - x1)*(x0 + x1)/x0**3])
  121. def test_multiple_expressions():
  122. e1 = (x + y)*z
  123. e2 = (x + y)*w
  124. substs, reduced = cse([e1, e2])
  125. assert substs == [(x0, x + y)]
  126. assert reduced == [x0*z, x0*w]
  127. l = [w*x*y + z, w*y]
  128. substs, reduced = cse(l)
  129. rsubsts, _ = cse(reversed(l))
  130. assert substs == rsubsts
  131. assert reduced == [z + x*x0, x0]
  132. l = [w*x*y, w*x*y + z, w*y]
  133. substs, reduced = cse(l)
  134. rsubsts, _ = cse(reversed(l))
  135. assert substs == rsubsts
  136. assert reduced == [x1, x1 + z, x0]
  137. l = [(x - z)*(y - z), x - z, y - z]
  138. substs, reduced = cse(l)
  139. rsubsts, _ = cse(reversed(l))
  140. assert substs == [(x0, -z), (x1, x + x0), (x2, x0 + y)]
  141. assert rsubsts == [(x0, -z), (x1, x0 + y), (x2, x + x0)]
  142. assert reduced == [x1*x2, x1, x2]
  143. l = [w*y + w + x + y + z, w*x*y]
  144. assert cse(l) == ([(x0, w*y)], [w + x + x0 + y + z, x*x0])
  145. assert cse([x + y, x + y + z]) == ([(x0, x + y)], [x0, z + x0])
  146. assert cse([x + y, x + z]) == ([], [x + y, x + z])
  147. assert cse([x*y, z + x*y, x*y*z + 3]) == \
  148. ([(x0, x*y)], [x0, z + x0, 3 + x0*z])
  149. @XFAIL # CSE of non-commutative Mul terms is disabled
  150. def test_non_commutative_cse():
  151. A, B, C = symbols('A B C', commutative=False)
  152. l = [A*B*C, A*C]
  153. assert cse(l) == ([], l)
  154. l = [A*B*C, A*B]
  155. assert cse(l) == ([(x0, A*B)], [x0*C, x0])
  156. # Test if CSE of non-commutative Mul terms is disabled
  157. def test_bypass_non_commutatives():
  158. A, B, C = symbols('A B C', commutative=False)
  159. l = [A*B*C, A*C]
  160. assert cse(l) == ([], l)
  161. l = [A*B*C, A*B]
  162. assert cse(l) == ([], l)
  163. l = [B*C, A*B*C]
  164. assert cse(l) == ([], l)
  165. @XFAIL # CSE fails when replacing non-commutative sub-expressions
  166. def test_non_commutative_order():
  167. A, B, C = symbols('A B C', commutative=False)
  168. x0 = symbols('x0', commutative=False)
  169. l = [B+C, A*(B+C)]
  170. assert cse(l) == ([(x0, B+C)], [x0, A*x0])
  171. @XFAIL # Worked in gh-11232, but was reverted due to performance considerations
  172. def test_issue_10228():
  173. assert cse([x*y**2 + x*y]) == ([(x0, x*y)], [x0*y + x0])
  174. assert cse([x + y, 2*x + y]) == ([(x0, x + y)], [x0, x + x0])
  175. assert cse((w + 2*x + y + z, w + x + 1)) == (
  176. [(x0, w + x)], [x0 + x + y + z, x0 + 1])
  177. assert cse(((w + x + y + z)*(w - x))/(w + x)) == (
  178. [(x0, w + x)], [(x0 + y + z)*(w - x)/x0])
  179. a, b, c, d, f, g, j, m = symbols('a, b, c, d, f, g, j, m')
  180. exprs = (d*g**2*j*m, 4*a*f*g*m, a*b*c*f**2)
  181. assert cse(exprs) == (
  182. [(x0, g*m), (x1, a*f)], [d*g*j*x0, 4*x0*x1, b*c*f*x1]
  183. )
  184. @XFAIL
  185. def test_powers():
  186. assert cse(x*y**2 + x*y) == ([(x0, x*y)], [x0*y + x0])
  187. def test_issue_4498():
  188. assert cse(w/(x - y) + z/(y - x), optimizations='basic') == \
  189. ([], [(w - z)/(x - y)])
  190. def test_issue_4020():
  191. assert cse(x**5 + x**4 + x**3 + x**2, optimizations='basic') \
  192. == ([(x0, x**2)], [x0*(x**3 + x + x0 + 1)])
  193. def test_issue_4203():
  194. assert cse(sin(x**x)/x**x) == ([(x0, x**x)], [sin(x0)/x0])
  195. def test_issue_6263():
  196. e = Eq(x*(-x + 1) + x*(x - 1), 0)
  197. assert cse(e, optimizations='basic') == ([], [True])
  198. def test_dont_cse_tuples():
  199. from sympy.core.function import Subs
  200. f = Function("f")
  201. g = Function("g")
  202. name_val, (expr,) = cse(
  203. Subs(f(x, y), (x, y), (0, 1))
  204. + Subs(g(x, y), (x, y), (0, 1)))
  205. assert name_val == []
  206. assert expr == (Subs(f(x, y), (x, y), (0, 1))
  207. + Subs(g(x, y), (x, y), (0, 1)))
  208. name_val, (expr,) = cse(
  209. Subs(f(x, y), (x, y), (0, x + y))
  210. + Subs(g(x, y), (x, y), (0, x + y)))
  211. assert name_val == [(x0, x + y)]
  212. assert expr == Subs(f(x, y), (x, y), (0, x0)) + \
  213. Subs(g(x, y), (x, y), (0, x0))
  214. def test_pow_invpow():
  215. assert cse(1/x**2 + x**2) == \
  216. ([(x0, x**2)], [x0 + 1/x0])
  217. assert cse(x**2 + (1 + 1/x**2)/x**2) == \
  218. ([(x0, x**2), (x1, 1/x0)], [x0 + x1*(x1 + 1)])
  219. assert cse(1/x**2 + (1 + 1/x**2)*x**2) == \
  220. ([(x0, x**2), (x1, 1/x0)], [x0*(x1 + 1) + x1])
  221. assert cse(cos(1/x**2) + sin(1/x**2)) == \
  222. ([(x0, x**(-2))], [sin(x0) + cos(x0)])
  223. assert cse(cos(x**2) + sin(x**2)) == \
  224. ([(x0, x**2)], [sin(x0) + cos(x0)])
  225. assert cse(y/(2 + x**2) + z/x**2/y) == \
  226. ([(x0, x**2)], [y/(x0 + 2) + z/(x0*y)])
  227. assert cse(exp(x**2) + x**2*cos(1/x**2)) == \
  228. ([(x0, x**2)], [x0*cos(1/x0) + exp(x0)])
  229. assert cse((1 + 1/x**2)/x**2) == \
  230. ([(x0, x**(-2))], [x0*(x0 + 1)])
  231. assert cse(x**(2*y) + x**(-2*y)) == \
  232. ([(x0, x**(2*y))], [x0 + 1/x0])
  233. def test_postprocess():
  234. eq = (x + 1 + exp((x + 1)/(y + 1)) + cos(y + 1))
  235. assert cse([eq, Eq(x, z + 1), z - 2, (z + 1)*(x + 1)],
  236. postprocess=cse_main.cse_separate) == \
  237. [[(x0, y + 1), (x2, z + 1), (x, x2), (x1, x + 1)],
  238. [x1 + exp(x1/x0) + cos(x0), z - 2, x1*x2]]
  239. def test_issue_4499():
  240. # previously, this gave 16 constants
  241. from sympy.abc import a, b
  242. B = Function('B')
  243. G = Function('G')
  244. t = Tuple(*
  245. (a, a + S.Half, 2*a, b, 2*a - b + 1, (sqrt(z)/2)**(-2*a + 1)*B(2*a -
  246. b, sqrt(z))*B(b - 1, sqrt(z))*G(b)*G(2*a - b + 1),
  247. sqrt(z)*(sqrt(z)/2)**(-2*a + 1)*B(b, sqrt(z))*B(2*a - b,
  248. sqrt(z))*G(b)*G(2*a - b + 1), sqrt(z)*(sqrt(z)/2)**(-2*a + 1)*B(b - 1,
  249. sqrt(z))*B(2*a - b + 1, sqrt(z))*G(b)*G(2*a - b + 1),
  250. (sqrt(z)/2)**(-2*a + 1)*B(b, sqrt(z))*B(2*a - b + 1,
  251. sqrt(z))*G(b)*G(2*a - b + 1), 1, 0, S.Half, z/2, -b + 1, -2*a + b,
  252. -2*a))
  253. c = cse(t)
  254. ans = (
  255. [(x0, 2*a), (x1, -b + x0), (x2, x1 + 1), (x3, b - 1), (x4, sqrt(z)),
  256. (x5, B(x3, x4)), (x6, (x4/2)**(1 - x0)*G(b)*G(x2)), (x7, x6*B(x1, x4)),
  257. (x8, B(b, x4)), (x9, x6*B(x2, x4))],
  258. [(a, a + S.Half, x0, b, x2, x5*x7, x4*x7*x8, x4*x5*x9, x8*x9,
  259. 1, 0, S.Half, z/2, -x3, -x1, -x0)])
  260. assert ans == c
  261. def test_issue_6169():
  262. r = CRootOf(x**6 - 4*x**5 - 2, 1)
  263. assert cse(r) == ([], [r])
  264. # and a check that the right thing is done with the new
  265. # mechanism
  266. assert sub_post(sub_pre((-x - y)*z - x - y)) == -z*(x + y) - x - y
  267. def test_cse_Indexed():
  268. len_y = 5
  269. y = IndexedBase('y', shape=(len_y,))
  270. x = IndexedBase('x', shape=(len_y,))
  271. i = Idx('i', len_y-1)
  272. expr1 = (y[i+1]-y[i])/(x[i+1]-x[i])
  273. expr2 = 1/(x[i+1]-x[i])
  274. replacements, reduced_exprs = cse([expr1, expr2])
  275. assert len(replacements) > 0
  276. def test_cse_MatrixSymbol():
  277. # MatrixSymbols have non-Basic args, so make sure that works
  278. A = MatrixSymbol("A", 3, 3)
  279. assert cse(A) == ([], [A])
  280. n = symbols('n', integer=True)
  281. B = MatrixSymbol("B", n, n)
  282. assert cse(B) == ([], [B])
  283. assert cse(A[0] * A[0]) == ([], [A[0]*A[0]])
  284. assert cse(A[0,0]*A[0,1] + A[0,0]*A[0,1]*A[0,2]) == ([(x0, A[0, 0]*A[0, 1])], [x0*A[0, 2] + x0])
  285. def test_cse_MatrixExpr():
  286. A = MatrixSymbol('A', 3, 3)
  287. y = MatrixSymbol('y', 3, 1)
  288. expr1 = (A.T*A).I * A * y
  289. expr2 = (A.T*A) * A * y
  290. replacements, reduced_exprs = cse([expr1, expr2])
  291. assert len(replacements) > 0
  292. replacements, reduced_exprs = cse([expr1 + expr2, expr1])
  293. assert replacements
  294. replacements, reduced_exprs = cse([A**2, A + A**2])
  295. assert replacements
  296. def test_Piecewise():
  297. f = Piecewise((-z + x*y, Eq(y, 0)), (-z - x*y, True))
  298. ans = cse(f)
  299. actual_ans = ([(x0, x*y)],
  300. [Piecewise((x0 - z, Eq(y, 0)), (-z - x0, True))])
  301. assert ans == actual_ans
  302. def test_ignore_order_terms():
  303. eq = exp(x).series(x,0,3) + sin(y+x**3) - 1
  304. assert cse(eq) == ([], [sin(x**3 + y) + x + x**2/2 + O(x**3)])
  305. def test_name_conflict():
  306. z1 = x0 + y
  307. z2 = x2 + x3
  308. l = [cos(z1) + z1, cos(z2) + z2, x0 + x2]
  309. substs, reduced = cse(l)
  310. assert [e.subs(reversed(substs)) for e in reduced] == l
  311. def test_name_conflict_cust_symbols():
  312. z1 = x0 + y
  313. z2 = x2 + x3
  314. l = [cos(z1) + z1, cos(z2) + z2, x0 + x2]
  315. substs, reduced = cse(l, symbols("x:10"))
  316. assert [e.subs(reversed(substs)) for e in reduced] == l
  317. def test_symbols_exhausted_error():
  318. l = cos(x+y)+x+y+cos(w+y)+sin(w+y)
  319. sym = [x, y, z]
  320. with raises(ValueError):
  321. cse(l, symbols=sym)
  322. def test_issue_7840():
  323. # daveknippers' example
  324. C393 = sympify( \
  325. 'Piecewise((C391 - 1.65, C390 < 0.5), (Piecewise((C391 - 1.65, \
  326. C391 > 2.35), (C392, True)), True))'
  327. )
  328. C391 = sympify( \
  329. 'Piecewise((2.05*C390**(-1.03), C390 < 0.5), (2.5*C390**(-0.625), True))'
  330. )
  331. C393 = C393.subs('C391',C391)
  332. # simple substitution
  333. sub = {}
  334. sub['C390'] = 0.703451854
  335. sub['C392'] = 1.01417794
  336. ss_answer = C393.subs(sub)
  337. # cse
  338. substitutions,new_eqn = cse(C393)
  339. for pair in substitutions:
  340. sub[pair[0].name] = pair[1].subs(sub)
  341. cse_answer = new_eqn[0].subs(sub)
  342. # both methods should be the same
  343. assert ss_answer == cse_answer
  344. # GitRay's example
  345. expr = sympify(
  346. "Piecewise((Symbol('ON'), Equality(Symbol('mode'), Symbol('ON'))), \
  347. (Piecewise((Piecewise((Symbol('OFF'), StrictLessThan(Symbol('x'), \
  348. Symbol('threshold'))), (Symbol('ON'), true)), Equality(Symbol('mode'), \
  349. Symbol('AUTO'))), (Symbol('OFF'), true)), true))"
  350. )
  351. substitutions, new_eqn = cse(expr)
  352. # this Piecewise should be exactly the same
  353. assert new_eqn[0] == expr
  354. # there should not be any replacements
  355. assert len(substitutions) < 1
  356. def test_issue_8891():
  357. for cls in (MutableDenseMatrix, MutableSparseMatrix,
  358. ImmutableDenseMatrix, ImmutableSparseMatrix):
  359. m = cls(2, 2, [x + y, 0, 0, 0])
  360. res = cse([x + y, m])
  361. ans = ([(x0, x + y)], [x0, cls([[x0, 0], [0, 0]])])
  362. assert res == ans
  363. assert isinstance(res[1][-1], cls)
  364. def test_issue_11230():
  365. # a specific test that always failed
  366. a, b, f, k, l, i = symbols('a b f k l i')
  367. p = [a*b*f*k*l, a*i*k**2*l, f*i*k**2*l]
  368. R, C = cse(p)
  369. assert not any(i.is_Mul for a in C for i in a.args)
  370. # random tests for the issue
  371. from sympy.core.random import choice
  372. from sympy.core.function import expand_mul
  373. s = symbols('a:m')
  374. # 35 Mul tests, none of which should ever fail
  375. ex = [Mul(*[choice(s) for i in range(5)]) for i in range(7)]
  376. for p in subsets(ex, 3):
  377. p = list(p)
  378. R, C = cse(p)
  379. assert not any(i.is_Mul for a in C for i in a.args)
  380. for ri in reversed(R):
  381. for i in range(len(C)):
  382. C[i] = C[i].subs(*ri)
  383. assert p == C
  384. # 35 Add tests, none of which should ever fail
  385. ex = [Add(*[choice(s[:7]) for i in range(5)]) for i in range(7)]
  386. for p in subsets(ex, 3):
  387. p = list(p)
  388. R, C = cse(p)
  389. assert not any(i.is_Add for a in C for i in a.args)
  390. for ri in reversed(R):
  391. for i in range(len(C)):
  392. C[i] = C[i].subs(*ri)
  393. # use expand_mul to handle cases like this:
  394. # p = [a + 2*b + 2*e, 2*b + c + 2*e, b + 2*c + 2*g]
  395. # x0 = 2*(b + e) is identified giving a rebuilt p that
  396. # is now `[a + 2*(b + e), c + 2*(b + e), b + 2*c + 2*g]`
  397. assert p == [expand_mul(i) for i in C]
  398. @XFAIL
  399. def test_issue_11577():
  400. def check(eq):
  401. r, c = cse(eq)
  402. assert eq.count_ops() >= \
  403. len(r) + sum([i[1].count_ops() for i in r]) + \
  404. count_ops(c)
  405. eq = x**5*y**2 + x**5*y + x**5
  406. assert cse(eq) == (
  407. [(x0, x**4), (x1, x*y)], [x**5 + x0*x1*y + x0*x1])
  408. # ([(x0, x**5*y)], [x0*y + x0 + x**5]) or
  409. # ([(x0, x**5)], [x0*y**2 + x0*y + x0])
  410. check(eq)
  411. eq = x**2/(y + 1)**2 + x/(y + 1)
  412. assert cse(eq) == (
  413. [(x0, y + 1)], [x**2/x0**2 + x/x0])
  414. # ([(x0, x/(y + 1))], [x0**2 + x0])
  415. check(eq)
  416. def test_hollow_rejection():
  417. eq = [x + 3, x + 4]
  418. assert cse(eq) == ([], eq)
  419. def test_cse_ignore():
  420. exprs = [exp(y)*(3*y + 3*sqrt(x+1)), exp(y)*(5*y + 5*sqrt(x+1))]
  421. subst1, red1 = cse(exprs)
  422. assert any(y in sub.free_symbols for _, sub in subst1), "cse failed to identify any term with y"
  423. subst2, red2 = cse(exprs, ignore=(y,)) # y is not allowed in substitutions
  424. assert not any(y in sub.free_symbols for _, sub in subst2), "Sub-expressions containing y must be ignored"
  425. assert any(sub - sqrt(x + 1) == 0 for _, sub in subst2), "cse failed to identify sqrt(x + 1) as sub-expression"
  426. def test_cse_ignore_issue_15002():
  427. l = [
  428. w*exp(x)*exp(-z),
  429. exp(y)*exp(x)*exp(-z)
  430. ]
  431. substs, reduced = cse(l, ignore=(x,))
  432. rl = [e.subs(reversed(substs)) for e in reduced]
  433. assert rl == l
  434. def test_cse__performance():
  435. nexprs, nterms = 3, 20
  436. x = symbols('x:%d' % nterms)
  437. exprs = [
  438. reduce(add, [x[j]*(-1)**(i+j) for j in range(nterms)])
  439. for i in range(nexprs)
  440. ]
  441. assert (exprs[0] + exprs[1]).simplify() == 0
  442. subst, red = cse(exprs)
  443. assert len(subst) > 0, "exprs[0] == -exprs[2], i.e. a CSE"
  444. for i, e in enumerate(red):
  445. assert (e.subs(reversed(subst)) - exprs[i]).simplify() == 0
  446. def test_issue_12070():
  447. exprs = [x + y, 2 + x + y, x + y + z, 3 + x + y + z]
  448. subst, red = cse(exprs)
  449. assert 6 >= (len(subst) + sum([v.count_ops() for k, v in subst]) +
  450. count_ops(red))
  451. def test_issue_13000():
  452. eq = x/(-4*x**2 + y**2)
  453. cse_eq = cse(eq)[1][0]
  454. assert cse_eq == eq
  455. def test_issue_18203():
  456. eq = CRootOf(x**5 + 11*x - 2, 0) + CRootOf(x**5 + 11*x - 2, 1)
  457. assert cse(eq) == ([], [eq])
  458. def test_unevaluated_mul():
  459. eq = Mul(x + y, x + y, evaluate=False)
  460. assert cse(eq) == ([(x0, x + y)], [x0**2])
  461. def test_cse_release_variables():
  462. from sympy.simplify.cse_main import cse_release_variables
  463. _0, _1, _2, _3, _4 = symbols('_:5')
  464. eqs = [(x + y - 1)**2, x,
  465. x + y, (x + y)/(2*x + 1) + (x + y - 1)**2,
  466. (2*x + 1)**(x + y)]
  467. r, e = cse(eqs, postprocess=cse_release_variables)
  468. # this can change in keeping with the intention of the function
  469. assert r, e == ([
  470. (x0, x + y), (x1, (x0 - 1)**2), (x2, 2*x + 1),
  471. (_3, x0/x2 + x1), (_4, x2**x0), (x2, None), (_0, x1),
  472. (x1, None), (_2, x0), (x0, None), (_1, x)], (_0, _1, _2, _3, _4))
  473. r.reverse()
  474. assert eqs == [i.subs(r) for i in e]
  475. def test_cse_list():
  476. _cse = lambda x: cse(x, list=False)
  477. assert _cse(x) == ([], x)
  478. assert _cse('x') == ([], 'x')
  479. it = [x]
  480. for c in (list, tuple, set):
  481. assert _cse(c(it)) == ([], c(it))
  482. #Tuple works different from tuple:
  483. assert _cse(Tuple(*it)) == ([], Tuple(*it))
  484. d = {x: 1}
  485. assert _cse(d) == ([], d)
  486. def test_issue_18991():
  487. A = MatrixSymbol('A', 2, 2)
  488. assert signsimp(-A * A - A) == -A * A - A
  489. def test_unevaluated_Mul():
  490. m = [Mul(1, 2, evaluate=False)]
  491. assert cse(m) == ([], m)