图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

308 lines
10 KiB

  1. import itertools
  2. from sympy.core.add import Add
  3. from sympy.core.expr import Expr
  4. from sympy.core.function import expand as _expand
  5. from sympy.core.mul import Mul
  6. from sympy.core.singleton import S
  7. from sympy.matrices.common import ShapeError
  8. from sympy.matrices.expressions.matexpr import MatrixExpr
  9. from sympy.matrices.expressions.matmul import MatMul
  10. from sympy.matrices.expressions.special import ZeroMatrix
  11. from sympy.stats.rv import RandomSymbol, is_random
  12. from sympy.core.sympify import _sympify
  13. from sympy.stats.symbolic_probability import Variance, Covariance, Expectation
  14. class ExpectationMatrix(Expectation, MatrixExpr):
  15. """
  16. Expectation of a random matrix expression.
  17. Examples
  18. ========
  19. >>> from sympy.stats import ExpectationMatrix, Normal
  20. >>> from sympy.stats.rv import RandomMatrixSymbol
  21. >>> from sympy import symbols, MatrixSymbol, Matrix
  22. >>> k = symbols("k")
  23. >>> A, B = MatrixSymbol("A", k, k), MatrixSymbol("B", k, k)
  24. >>> X, Y = RandomMatrixSymbol("X", k, 1), RandomMatrixSymbol("Y", k, 1)
  25. >>> ExpectationMatrix(X)
  26. ExpectationMatrix(X)
  27. >>> ExpectationMatrix(A*X).shape
  28. (k, 1)
  29. To expand the expectation in its expression, use ``expand()``:
  30. >>> ExpectationMatrix(A*X + B*Y).expand()
  31. A*ExpectationMatrix(X) + B*ExpectationMatrix(Y)
  32. >>> ExpectationMatrix((X + Y)*(X - Y).T).expand()
  33. ExpectationMatrix(X*X.T) - ExpectationMatrix(X*Y.T) + ExpectationMatrix(Y*X.T) - ExpectationMatrix(Y*Y.T)
  34. To evaluate the ``ExpectationMatrix``, use ``doit()``:
  35. >>> N11, N12 = Normal('N11', 11, 1), Normal('N12', 12, 1)
  36. >>> N21, N22 = Normal('N21', 21, 1), Normal('N22', 22, 1)
  37. >>> M11, M12 = Normal('M11', 1, 1), Normal('M12', 2, 1)
  38. >>> M21, M22 = Normal('M21', 3, 1), Normal('M22', 4, 1)
  39. >>> x1 = Matrix([[N11, N12], [N21, N22]])
  40. >>> x2 = Matrix([[M11, M12], [M21, M22]])
  41. >>> ExpectationMatrix(x1 + x2).doit()
  42. Matrix([
  43. [12, 14],
  44. [24, 26]])
  45. """
  46. def __new__(cls, expr, condition=None):
  47. expr = _sympify(expr)
  48. if condition is None:
  49. if not is_random(expr):
  50. return expr
  51. obj = Expr.__new__(cls, expr)
  52. else:
  53. condition = _sympify(condition)
  54. obj = Expr.__new__(cls, expr, condition)
  55. obj._shape = expr.shape
  56. obj._condition = condition
  57. return obj
  58. @property
  59. def shape(self):
  60. return self._shape
  61. def expand(self, **hints):
  62. expr = self.args[0]
  63. condition = self._condition
  64. if not is_random(expr):
  65. return expr
  66. if isinstance(expr, Add):
  67. return Add.fromiter(Expectation(a, condition=condition).expand()
  68. for a in expr.args)
  69. expand_expr = _expand(expr)
  70. if isinstance(expand_expr, Add):
  71. return Add.fromiter(Expectation(a, condition=condition).expand()
  72. for a in expand_expr.args)
  73. elif isinstance(expr, (Mul, MatMul)):
  74. rv = []
  75. nonrv = []
  76. postnon = []
  77. for a in expr.args:
  78. if is_random(a):
  79. if rv:
  80. rv.extend(postnon)
  81. else:
  82. nonrv.extend(postnon)
  83. postnon = []
  84. rv.append(a)
  85. elif a.is_Matrix:
  86. postnon.append(a)
  87. else:
  88. nonrv.append(a)
  89. # In order to avoid infinite-looping (MatMul may call .doit() again),
  90. # do not rebuild
  91. if len(nonrv) == 0:
  92. return self
  93. return Mul.fromiter(nonrv)*Expectation(Mul.fromiter(rv),
  94. condition=condition)*Mul.fromiter(postnon)
  95. return self
  96. class VarianceMatrix(Variance, MatrixExpr):
  97. """
  98. Variance of a random matrix probability expression. Also known as
  99. Covariance matrix, auto-covariance matrix, dispersion matrix,
  100. or variance-covariance matrix.
  101. Examples
  102. ========
  103. >>> from sympy.stats import VarianceMatrix
  104. >>> from sympy.stats.rv import RandomMatrixSymbol
  105. >>> from sympy import symbols, MatrixSymbol
  106. >>> k = symbols("k")
  107. >>> A, B = MatrixSymbol("A", k, k), MatrixSymbol("B", k, k)
  108. >>> X, Y = RandomMatrixSymbol("X", k, 1), RandomMatrixSymbol("Y", k, 1)
  109. >>> VarianceMatrix(X)
  110. VarianceMatrix(X)
  111. >>> VarianceMatrix(X).shape
  112. (k, k)
  113. To expand the variance in its expression, use ``expand()``:
  114. >>> VarianceMatrix(A*X).expand()
  115. A*VarianceMatrix(X)*A.T
  116. >>> VarianceMatrix(A*X + B*Y).expand()
  117. 2*A*CrossCovarianceMatrix(X, Y)*B.T + A*VarianceMatrix(X)*A.T + B*VarianceMatrix(Y)*B.T
  118. """
  119. def __new__(cls, arg, condition=None):
  120. arg = _sympify(arg)
  121. if 1 not in arg.shape:
  122. raise ShapeError("Expression is not a vector")
  123. shape = (arg.shape[0], arg.shape[0]) if arg.shape[1] == 1 else (arg.shape[1], arg.shape[1])
  124. if condition:
  125. obj = Expr.__new__(cls, arg, condition)
  126. else:
  127. obj = Expr.__new__(cls, arg)
  128. obj._shape = shape
  129. obj._condition = condition
  130. return obj
  131. @property
  132. def shape(self):
  133. return self._shape
  134. def expand(self, **hints):
  135. arg = self.args[0]
  136. condition = self._condition
  137. if not is_random(arg):
  138. return ZeroMatrix(*self.shape)
  139. if isinstance(arg, RandomSymbol):
  140. return self
  141. elif isinstance(arg, Add):
  142. rv = []
  143. for a in arg.args:
  144. if is_random(a):
  145. rv.append(a)
  146. variances = Add(*map(lambda xv: Variance(xv, condition).expand(), rv))
  147. map_to_covar = lambda x: 2*Covariance(*x, condition=condition).expand()
  148. covariances = Add(*map(map_to_covar, itertools.combinations(rv, 2)))
  149. return variances + covariances
  150. elif isinstance(arg, (Mul, MatMul)):
  151. nonrv = []
  152. rv = []
  153. for a in arg.args:
  154. if is_random(a):
  155. rv.append(a)
  156. else:
  157. nonrv.append(a)
  158. if len(rv) == 0:
  159. return ZeroMatrix(*self.shape)
  160. # Avoid possible infinite loops with MatMul:
  161. if len(nonrv) == 0:
  162. return self
  163. # Variance of many multiple matrix products is not implemented:
  164. if len(rv) > 1:
  165. return self
  166. return Mul.fromiter(nonrv)*Variance(Mul.fromiter(rv),
  167. condition)*(Mul.fromiter(nonrv)).transpose()
  168. # this expression contains a RandomSymbol somehow:
  169. return self
  170. class CrossCovarianceMatrix(Covariance, MatrixExpr):
  171. """
  172. Covariance of a random matrix probability expression.
  173. Examples
  174. ========
  175. >>> from sympy.stats import CrossCovarianceMatrix
  176. >>> from sympy.stats.rv import RandomMatrixSymbol
  177. >>> from sympy import symbols, MatrixSymbol
  178. >>> k = symbols("k")
  179. >>> A, B = MatrixSymbol("A", k, k), MatrixSymbol("B", k, k)
  180. >>> C, D = MatrixSymbol("C", k, k), MatrixSymbol("D", k, k)
  181. >>> X, Y = RandomMatrixSymbol("X", k, 1), RandomMatrixSymbol("Y", k, 1)
  182. >>> Z, W = RandomMatrixSymbol("Z", k, 1), RandomMatrixSymbol("W", k, 1)
  183. >>> CrossCovarianceMatrix(X, Y)
  184. CrossCovarianceMatrix(X, Y)
  185. >>> CrossCovarianceMatrix(X, Y).shape
  186. (k, k)
  187. To expand the covariance in its expression, use ``expand()``:
  188. >>> CrossCovarianceMatrix(X + Y, Z).expand()
  189. CrossCovarianceMatrix(X, Z) + CrossCovarianceMatrix(Y, Z)
  190. >>> CrossCovarianceMatrix(A*X, Y).expand()
  191. A*CrossCovarianceMatrix(X, Y)
  192. >>> CrossCovarianceMatrix(A*X, B.T*Y).expand()
  193. A*CrossCovarianceMatrix(X, Y)*B
  194. >>> CrossCovarianceMatrix(A*X + B*Y, C.T*Z + D.T*W).expand()
  195. A*CrossCovarianceMatrix(X, W)*D + A*CrossCovarianceMatrix(X, Z)*C + B*CrossCovarianceMatrix(Y, W)*D + B*CrossCovarianceMatrix(Y, Z)*C
  196. """
  197. def __new__(cls, arg1, arg2, condition=None):
  198. arg1 = _sympify(arg1)
  199. arg2 = _sympify(arg2)
  200. if (1 not in arg1.shape) or (1 not in arg2.shape) or (arg1.shape[1] != arg2.shape[1]):
  201. raise ShapeError("Expression is not a vector")
  202. shape = (arg1.shape[0], arg2.shape[0]) if arg1.shape[1] == 1 and arg2.shape[1] == 1 \
  203. else (1, 1)
  204. if condition:
  205. obj = Expr.__new__(cls, arg1, arg2, condition)
  206. else:
  207. obj = Expr.__new__(cls, arg1, arg2)
  208. obj._shape = shape
  209. obj._condition = condition
  210. return obj
  211. @property
  212. def shape(self):
  213. return self._shape
  214. def expand(self, **hints):
  215. arg1 = self.args[0]
  216. arg2 = self.args[1]
  217. condition = self._condition
  218. if arg1 == arg2:
  219. return VarianceMatrix(arg1, condition).expand()
  220. if not is_random(arg1) or not is_random(arg2):
  221. return ZeroMatrix(*self.shape)
  222. if isinstance(arg1, RandomSymbol) and isinstance(arg2, RandomSymbol):
  223. return CrossCovarianceMatrix(arg1, arg2, condition)
  224. coeff_rv_list1 = self._expand_single_argument(arg1.expand())
  225. coeff_rv_list2 = self._expand_single_argument(arg2.expand())
  226. addends = [a*CrossCovarianceMatrix(r1, r2, condition=condition)*b.transpose()
  227. for (a, r1) in coeff_rv_list1 for (b, r2) in coeff_rv_list2]
  228. return Add.fromiter(addends)
  229. @classmethod
  230. def _expand_single_argument(cls, expr):
  231. # return (coefficient, random_symbol) pairs:
  232. if isinstance(expr, RandomSymbol):
  233. return [(S.One, expr)]
  234. elif isinstance(expr, Add):
  235. outval = []
  236. for a in expr.args:
  237. if isinstance(a, (Mul, MatMul)):
  238. outval.append(cls._get_mul_nonrv_rv_tuple(a))
  239. elif is_random(a):
  240. outval.append((S.One, a))
  241. return outval
  242. elif isinstance(expr, (Mul, MatMul)):
  243. return [cls._get_mul_nonrv_rv_tuple(expr)]
  244. elif is_random(expr):
  245. return [(S.One, expr)]
  246. @classmethod
  247. def _get_mul_nonrv_rv_tuple(cls, m):
  248. rv = []
  249. nonrv = []
  250. for a in m.args:
  251. if is_random(a):
  252. rv.append(a)
  253. else:
  254. nonrv.append(a)
  255. return (Mul.fromiter(nonrv), Mul.fromiter(rv))