图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

166 lines
5.5 KiB

  1. from sympy.combinatorics import Permutation
  2. from sympy.core.expr import unchanged
  3. from sympy.matrices import Matrix
  4. from sympy.matrices.expressions import \
  5. MatMul, BlockDiagMatrix, Determinant, Inverse
  6. from sympy.matrices.expressions.matexpr import MatrixSymbol
  7. from sympy.matrices.expressions.special import ZeroMatrix, OneMatrix, Identity
  8. from sympy.matrices.expressions.permutation import \
  9. MatrixPermute, PermutationMatrix
  10. from sympy.testing.pytest import raises
  11. from sympy.core.symbol import Symbol
  12. def test_PermutationMatrix_basic():
  13. p = Permutation([1, 0])
  14. assert unchanged(PermutationMatrix, p)
  15. raises(ValueError, lambda: PermutationMatrix((0, 1, 2)))
  16. assert PermutationMatrix(p).as_explicit() == Matrix([[0, 1], [1, 0]])
  17. assert isinstance(PermutationMatrix(p)*MatrixSymbol('A', 2, 2), MatMul)
  18. def test_PermutationMatrix_matmul():
  19. p = Permutation([1, 2, 0])
  20. P = PermutationMatrix(p)
  21. M = Matrix([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
  22. assert (P*M).as_explicit() == P.as_explicit()*M
  23. assert (M*P).as_explicit() == M*P.as_explicit()
  24. P1 = PermutationMatrix(Permutation([1, 2, 0]))
  25. P2 = PermutationMatrix(Permutation([2, 1, 0]))
  26. P3 = PermutationMatrix(Permutation([1, 0, 2]))
  27. assert P1*P2 == P3
  28. def test_PermutationMatrix_matpow():
  29. p1 = Permutation([1, 2, 0])
  30. P1 = PermutationMatrix(p1)
  31. p2 = Permutation([2, 0, 1])
  32. P2 = PermutationMatrix(p2)
  33. assert P1**2 == P2
  34. assert P1**3 == Identity(3)
  35. def test_PermutationMatrix_identity():
  36. p = Permutation([0, 1])
  37. assert PermutationMatrix(p).is_Identity
  38. p = Permutation([1, 0])
  39. assert not PermutationMatrix(p).is_Identity
  40. def test_PermutationMatrix_determinant():
  41. P = PermutationMatrix(Permutation([0, 1, 2]))
  42. assert Determinant(P).doit() == 1
  43. P = PermutationMatrix(Permutation([0, 2, 1]))
  44. assert Determinant(P).doit() == -1
  45. P = PermutationMatrix(Permutation([2, 0, 1]))
  46. assert Determinant(P).doit() == 1
  47. def test_PermutationMatrix_inverse():
  48. P = PermutationMatrix(Permutation(0, 1, 2))
  49. assert Inverse(P).doit() == PermutationMatrix(Permutation(0, 2, 1))
  50. def test_PermutationMatrix_rewrite_BlockDiagMatrix():
  51. P = PermutationMatrix(Permutation([0, 1, 2, 3, 4, 5]))
  52. P0 = PermutationMatrix(Permutation([0]))
  53. assert P.rewrite(BlockDiagMatrix) == \
  54. BlockDiagMatrix(P0, P0, P0, P0, P0, P0)
  55. P = PermutationMatrix(Permutation([0, 1, 3, 2, 4, 5]))
  56. P10 = PermutationMatrix(Permutation(0, 1))
  57. assert P.rewrite(BlockDiagMatrix) == \
  58. BlockDiagMatrix(P0, P0, P10, P0, P0)
  59. P = PermutationMatrix(Permutation([1, 0, 3, 2, 5, 4]))
  60. assert P.rewrite(BlockDiagMatrix) == \
  61. BlockDiagMatrix(P10, P10, P10)
  62. P = PermutationMatrix(Permutation([0, 4, 3, 2, 1, 5]))
  63. P3210 = PermutationMatrix(Permutation([3, 2, 1, 0]))
  64. assert P.rewrite(BlockDiagMatrix) == \
  65. BlockDiagMatrix(P0, P3210, P0)
  66. P = PermutationMatrix(Permutation([0, 4, 2, 3, 1, 5]))
  67. P3120 = PermutationMatrix(Permutation([3, 1, 2, 0]))
  68. assert P.rewrite(BlockDiagMatrix) == \
  69. BlockDiagMatrix(P0, P3120, P0)
  70. P = PermutationMatrix(Permutation(0, 3)(1, 4)(2, 5))
  71. assert P.rewrite(BlockDiagMatrix) == BlockDiagMatrix(P)
  72. def test_MartrixPermute_basic():
  73. p = Permutation(0, 1)
  74. P = PermutationMatrix(p)
  75. A = MatrixSymbol('A', 2, 2)
  76. raises(ValueError, lambda: MatrixPermute(Symbol('x'), p))
  77. raises(ValueError, lambda: MatrixPermute(A, Symbol('x')))
  78. assert MatrixPermute(A, P) == MatrixPermute(A, p)
  79. raises(ValueError, lambda: MatrixPermute(A, p, 2))
  80. pp = Permutation(0, 1, size=3)
  81. assert MatrixPermute(A, pp) == MatrixPermute(A, p)
  82. pp = Permutation(0, 1, 2)
  83. raises(ValueError, lambda: MatrixPermute(A, pp))
  84. def test_MatrixPermute_shape():
  85. p = Permutation(0, 1)
  86. A = MatrixSymbol('A', 2, 3)
  87. assert MatrixPermute(A, p).shape == (2, 3)
  88. def test_MatrixPermute_explicit():
  89. p = Permutation(0, 1, 2)
  90. A = MatrixSymbol('A', 3, 3)
  91. AA = A.as_explicit()
  92. assert MatrixPermute(A, p, 0).as_explicit() == \
  93. AA.permute(p, orientation='rows')
  94. assert MatrixPermute(A, p, 1).as_explicit() == \
  95. AA.permute(p, orientation='cols')
  96. def test_MatrixPermute_rewrite_MatMul():
  97. p = Permutation(0, 1, 2)
  98. A = MatrixSymbol('A', 3, 3)
  99. assert MatrixPermute(A, p, 0).rewrite(MatMul).as_explicit() == \
  100. MatrixPermute(A, p, 0).as_explicit()
  101. assert MatrixPermute(A, p, 1).rewrite(MatMul).as_explicit() == \
  102. MatrixPermute(A, p, 1).as_explicit()
  103. def test_MatrixPermute_doit():
  104. p = Permutation(0, 1, 2)
  105. A = MatrixSymbol('A', 3, 3)
  106. assert MatrixPermute(A, p).doit() == MatrixPermute(A, p)
  107. p = Permutation(0, size=3)
  108. A = MatrixSymbol('A', 3, 3)
  109. assert MatrixPermute(A, p).doit().as_explicit() == \
  110. MatrixPermute(A, p).as_explicit()
  111. p = Permutation(0, 1, 2)
  112. A = Identity(3)
  113. assert MatrixPermute(A, p, 0).doit().as_explicit() == \
  114. MatrixPermute(A, p, 0).as_explicit()
  115. assert MatrixPermute(A, p, 1).doit().as_explicit() == \
  116. MatrixPermute(A, p, 1).as_explicit()
  117. A = ZeroMatrix(3, 3)
  118. assert MatrixPermute(A, p).doit() == A
  119. A = OneMatrix(3, 3)
  120. assert MatrixPermute(A, p).doit() == A
  121. A = MatrixSymbol('A', 4, 4)
  122. p1 = Permutation(0, 1, 2, 3)
  123. p2 = Permutation(0, 2, 3, 1)
  124. expr = MatrixPermute(MatrixPermute(A, p1, 0), p2, 0)
  125. assert expr.as_explicit() == expr.doit().as_explicit()
  126. expr = MatrixPermute(MatrixPermute(A, p1, 1), p2, 1)
  127. assert expr.as_explicit() == expr.doit().as_explicit()