图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

109 lines
3.0 KiB

  1. from sympy.core.containers import Tuple
  2. from sympy.core.symbol import symbols
  3. from sympy.matrices.dense import Matrix
  4. from sympy.physics.quantum.trace import Tr
  5. from sympy.testing.pytest import raises, warns_deprecated_sympy
  6. def test_trace_new():
  7. a, b, c, d, Y = symbols('a b c d Y')
  8. A, B, C, D = symbols('A B C D', commutative=False)
  9. assert Tr(a + b) == a + b
  10. assert Tr(A + B) == Tr(A) + Tr(B)
  11. #check trace args not implicitly permuted
  12. assert Tr(C*D*A*B).args[0].args == (C, D, A, B)
  13. # check for mul and adds
  14. assert Tr((a*b) + ( c*d)) == (a*b) + (c*d)
  15. # Tr(scalar*A) = scalar*Tr(A)
  16. assert Tr(a*A) == a*Tr(A)
  17. assert Tr(a*A*B*b) == a*b*Tr(A*B)
  18. # since A is symbol and not commutative
  19. assert isinstance(Tr(A), Tr)
  20. #POW
  21. assert Tr(pow(a, b)) == a**b
  22. assert isinstance(Tr(pow(A, a)), Tr)
  23. #Matrix
  24. M = Matrix([[1, 1], [2, 2]])
  25. assert Tr(M) == 3
  26. ##test indices in different forms
  27. #no index
  28. t = Tr(A)
  29. assert t.args[1] == Tuple()
  30. #single index
  31. t = Tr(A, 0)
  32. assert t.args[1] == Tuple(0)
  33. #index in a list
  34. t = Tr(A, [0])
  35. assert t.args[1] == Tuple(0)
  36. t = Tr(A, [0, 1, 2])
  37. assert t.args[1] == Tuple(0, 1, 2)
  38. #index is tuple
  39. t = Tr(A, (0))
  40. assert t.args[1] == Tuple(0)
  41. t = Tr(A, (1, 2))
  42. assert t.args[1] == Tuple(1, 2)
  43. #trace indices test
  44. t = Tr((A + B), [2])
  45. assert t.args[0].args[1] == Tuple(2) and t.args[1].args[1] == Tuple(2)
  46. t = Tr(a*A, [2, 3])
  47. assert t.args[1].args[1] == Tuple(2, 3)
  48. #class with trace method defined
  49. #to simulate numpy objects
  50. class Foo:
  51. def trace(self):
  52. return 1
  53. assert Tr(Foo()) == 1
  54. #argument test
  55. # check for value error, when either/both arguments are not provided
  56. raises(ValueError, lambda: Tr())
  57. raises(ValueError, lambda: Tr(A, 1, 2))
  58. def test_trace_doit():
  59. a, b, c, d = symbols('a b c d')
  60. A, B, C, D = symbols('A B C D', commutative=False)
  61. #TODO: needed while testing reduced density operations, etc.
  62. def test_permute():
  63. A, B, C, D, E, F, G = symbols('A B C D E F G', commutative=False)
  64. t = Tr(A*B*C*D*E*F*G)
  65. assert t.permute(0).args[0].args == (A, B, C, D, E, F, G)
  66. assert t.permute(2).args[0].args == (F, G, A, B, C, D, E)
  67. assert t.permute(4).args[0].args == (D, E, F, G, A, B, C)
  68. assert t.permute(6).args[0].args == (B, C, D, E, F, G, A)
  69. assert t.permute(8).args[0].args == t.permute(1).args[0].args
  70. assert t.permute(-1).args[0].args == (B, C, D, E, F, G, A)
  71. assert t.permute(-3).args[0].args == (D, E, F, G, A, B, C)
  72. assert t.permute(-5).args[0].args == (F, G, A, B, C, D, E)
  73. assert t.permute(-8).args[0].args == t.permute(-1).args[0].args
  74. t = Tr((A + B)*(B*B)*C*D)
  75. assert t.permute(2).args[0].args == (C, D, (A + B), (B**2))
  76. t1 = Tr(A*B)
  77. t2 = t1.permute(1)
  78. assert id(t1) != id(t2) and t1 == t2
  79. def test_deprecated_core_trace():
  80. with warns_deprecated_sympy():
  81. from sympy.core.trace import Tr # noqa:F401