图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

162 lines
5.8 KiB

  1. from sympy.core.add import Add
  2. from sympy.core.basic import Basic
  3. from sympy.core.containers import Tuple
  4. from sympy.core.singleton import S
  5. from sympy.core.symbol import (Symbol, symbols)
  6. from sympy.logic.boolalg import And
  7. from sympy.core.symbol import Str
  8. from sympy.unify.core import Compound, Variable
  9. from sympy.unify.usympy import (deconstruct, construct, unify, is_associative,
  10. is_commutative)
  11. from sympy.abc import x, y, z, n
  12. def test_deconstruct():
  13. expr = Basic(S(1), S(2), S(3))
  14. expected = Compound(Basic, (1, 2, 3))
  15. assert deconstruct(expr) == expected
  16. assert deconstruct(1) == 1
  17. assert deconstruct(x) == x
  18. assert deconstruct(x, variables=(x,)) == Variable(x)
  19. assert deconstruct(Add(1, x, evaluate=False)) == Compound(Add, (1, x))
  20. assert deconstruct(Add(1, x, evaluate=False), variables=(x,)) == \
  21. Compound(Add, (1, Variable(x)))
  22. def test_construct():
  23. expr = Compound(Basic, (S(1), S(2), S(3)))
  24. expected = Basic(S(1), S(2), S(3))
  25. assert construct(expr) == expected
  26. def test_nested():
  27. expr = Basic(S(1), Basic(S(2)), S(3))
  28. cmpd = Compound(Basic, (S(1), Compound(Basic, Tuple(2)), S(3)))
  29. assert deconstruct(expr) == cmpd
  30. assert construct(cmpd) == expr
  31. def test_unify():
  32. expr = Basic(S(1), S(2), S(3))
  33. a, b, c = map(Symbol, 'abc')
  34. pattern = Basic(a, b, c)
  35. assert list(unify(expr, pattern, {}, (a, b, c))) == [{a: 1, b: 2, c: 3}]
  36. assert list(unify(expr, pattern, variables=(a, b, c))) == \
  37. [{a: 1, b: 2, c: 3}]
  38. def test_unify_variables():
  39. assert list(unify(Basic(S(1), S(2)), Basic(S(1), x), {}, variables=(x,))) == [{x: 2}]
  40. def test_s_input():
  41. expr = Basic(S(1), S(2))
  42. a, b = map(Symbol, 'ab')
  43. pattern = Basic(a, b)
  44. assert list(unify(expr, pattern, {}, (a, b))) == [{a: 1, b: 2}]
  45. assert list(unify(expr, pattern, {a: 5}, (a, b))) == []
  46. def iterdicteq(a, b):
  47. a = tuple(a)
  48. b = tuple(b)
  49. return len(a) == len(b) and all(x in b for x in a)
  50. def test_unify_commutative():
  51. expr = Add(1, 2, 3, evaluate=False)
  52. a, b, c = map(Symbol, 'abc')
  53. pattern = Add(a, b, c, evaluate=False)
  54. result = tuple(unify(expr, pattern, {}, (a, b, c)))
  55. expected = ({a: 1, b: 2, c: 3},
  56. {a: 1, b: 3, c: 2},
  57. {a: 2, b: 1, c: 3},
  58. {a: 2, b: 3, c: 1},
  59. {a: 3, b: 1, c: 2},
  60. {a: 3, b: 2, c: 1})
  61. assert iterdicteq(result, expected)
  62. def test_unify_iter():
  63. expr = Add(1, 2, 3, evaluate=False)
  64. a, b, c = map(Symbol, 'abc')
  65. pattern = Add(a, c, evaluate=False)
  66. assert is_associative(deconstruct(pattern))
  67. assert is_commutative(deconstruct(pattern))
  68. result = list(unify(expr, pattern, {}, (a, c)))
  69. expected = [{a: 1, c: Add(2, 3, evaluate=False)},
  70. {a: 1, c: Add(3, 2, evaluate=False)},
  71. {a: 2, c: Add(1, 3, evaluate=False)},
  72. {a: 2, c: Add(3, 1, evaluate=False)},
  73. {a: 3, c: Add(1, 2, evaluate=False)},
  74. {a: 3, c: Add(2, 1, evaluate=False)},
  75. {a: Add(1, 2, evaluate=False), c: 3},
  76. {a: Add(2, 1, evaluate=False), c: 3},
  77. {a: Add(1, 3, evaluate=False), c: 2},
  78. {a: Add(3, 1, evaluate=False), c: 2},
  79. {a: Add(2, 3, evaluate=False), c: 1},
  80. {a: Add(3, 2, evaluate=False), c: 1}]
  81. assert iterdicteq(result, expected)
  82. def test_hard_match():
  83. from sympy.functions.elementary.trigonometric import (cos, sin)
  84. expr = sin(x) + cos(x)**2
  85. p, q = map(Symbol, 'pq')
  86. pattern = sin(p) + cos(p)**2
  87. assert list(unify(expr, pattern, {}, (p, q))) == [{p: x}]
  88. def test_matrix():
  89. from sympy.matrices.expressions.matexpr import MatrixSymbol
  90. X = MatrixSymbol('X', n, n)
  91. Y = MatrixSymbol('Y', 2, 2)
  92. Z = MatrixSymbol('Z', 2, 3)
  93. assert list(unify(X, Y, {}, variables=[n, Str('X')])) == [{Str('X'): Str('Y'), n: 2}]
  94. assert list(unify(X, Z, {}, variables=[n, Str('X')])) == []
  95. def test_non_frankenAdds():
  96. # the is_commutative property used to fail because of Basic.__new__
  97. # This caused is_commutative and str calls to fail
  98. expr = x+y*2
  99. rebuilt = construct(deconstruct(expr))
  100. # Ensure that we can run these commands without causing an error
  101. str(rebuilt)
  102. rebuilt.is_commutative
  103. def test_FiniteSet_commutivity():
  104. from sympy.sets.sets import FiniteSet
  105. a, b, c, x, y = symbols('a,b,c,x,y')
  106. s = FiniteSet(a, b, c)
  107. t = FiniteSet(x, y)
  108. variables = (x, y)
  109. assert {x: FiniteSet(a, c), y: b} in tuple(unify(s, t, variables=variables))
  110. def test_FiniteSet_complex():
  111. from sympy.sets.sets import FiniteSet
  112. a, b, c, x, y, z = symbols('a,b,c,x,y,z')
  113. expr = FiniteSet(Basic(S(1), x), y, Basic(x, z))
  114. pattern = FiniteSet(a, Basic(x, b))
  115. variables = a, b
  116. expected = tuple([{b: 1, a: FiniteSet(y, Basic(x, z))},
  117. {b: z, a: FiniteSet(y, Basic(S(1), x))}])
  118. assert iterdicteq(unify(expr, pattern, variables=variables), expected)
  119. def test_and():
  120. variables = x, y
  121. expected = tuple([{x: z > 0, y: n < 3}])
  122. assert iterdicteq(unify((z>0) & (n<3), And(x, y), variables=variables),
  123. expected)
  124. def test_Union():
  125. from sympy.sets.sets import Interval
  126. assert list(unify(Interval(0, 1) + Interval(10, 11),
  127. Interval(0, 1) + Interval(12, 13),
  128. variables=(Interval(12, 13),)))
  129. def test_is_commutative():
  130. assert is_commutative(deconstruct(x+y))
  131. assert is_commutative(deconstruct(x*y))
  132. assert not is_commutative(deconstruct(x**y))
  133. def test_commutative_in_commutative():
  134. from sympy.abc import a,b,c,d
  135. from sympy.functions.elementary.trigonometric import (cos, sin)
  136. eq = sin(3)*sin(4)*sin(5) + 4*cos(3)*cos(4)
  137. pat = a*cos(b)*cos(c) + d*sin(b)*sin(c)
  138. assert next(unify(eq, pat, variables=(a,b,c,d)))