m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

113 lines
3.2 KiB

6 months ago
  1. from sympy.core import Lambda, S, symbols
  2. from sympy.concrete import Sum
  3. from sympy.functions import adjoint, conjugate, transpose
  4. from sympy.matrices import eye, Matrix, ShapeError, ImmutableMatrix
  5. from sympy.matrices.expressions import (
  6. Adjoint, Identity, FunctionMatrix, MatrixExpr, MatrixSymbol, Trace,
  7. ZeroMatrix, trace, MatPow, MatAdd, MatMul
  8. )
  9. from sympy.matrices.expressions.special import OneMatrix
  10. from sympy.testing.pytest import raises
  11. n = symbols('n', integer=True)
  12. A = MatrixSymbol('A', n, n)
  13. B = MatrixSymbol('B', n, n)
  14. C = MatrixSymbol('C', 3, 4)
  15. def test_Trace():
  16. assert isinstance(Trace(A), Trace)
  17. assert not isinstance(Trace(A), MatrixExpr)
  18. raises(ShapeError, lambda: Trace(C))
  19. assert trace(eye(3)) == 3
  20. assert trace(Matrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 9])) == 15
  21. assert adjoint(Trace(A)) == trace(Adjoint(A))
  22. assert conjugate(Trace(A)) == trace(Adjoint(A))
  23. assert transpose(Trace(A)) == Trace(A)
  24. _ = A / Trace(A) # Make sure this is possible
  25. # Some easy simplifications
  26. assert trace(Identity(5)) == 5
  27. assert trace(ZeroMatrix(5, 5)) == 0
  28. assert trace(OneMatrix(1, 1)) == 1
  29. assert trace(OneMatrix(2, 2)) == 2
  30. assert trace(OneMatrix(n, n)) == n
  31. assert trace(2*A*B) == 2*Trace(A*B)
  32. assert trace(A.T) == trace(A)
  33. i, j = symbols('i j')
  34. F = FunctionMatrix(3, 3, Lambda((i, j), i + j))
  35. assert trace(F) == (0 + 0) + (1 + 1) + (2 + 2)
  36. raises(TypeError, lambda: Trace(S.One))
  37. assert Trace(A).arg is A
  38. assert str(trace(A)) == str(Trace(A).doit())
  39. assert Trace(A).is_commutative is True
  40. def test_Trace_A_plus_B():
  41. assert trace(A + B) == Trace(A) + Trace(B)
  42. assert Trace(A + B).arg == MatAdd(A, B)
  43. assert Trace(A + B).doit() == Trace(A) + Trace(B)
  44. def test_Trace_MatAdd_doit():
  45. # See issue #9028
  46. X = ImmutableMatrix([[1, 2, 3]]*3)
  47. Y = MatrixSymbol('Y', 3, 3)
  48. q = MatAdd(X, 2*X, Y, -3*Y)
  49. assert Trace(q).arg == q
  50. assert Trace(q).doit() == 18 - 2*Trace(Y)
  51. def test_Trace_MatPow_doit():
  52. X = Matrix([[1, 2], [3, 4]])
  53. assert Trace(X).doit() == 5
  54. q = MatPow(X, 2)
  55. assert Trace(q).arg == q
  56. assert Trace(q).doit() == 29
  57. def test_Trace_MutableMatrix_plus():
  58. # See issue #9043
  59. X = Matrix([[1, 2], [3, 4]])
  60. assert Trace(X) + Trace(X) == 2*Trace(X)
  61. def test_Trace_doit_deep_False():
  62. X = Matrix([[1, 2], [3, 4]])
  63. q = MatPow(X, 2)
  64. assert Trace(q).doit(deep=False).arg == q
  65. q = MatAdd(X, 2*X)
  66. assert Trace(q).doit(deep=False).arg == q
  67. q = MatMul(X, 2*X)
  68. assert Trace(q).doit(deep=False).arg == q
  69. def test_trace_constant_factor():
  70. # Issue 9052: gave 2*Trace(MatMul(A)) instead of 2*Trace(A)
  71. assert trace(2*A) == 2*Trace(A)
  72. X = ImmutableMatrix([[1, 2], [3, 4]])
  73. assert trace(MatMul(2, X)) == 10
  74. def test_rewrite():
  75. assert isinstance(trace(A).rewrite(Sum), Sum)
  76. def test_trace_normalize():
  77. assert Trace(B*A) != Trace(A*B)
  78. assert Trace(B*A)._normalize() == Trace(A*B)
  79. assert Trace(B*A.T)._normalize() == Trace(A*B.T)
  80. def test_trace_as_explicit():
  81. raises(ValueError, lambda: Trace(A).as_explicit())
  82. X = MatrixSymbol("X", 3, 3)
  83. assert Trace(X).as_explicit() == X[0, 0] + X[1, 1] + X[2, 2]
  84. assert Trace(eye(3)).as_explicit() == 3