m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

113 lines
3.2 KiB

from sympy.core import Lambda, S, symbols
from sympy.concrete import Sum
from sympy.functions import adjoint, conjugate, transpose
from sympy.matrices import eye, Matrix, ShapeError, ImmutableMatrix
from sympy.matrices.expressions import (
Adjoint, Identity, FunctionMatrix, MatrixExpr, MatrixSymbol, Trace,
ZeroMatrix, trace, MatPow, MatAdd, MatMul
)
from sympy.matrices.expressions.special import OneMatrix
from sympy.testing.pytest import raises
n = symbols('n', integer=True)
A = MatrixSymbol('A', n, n)
B = MatrixSymbol('B', n, n)
C = MatrixSymbol('C', 3, 4)
def test_Trace():
assert isinstance(Trace(A), Trace)
assert not isinstance(Trace(A), MatrixExpr)
raises(ShapeError, lambda: Trace(C))
assert trace(eye(3)) == 3
assert trace(Matrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 9])) == 15
assert adjoint(Trace(A)) == trace(Adjoint(A))
assert conjugate(Trace(A)) == trace(Adjoint(A))
assert transpose(Trace(A)) == Trace(A)
_ = A / Trace(A) # Make sure this is possible
# Some easy simplifications
assert trace(Identity(5)) == 5
assert trace(ZeroMatrix(5, 5)) == 0
assert trace(OneMatrix(1, 1)) == 1
assert trace(OneMatrix(2, 2)) == 2
assert trace(OneMatrix(n, n)) == n
assert trace(2*A*B) == 2*Trace(A*B)
assert trace(A.T) == trace(A)
i, j = symbols('i j')
F = FunctionMatrix(3, 3, Lambda((i, j), i + j))
assert trace(F) == (0 + 0) + (1 + 1) + (2 + 2)
raises(TypeError, lambda: Trace(S.One))
assert Trace(A).arg is A
assert str(trace(A)) == str(Trace(A).doit())
assert Trace(A).is_commutative is True
def test_Trace_A_plus_B():
assert trace(A + B) == Trace(A) + Trace(B)
assert Trace(A + B).arg == MatAdd(A, B)
assert Trace(A + B).doit() == Trace(A) + Trace(B)
def test_Trace_MatAdd_doit():
# See issue #9028
X = ImmutableMatrix([[1, 2, 3]]*3)
Y = MatrixSymbol('Y', 3, 3)
q = MatAdd(X, 2*X, Y, -3*Y)
assert Trace(q).arg == q
assert Trace(q).doit() == 18 - 2*Trace(Y)
def test_Trace_MatPow_doit():
X = Matrix([[1, 2], [3, 4]])
assert Trace(X).doit() == 5
q = MatPow(X, 2)
assert Trace(q).arg == q
assert Trace(q).doit() == 29
def test_Trace_MutableMatrix_plus():
# See issue #9043
X = Matrix([[1, 2], [3, 4]])
assert Trace(X) + Trace(X) == 2*Trace(X)
def test_Trace_doit_deep_False():
X = Matrix([[1, 2], [3, 4]])
q = MatPow(X, 2)
assert Trace(q).doit(deep=False).arg == q
q = MatAdd(X, 2*X)
assert Trace(q).doit(deep=False).arg == q
q = MatMul(X, 2*X)
assert Trace(q).doit(deep=False).arg == q
def test_trace_constant_factor():
# Issue 9052: gave 2*Trace(MatMul(A)) instead of 2*Trace(A)
assert trace(2*A) == 2*Trace(A)
X = ImmutableMatrix([[1, 2], [3, 4]])
assert trace(MatMul(2, X)) == 10
def test_rewrite():
assert isinstance(trace(A).rewrite(Sum), Sum)
def test_trace_normalize():
assert Trace(B*A) != Trace(A*B)
assert Trace(B*A)._normalize() == Trace(A*B)
assert Trace(B*A.T)._normalize() == Trace(A*B.T)
def test_trace_as_explicit():
raises(ValueError, lambda: Trace(A).as_explicit())
X = MatrixSymbol("X", 3, 3)
assert Trace(X).as_explicit() == X[0, 0] + X[1, 1] + X[2, 2]
assert Trace(eye(3)).as_explicit() == 3