m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

172 lines
5.5 KiB

7 months ago
  1. from sympy.stats import Expectation, Normal, Variance, Covariance
  2. from sympy.testing.pytest import raises
  3. from sympy.core.symbol import symbols
  4. from sympy.matrices.common import ShapeError
  5. from sympy.matrices.dense import Matrix
  6. from sympy.matrices.expressions.matexpr import MatrixSymbol
  7. from sympy.matrices.expressions.special import ZeroMatrix
  8. from sympy.stats.rv import RandomMatrixSymbol
  9. from sympy.stats.symbolic_multivariate_probability import (ExpectationMatrix,
  10. VarianceMatrix, CrossCovarianceMatrix)
  11. j, k = symbols("j,k")
  12. A = MatrixSymbol("A", k, k)
  13. B = MatrixSymbol("B", k, k)
  14. C = MatrixSymbol("C", k, k)
  15. D = MatrixSymbol("D", k, k)
  16. a = MatrixSymbol("a", k, 1)
  17. b = MatrixSymbol("b", k, 1)
  18. A2 = MatrixSymbol("A2", 2, 2)
  19. B2 = MatrixSymbol("B2", 2, 2)
  20. X = RandomMatrixSymbol("X", k, 1)
  21. Y = RandomMatrixSymbol("Y", k, 1)
  22. Z = RandomMatrixSymbol("Z", k, 1)
  23. W = RandomMatrixSymbol("W", k, 1)
  24. R = RandomMatrixSymbol("R", k, k)
  25. X2 = RandomMatrixSymbol("X2", 2, 1)
  26. normal = Normal("normal", 0, 1)
  27. m1 = Matrix([
  28. [1, j*Normal("normal2", 2, 1)],
  29. [normal, 0]
  30. ])
  31. def test_multivariate_expectation():
  32. expr = Expectation(a)
  33. assert expr == Expectation(a) == ExpectationMatrix(a)
  34. assert expr.expand() == a
  35. expr = Expectation(X)
  36. assert expr == Expectation(X) == ExpectationMatrix(X)
  37. assert expr.shape == (k, 1)
  38. assert expr.rows == k
  39. assert expr.cols == 1
  40. assert isinstance(expr, ExpectationMatrix)
  41. expr = Expectation(A*X + b)
  42. assert expr == ExpectationMatrix(A*X + b)
  43. assert expr.expand() == A*ExpectationMatrix(X) + b
  44. assert isinstance(expr, ExpectationMatrix)
  45. assert expr.shape == (k, 1)
  46. expr = Expectation(m1*X2)
  47. assert expr.expand() == expr
  48. expr = Expectation(A2*m1*B2*X2)
  49. assert expr.args[0].args == (A2, m1, B2, X2)
  50. assert expr.expand() == A2*ExpectationMatrix(m1*B2*X2)
  51. expr = Expectation((X + Y)*(X - Y).T)
  52. assert expr.expand() == ExpectationMatrix(X*X.T) - ExpectationMatrix(X*Y.T) +\
  53. ExpectationMatrix(Y*X.T) - ExpectationMatrix(Y*Y.T)
  54. expr = Expectation(A*X + B*Y)
  55. assert expr.expand() == A*ExpectationMatrix(X) + B*ExpectationMatrix(Y)
  56. assert Expectation(m1).doit() == Matrix([[1, 2*j], [0, 0]])
  57. x1 = Matrix([
  58. [Normal('N11', 11, 1), Normal('N12', 12, 1)],
  59. [Normal('N21', 21, 1), Normal('N22', 22, 1)]
  60. ])
  61. x2 = Matrix([
  62. [Normal('M11', 1, 1), Normal('M12', 2, 1)],
  63. [Normal('M21', 3, 1), Normal('M22', 4, 1)]
  64. ])
  65. assert Expectation(Expectation(x1 + x2)).doit(deep=False) == ExpectationMatrix(x1 + x2)
  66. assert Expectation(Expectation(x1 + x2)).doit() == Matrix([[12, 14], [24, 26]])
  67. def test_multivariate_variance():
  68. raises(ShapeError, lambda: Variance(A))
  69. expr = Variance(a) # type: VarianceMatrix
  70. assert expr == Variance(a) == VarianceMatrix(a)
  71. assert expr.expand() == ZeroMatrix(k, k)
  72. expr = Variance(a.T)
  73. assert expr == Variance(a.T) == VarianceMatrix(a.T)
  74. assert expr.expand() == ZeroMatrix(k, k)
  75. expr = Variance(X)
  76. assert expr == Variance(X) == VarianceMatrix(X)
  77. assert expr.shape == (k, k)
  78. assert expr.rows == k
  79. assert expr.cols == k
  80. assert isinstance(expr, VarianceMatrix)
  81. expr = Variance(A*X)
  82. assert expr == VarianceMatrix(A*X)
  83. assert expr.expand() == A*VarianceMatrix(X)*A.T
  84. assert isinstance(expr, VarianceMatrix)
  85. assert expr.shape == (k, k)
  86. expr = Variance(A*B*X)
  87. assert expr.expand() == A*B*VarianceMatrix(X)*B.T*A.T
  88. expr = Variance(m1*X2)
  89. assert expr.expand() == expr
  90. expr = Variance(A2*m1*B2*X2)
  91. assert expr.args[0].args == (A2, m1, B2, X2)
  92. assert expr.expand() == expr
  93. expr = Variance(A*X + B*Y)
  94. assert expr.expand() == 2*A*CrossCovarianceMatrix(X, Y)*B.T +\
  95. A*VarianceMatrix(X)*A.T + B*VarianceMatrix(Y)*B.T
  96. def test_multivariate_crosscovariance():
  97. raises(ShapeError, lambda: Covariance(X, Y.T))
  98. raises(ShapeError, lambda: Covariance(X, A))
  99. expr = Covariance(a.T, b.T)
  100. assert expr.shape == (1, 1)
  101. assert expr.expand() == ZeroMatrix(1, 1)
  102. expr = Covariance(a, b)
  103. assert expr == Covariance(a, b) == CrossCovarianceMatrix(a, b)
  104. assert expr.expand() == ZeroMatrix(k, k)
  105. assert expr.shape == (k, k)
  106. assert expr.rows == k
  107. assert expr.cols == k
  108. assert isinstance(expr, CrossCovarianceMatrix)
  109. expr = Covariance(A*X + a, b)
  110. assert expr.expand() == ZeroMatrix(k, k)
  111. expr = Covariance(X, Y)
  112. assert isinstance(expr, CrossCovarianceMatrix)
  113. assert expr.expand() == expr
  114. expr = Covariance(X, X)
  115. assert isinstance(expr, CrossCovarianceMatrix)
  116. assert expr.expand() == VarianceMatrix(X)
  117. expr = Covariance(X + Y, Z)
  118. assert isinstance(expr, CrossCovarianceMatrix)
  119. assert expr.expand() == CrossCovarianceMatrix(X, Z) + CrossCovarianceMatrix(Y, Z)
  120. expr = Covariance(A*X, Y)
  121. assert isinstance(expr, CrossCovarianceMatrix)
  122. assert expr.expand() == A*CrossCovarianceMatrix(X, Y)
  123. expr = Covariance(X, B*Y)
  124. assert isinstance(expr, CrossCovarianceMatrix)
  125. assert expr.expand() == CrossCovarianceMatrix(X, Y)*B.T
  126. expr = Covariance(A*X + a, B.T*Y + b)
  127. assert isinstance(expr, CrossCovarianceMatrix)
  128. assert expr.expand() == A*CrossCovarianceMatrix(X, Y)*B
  129. expr = Covariance(A*X + B*Y + a, C.T*Z + D.T*W + b)
  130. assert isinstance(expr, CrossCovarianceMatrix)
  131. assert expr.expand() == A*CrossCovarianceMatrix(X, W)*D + A*CrossCovarianceMatrix(X, Z)*C \
  132. + B*CrossCovarianceMatrix(Y, W)*D + B*CrossCovarianceMatrix(Y, Z)*C