m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

308 lines
10 KiB

7 months ago
  1. import itertools
  2. from sympy.core.add import Add
  3. from sympy.core.expr import Expr
  4. from sympy.core.function import expand as _expand
  5. from sympy.core.mul import Mul
  6. from sympy.core.singleton import S
  7. from sympy.matrices.common import ShapeError
  8. from sympy.matrices.expressions.matexpr import MatrixExpr
  9. from sympy.matrices.expressions.matmul import MatMul
  10. from sympy.matrices.expressions.special import ZeroMatrix
  11. from sympy.stats.rv import RandomSymbol, is_random
  12. from sympy.core.sympify import _sympify
  13. from sympy.stats.symbolic_probability import Variance, Covariance, Expectation
  14. class ExpectationMatrix(Expectation, MatrixExpr):
  15. """
  16. Expectation of a random matrix expression.
  17. Examples
  18. ========
  19. >>> from sympy.stats import ExpectationMatrix, Normal
  20. >>> from sympy.stats.rv import RandomMatrixSymbol
  21. >>> from sympy import symbols, MatrixSymbol, Matrix
  22. >>> k = symbols("k")
  23. >>> A, B = MatrixSymbol("A", k, k), MatrixSymbol("B", k, k)
  24. >>> X, Y = RandomMatrixSymbol("X", k, 1), RandomMatrixSymbol("Y", k, 1)
  25. >>> ExpectationMatrix(X)
  26. ExpectationMatrix(X)
  27. >>> ExpectationMatrix(A*X).shape
  28. (k, 1)
  29. To expand the expectation in its expression, use ``expand()``:
  30. >>> ExpectationMatrix(A*X + B*Y).expand()
  31. A*ExpectationMatrix(X) + B*ExpectationMatrix(Y)
  32. >>> ExpectationMatrix((X + Y)*(X - Y).T).expand()
  33. ExpectationMatrix(X*X.T) - ExpectationMatrix(X*Y.T) + ExpectationMatrix(Y*X.T) - ExpectationMatrix(Y*Y.T)
  34. To evaluate the ``ExpectationMatrix``, use ``doit()``:
  35. >>> N11, N12 = Normal('N11', 11, 1), Normal('N12', 12, 1)
  36. >>> N21, N22 = Normal('N21', 21, 1), Normal('N22', 22, 1)
  37. >>> M11, M12 = Normal('M11', 1, 1), Normal('M12', 2, 1)
  38. >>> M21, M22 = Normal('M21', 3, 1), Normal('M22', 4, 1)
  39. >>> x1 = Matrix([[N11, N12], [N21, N22]])
  40. >>> x2 = Matrix([[M11, M12], [M21, M22]])
  41. >>> ExpectationMatrix(x1 + x2).doit()
  42. Matrix([
  43. [12, 14],
  44. [24, 26]])
  45. """
  46. def __new__(cls, expr, condition=None):
  47. expr = _sympify(expr)
  48. if condition is None:
  49. if not is_random(expr):
  50. return expr
  51. obj = Expr.__new__(cls, expr)
  52. else:
  53. condition = _sympify(condition)
  54. obj = Expr.__new__(cls, expr, condition)
  55. obj._shape = expr.shape
  56. obj._condition = condition
  57. return obj
  58. @property
  59. def shape(self):
  60. return self._shape
  61. def expand(self, **hints):
  62. expr = self.args[0]
  63. condition = self._condition
  64. if not is_random(expr):
  65. return expr
  66. if isinstance(expr, Add):
  67. return Add.fromiter(Expectation(a, condition=condition).expand()
  68. for a in expr.args)
  69. expand_expr = _expand(expr)
  70. if isinstance(expand_expr, Add):
  71. return Add.fromiter(Expectation(a, condition=condition).expand()
  72. for a in expand_expr.args)
  73. elif isinstance(expr, (Mul, MatMul)):
  74. rv = []
  75. nonrv = []
  76. postnon = []
  77. for a in expr.args:
  78. if is_random(a):
  79. if rv:
  80. rv.extend(postnon)
  81. else:
  82. nonrv.extend(postnon)
  83. postnon = []
  84. rv.append(a)
  85. elif a.is_Matrix:
  86. postnon.append(a)
  87. else:
  88. nonrv.append(a)
  89. # In order to avoid infinite-looping (MatMul may call .doit() again),
  90. # do not rebuild
  91. if len(nonrv) == 0:
  92. return self
  93. return Mul.fromiter(nonrv)*Expectation(Mul.fromiter(rv),
  94. condition=condition)*Mul.fromiter(postnon)
  95. return self
  96. class VarianceMatrix(Variance, MatrixExpr):
  97. """
  98. Variance of a random matrix probability expression. Also known as
  99. Covariance matrix, auto-covariance matrix, dispersion matrix,
  100. or variance-covariance matrix.
  101. Examples
  102. ========
  103. >>> from sympy.stats import VarianceMatrix
  104. >>> from sympy.stats.rv import RandomMatrixSymbol
  105. >>> from sympy import symbols, MatrixSymbol
  106. >>> k = symbols("k")
  107. >>> A, B = MatrixSymbol("A", k, k), MatrixSymbol("B", k, k)
  108. >>> X, Y = RandomMatrixSymbol("X", k, 1), RandomMatrixSymbol("Y", k, 1)
  109. >>> VarianceMatrix(X)
  110. VarianceMatrix(X)
  111. >>> VarianceMatrix(X).shape
  112. (k, k)
  113. To expand the variance in its expression, use ``expand()``:
  114. >>> VarianceMatrix(A*X).expand()
  115. A*VarianceMatrix(X)*A.T
  116. >>> VarianceMatrix(A*X + B*Y).expand()
  117. 2*A*CrossCovarianceMatrix(X, Y)*B.T + A*VarianceMatrix(X)*A.T + B*VarianceMatrix(Y)*B.T
  118. """
  119. def __new__(cls, arg, condition=None):
  120. arg = _sympify(arg)
  121. if 1 not in arg.shape:
  122. raise ShapeError("Expression is not a vector")
  123. shape = (arg.shape[0], arg.shape[0]) if arg.shape[1] == 1 else (arg.shape[1], arg.shape[1])
  124. if condition:
  125. obj = Expr.__new__(cls, arg, condition)
  126. else:
  127. obj = Expr.__new__(cls, arg)
  128. obj._shape = shape
  129. obj._condition = condition
  130. return obj
  131. @property
  132. def shape(self):
  133. return self._shape
  134. def expand(self, **hints):
  135. arg = self.args[0]
  136. condition = self._condition
  137. if not is_random(arg):
  138. return ZeroMatrix(*self.shape)
  139. if isinstance(arg, RandomSymbol):
  140. return self
  141. elif isinstance(arg, Add):
  142. rv = []
  143. for a in arg.args:
  144. if is_random(a):
  145. rv.append(a)
  146. variances = Add(*map(lambda xv: Variance(xv, condition).expand(), rv))
  147. map_to_covar = lambda x: 2*Covariance(*x, condition=condition).expand()
  148. covariances = Add(*map(map_to_covar, itertools.combinations(rv, 2)))
  149. return variances + covariances
  150. elif isinstance(arg, (Mul, MatMul)):
  151. nonrv = []
  152. rv = []
  153. for a in arg.args:
  154. if is_random(a):
  155. rv.append(a)
  156. else:
  157. nonrv.append(a)
  158. if len(rv) == 0:
  159. return ZeroMatrix(*self.shape)
  160. # Avoid possible infinite loops with MatMul:
  161. if len(nonrv) == 0:
  162. return self
  163. # Variance of many multiple matrix products is not implemented:
  164. if len(rv) > 1:
  165. return self
  166. return Mul.fromiter(nonrv)*Variance(Mul.fromiter(rv),
  167. condition)*(Mul.fromiter(nonrv)).transpose()
  168. # this expression contains a RandomSymbol somehow:
  169. return self
  170. class CrossCovarianceMatrix(Covariance, MatrixExpr):
  171. """
  172. Covariance of a random matrix probability expression.
  173. Examples
  174. ========
  175. >>> from sympy.stats import CrossCovarianceMatrix
  176. >>> from sympy.stats.rv import RandomMatrixSymbol
  177. >>> from sympy import symbols, MatrixSymbol
  178. >>> k = symbols("k")
  179. >>> A, B = MatrixSymbol("A", k, k), MatrixSymbol("B", k, k)
  180. >>> C, D = MatrixSymbol("C", k, k), MatrixSymbol("D", k, k)
  181. >>> X, Y = RandomMatrixSymbol("X", k, 1), RandomMatrixSymbol("Y", k, 1)
  182. >>> Z, W = RandomMatrixSymbol("Z", k, 1), RandomMatrixSymbol("W", k, 1)
  183. >>> CrossCovarianceMatrix(X, Y)
  184. CrossCovarianceMatrix(X, Y)
  185. >>> CrossCovarianceMatrix(X, Y).shape
  186. (k, k)
  187. To expand the covariance in its expression, use ``expand()``:
  188. >>> CrossCovarianceMatrix(X + Y, Z).expand()
  189. CrossCovarianceMatrix(X, Z) + CrossCovarianceMatrix(Y, Z)
  190. >>> CrossCovarianceMatrix(A*X, Y).expand()
  191. A*CrossCovarianceMatrix(X, Y)
  192. >>> CrossCovarianceMatrix(A*X, B.T*Y).expand()
  193. A*CrossCovarianceMatrix(X, Y)*B
  194. >>> CrossCovarianceMatrix(A*X + B*Y, C.T*Z + D.T*W).expand()
  195. A*CrossCovarianceMatrix(X, W)*D + A*CrossCovarianceMatrix(X, Z)*C + B*CrossCovarianceMatrix(Y, W)*D + B*CrossCovarianceMatrix(Y, Z)*C
  196. """
  197. def __new__(cls, arg1, arg2, condition=None):
  198. arg1 = _sympify(arg1)
  199. arg2 = _sympify(arg2)
  200. if (1 not in arg1.shape) or (1 not in arg2.shape) or (arg1.shape[1] != arg2.shape[1]):
  201. raise ShapeError("Expression is not a vector")
  202. shape = (arg1.shape[0], arg2.shape[0]) if arg1.shape[1] == 1 and arg2.shape[1] == 1 \
  203. else (1, 1)
  204. if condition:
  205. obj = Expr.__new__(cls, arg1, arg2, condition)
  206. else:
  207. obj = Expr.__new__(cls, arg1, arg2)
  208. obj._shape = shape
  209. obj._condition = condition
  210. return obj
  211. @property
  212. def shape(self):
  213. return self._shape
  214. def expand(self, **hints):
  215. arg1 = self.args[0]
  216. arg2 = self.args[1]
  217. condition = self._condition
  218. if arg1 == arg2:
  219. return VarianceMatrix(arg1, condition).expand()
  220. if not is_random(arg1) or not is_random(arg2):
  221. return ZeroMatrix(*self.shape)
  222. if isinstance(arg1, RandomSymbol) and isinstance(arg2, RandomSymbol):
  223. return CrossCovarianceMatrix(arg1, arg2, condition)
  224. coeff_rv_list1 = self._expand_single_argument(arg1.expand())
  225. coeff_rv_list2 = self._expand_single_argument(arg2.expand())
  226. addends = [a*CrossCovarianceMatrix(r1, r2, condition=condition)*b.transpose()
  227. for (a, r1) in coeff_rv_list1 for (b, r2) in coeff_rv_list2]
  228. return Add.fromiter(addends)
  229. @classmethod
  230. def _expand_single_argument(cls, expr):
  231. # return (coefficient, random_symbol) pairs:
  232. if isinstance(expr, RandomSymbol):
  233. return [(S.One, expr)]
  234. elif isinstance(expr, Add):
  235. outval = []
  236. for a in expr.args:
  237. if isinstance(a, (Mul, MatMul)):
  238. outval.append(cls._get_mul_nonrv_rv_tuple(a))
  239. elif is_random(a):
  240. outval.append((S.One, a))
  241. return outval
  242. elif isinstance(expr, (Mul, MatMul)):
  243. return [cls._get_mul_nonrv_rv_tuple(expr)]
  244. elif is_random(expr):
  245. return [(S.One, expr)]
  246. @classmethod
  247. def _get_mul_nonrv_rv_tuple(cls, m):
  248. rv = []
  249. nonrv = []
  250. for a in m.args:
  251. if is_random(a):
  252. rv.append(a)
  253. else:
  254. nonrv.append(a)
  255. return (Mul.fromiter(nonrv), Mul.fromiter(rv))