图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

605 lines
21 KiB

  1. from sympy.core.basic import Basic
  2. from sympy.core.mul import prod
  3. from sympy.core.numbers import pi
  4. from sympy.core.singleton import S
  5. from sympy.functions.elementary.exponential import exp
  6. from sympy.functions.special.gamma_functions import multigamma
  7. from sympy.core.sympify import sympify, _sympify
  8. from sympy.matrices import (ImmutableMatrix, Inverse, Trace, Determinant,
  9. MatrixSymbol, MatrixBase, Transpose, MatrixSet,
  10. matrix2numpy)
  11. from sympy.stats.rv import (_value_check, RandomMatrixSymbol, NamedArgsMixin, PSpace,
  12. _symbol_converter, MatrixDomain, Distribution)
  13. from sympy.external import import_module
  14. ################################################################################
  15. #------------------------Matrix Probability Space------------------------------#
  16. ################################################################################
  17. class MatrixPSpace(PSpace):
  18. """
  19. Represents probability space for
  20. Matrix Distributions.
  21. """
  22. def __new__(cls, sym, distribution, dim_n, dim_m):
  23. sym = _symbol_converter(sym)
  24. dim_n, dim_m = _sympify(dim_n), _sympify(dim_m)
  25. if not (dim_n.is_integer and dim_m.is_integer):
  26. raise ValueError("Dimensions should be integers")
  27. return Basic.__new__(cls, sym, distribution, dim_n, dim_m)
  28. distribution = property(lambda self: self.args[1])
  29. symbol = property(lambda self: self.args[0])
  30. @property
  31. def domain(self):
  32. return MatrixDomain(self.symbol, self.distribution.set)
  33. @property
  34. def value(self):
  35. return RandomMatrixSymbol(self.symbol, self.args[2], self.args[3], self)
  36. @property
  37. def values(self):
  38. return {self.value}
  39. def compute_density(self, expr, *args):
  40. rms = expr.atoms(RandomMatrixSymbol)
  41. if len(rms) > 1 or (not isinstance(expr, RandomMatrixSymbol)):
  42. raise NotImplementedError("Currently, no algorithm has been "
  43. "implemented to handle general expressions containing "
  44. "multiple matrix distributions.")
  45. return self.distribution.pdf(expr)
  46. def sample(self, size=(), library='scipy', seed=None):
  47. """
  48. Internal sample method
  49. Returns dictionary mapping RandomMatrixSymbol to realization value.
  50. """
  51. return {self.value: self.distribution.sample(size, library=library, seed=seed)}
  52. def rv(symbol, cls, args):
  53. args = list(map(sympify, args))
  54. dist = cls(*args)
  55. dist.check(*args)
  56. dim = dist.dimension
  57. pspace = MatrixPSpace(symbol, dist, dim[0], dim[1])
  58. return pspace.value
  59. class SampleMatrixScipy:
  60. """Returns the sample from scipy of the given distribution"""
  61. def __new__(cls, dist, size, seed=None):
  62. return cls._sample_scipy(dist, size, seed)
  63. @classmethod
  64. def _sample_scipy(cls, dist, size, seed):
  65. """Sample from SciPy."""
  66. from scipy import stats as scipy_stats
  67. import numpy
  68. scipy_rv_map = {
  69. 'WishartDistribution': lambda dist, size, rand_state: scipy_stats.wishart.rvs(
  70. df=int(dist.n), scale=matrix2numpy(dist.scale_matrix, float), size=size),
  71. 'MatrixNormalDistribution': lambda dist, size, rand_state: scipy_stats.matrix_normal.rvs(
  72. mean=matrix2numpy(dist.location_matrix, float),
  73. rowcov=matrix2numpy(dist.scale_matrix_1, float),
  74. colcov=matrix2numpy(dist.scale_matrix_2, float), size=size, random_state=rand_state)
  75. }
  76. sample_shape = {
  77. 'WishartDistribution': lambda dist: dist.scale_matrix.shape,
  78. 'MatrixNormalDistribution' : lambda dist: dist.location_matrix.shape
  79. }
  80. dist_list = scipy_rv_map.keys()
  81. if dist.__class__.__name__ not in dist_list:
  82. return None
  83. if seed is None or isinstance(seed, int):
  84. rand_state = numpy.random.default_rng(seed=seed)
  85. else:
  86. rand_state = seed
  87. samp = scipy_rv_map[dist.__class__.__name__](dist, prod(size), rand_state)
  88. return samp.reshape(size + sample_shape[dist.__class__.__name__](dist))
  89. class SampleMatrixNumpy:
  90. """Returns the sample from numpy of the given distribution"""
  91. ### TODO: Add tests after adding matrix distributions in numpy_rv_map
  92. def __new__(cls, dist, size, seed=None):
  93. return cls._sample_numpy(dist, size, seed)
  94. @classmethod
  95. def _sample_numpy(cls, dist, size, seed):
  96. """Sample from NumPy."""
  97. numpy_rv_map = {
  98. }
  99. sample_shape = {
  100. }
  101. dist_list = numpy_rv_map.keys()
  102. if dist.__class__.__name__ not in dist_list:
  103. return None
  104. import numpy
  105. if seed is None or isinstance(seed, int):
  106. rand_state = numpy.random.default_rng(seed=seed)
  107. else:
  108. rand_state = seed
  109. samp = numpy_rv_map[dist.__class__.__name__](dist, prod(size), rand_state)
  110. return samp.reshape(size + sample_shape[dist.__class__.__name__](dist))
  111. class SampleMatrixPymc:
  112. """Returns the sample from pymc3 of the given distribution"""
  113. def __new__(cls, dist, size, seed=None):
  114. return cls._sample_pymc3(dist, size, seed)
  115. @classmethod
  116. def _sample_pymc3(cls, dist, size, seed):
  117. """Sample from PyMC3."""
  118. import pymc3
  119. pymc3_rv_map = {
  120. 'MatrixNormalDistribution': lambda dist: pymc3.MatrixNormal('X',
  121. mu=matrix2numpy(dist.location_matrix, float),
  122. rowcov=matrix2numpy(dist.scale_matrix_1, float),
  123. colcov=matrix2numpy(dist.scale_matrix_2, float),
  124. shape=dist.location_matrix.shape),
  125. 'WishartDistribution': lambda dist: pymc3.WishartBartlett('X',
  126. nu=int(dist.n), S=matrix2numpy(dist.scale_matrix, float))
  127. }
  128. sample_shape = {
  129. 'WishartDistribution': lambda dist: dist.scale_matrix.shape,
  130. 'MatrixNormalDistribution' : lambda dist: dist.location_matrix.shape
  131. }
  132. dist_list = pymc3_rv_map.keys()
  133. if dist.__class__.__name__ not in dist_list:
  134. return None
  135. import logging
  136. logging.getLogger("pymc3").setLevel(logging.ERROR)
  137. with pymc3.Model():
  138. pymc3_rv_map[dist.__class__.__name__](dist)
  139. samps = pymc3.sample(draws=prod(size), chains=1, progressbar=False, random_seed=seed, return_inferencedata=False, compute_convergence_checks=False)['X']
  140. return samps.reshape(size + sample_shape[dist.__class__.__name__](dist))
  141. _get_sample_class_matrixrv = {
  142. 'scipy': SampleMatrixScipy,
  143. 'pymc3': SampleMatrixPymc,
  144. 'numpy': SampleMatrixNumpy
  145. }
  146. ################################################################################
  147. #-------------------------Matrix Distribution----------------------------------#
  148. ################################################################################
  149. class MatrixDistribution(Distribution, NamedArgsMixin):
  150. """
  151. Abstract class for Matrix Distribution.
  152. """
  153. def __new__(cls, *args):
  154. args = [ImmutableMatrix(arg) if isinstance(arg, list)
  155. else _sympify(arg) for arg in args]
  156. return Basic.__new__(cls, *args)
  157. @staticmethod
  158. def check(*args):
  159. pass
  160. def __call__(self, expr):
  161. if isinstance(expr, list):
  162. expr = ImmutableMatrix(expr)
  163. return self.pdf(expr)
  164. def sample(self, size=(), library='scipy', seed=None):
  165. """
  166. Internal sample method
  167. Returns dictionary mapping RandomSymbol to realization value.
  168. """
  169. libraries = ['scipy', 'numpy', 'pymc3']
  170. if library not in libraries:
  171. raise NotImplementedError("Sampling from %s is not supported yet."
  172. % str(library))
  173. if not import_module(library):
  174. raise ValueError("Failed to import %s" % library)
  175. samps = _get_sample_class_matrixrv[library](self, size, seed)
  176. if samps is not None:
  177. return samps
  178. raise NotImplementedError(
  179. "Sampling for %s is not currently implemented from %s"
  180. % (self.__class__.__name__, library)
  181. )
  182. ################################################################################
  183. #------------------------Matrix Distribution Types-----------------------------#
  184. ################################################################################
  185. #-------------------------------------------------------------------------------
  186. # Matrix Gamma distribution ----------------------------------------------------
  187. class MatrixGammaDistribution(MatrixDistribution):
  188. _argnames = ('alpha', 'beta', 'scale_matrix')
  189. @staticmethod
  190. def check(alpha, beta, scale_matrix):
  191. if not isinstance(scale_matrix, MatrixSymbol):
  192. _value_check(scale_matrix.is_positive_definite, "The shape "
  193. "matrix must be positive definite.")
  194. _value_check(scale_matrix.is_square, "Should "
  195. "be square matrix")
  196. _value_check(alpha.is_positive, "Shape parameter should be positive.")
  197. _value_check(beta.is_positive, "Scale parameter should be positive.")
  198. @property
  199. def set(self):
  200. k = self.scale_matrix.shape[0]
  201. return MatrixSet(k, k, S.Reals)
  202. @property
  203. def dimension(self):
  204. return self.scale_matrix.shape
  205. def pdf(self, x):
  206. alpha, beta, scale_matrix = self.alpha, self.beta, self.scale_matrix
  207. p = scale_matrix.shape[0]
  208. if isinstance(x, list):
  209. x = ImmutableMatrix(x)
  210. if not isinstance(x, (MatrixBase, MatrixSymbol)):
  211. raise ValueError("%s should be an isinstance of Matrix "
  212. "or MatrixSymbol" % str(x))
  213. sigma_inv_x = - Inverse(scale_matrix)*x / beta
  214. term1 = exp(Trace(sigma_inv_x))/((beta**(p*alpha)) * multigamma(alpha, p))
  215. term2 = (Determinant(scale_matrix))**(-alpha)
  216. term3 = (Determinant(x))**(alpha - S(p + 1)/2)
  217. return term1 * term2 * term3
  218. def MatrixGamma(symbol, alpha, beta, scale_matrix):
  219. """
  220. Creates a random variable with Matrix Gamma Distribution.
  221. The density of the said distribution can be found at [1].
  222. Parameters
  223. ==========
  224. alpha: Positive Real number
  225. Shape Parameter
  226. beta: Positive Real number
  227. Scale Parameter
  228. scale_matrix: Positive definite real square matrix
  229. Scale Matrix
  230. Returns
  231. =======
  232. RandomSymbol
  233. Examples
  234. ========
  235. >>> from sympy.stats import density, MatrixGamma
  236. >>> from sympy import MatrixSymbol, symbols
  237. >>> a, b = symbols('a b', positive=True)
  238. >>> M = MatrixGamma('M', a, b, [[2, 1], [1, 2]])
  239. >>> X = MatrixSymbol('X', 2, 2)
  240. >>> density(M)(X).doit()
  241. exp(Trace(Matrix([
  242. [-2/3, 1/3],
  243. [ 1/3, -2/3]])*X)/b)*Determinant(X)**(a - 3/2)/(3**a*sqrt(pi)*b**(2*a)*gamma(a)*gamma(a - 1/2))
  244. >>> density(M)([[1, 0], [0, 1]]).doit()
  245. exp(-4/(3*b))/(3**a*sqrt(pi)*b**(2*a)*gamma(a)*gamma(a - 1/2))
  246. References
  247. ==========
  248. .. [1] https://en.wikipedia.org/wiki/Matrix_gamma_distribution
  249. """
  250. if isinstance(scale_matrix, list):
  251. scale_matrix = ImmutableMatrix(scale_matrix)
  252. return rv(symbol, MatrixGammaDistribution, (alpha, beta, scale_matrix))
  253. #-------------------------------------------------------------------------------
  254. # Wishart Distribution ---------------------------------------------------------
  255. class WishartDistribution(MatrixDistribution):
  256. _argnames = ('n', 'scale_matrix')
  257. @staticmethod
  258. def check(n, scale_matrix):
  259. if not isinstance(scale_matrix, MatrixSymbol):
  260. _value_check(scale_matrix.is_positive_definite, "The shape "
  261. "matrix must be positive definite.")
  262. _value_check(scale_matrix.is_square, "Should "
  263. "be square matrix")
  264. _value_check(n.is_positive, "Shape parameter should be positive.")
  265. @property
  266. def set(self):
  267. k = self.scale_matrix.shape[0]
  268. return MatrixSet(k, k, S.Reals)
  269. @property
  270. def dimension(self):
  271. return self.scale_matrix.shape
  272. def pdf(self, x):
  273. n, scale_matrix = self.n, self.scale_matrix
  274. p = scale_matrix.shape[0]
  275. if isinstance(x, list):
  276. x = ImmutableMatrix(x)
  277. if not isinstance(x, (MatrixBase, MatrixSymbol)):
  278. raise ValueError("%s should be an isinstance of Matrix "
  279. "or MatrixSymbol" % str(x))
  280. sigma_inv_x = - Inverse(scale_matrix)*x / S(2)
  281. term1 = exp(Trace(sigma_inv_x))/((2**(p*n/S(2))) * multigamma(n/S(2), p))
  282. term2 = (Determinant(scale_matrix))**(-n/S(2))
  283. term3 = (Determinant(x))**(S(n - p - 1)/2)
  284. return term1 * term2 * term3
  285. def Wishart(symbol, n, scale_matrix):
  286. """
  287. Creates a random variable with Wishart Distribution.
  288. The density of the said distribution can be found at [1].
  289. Parameters
  290. ==========
  291. n: Positive Real number
  292. Represents degrees of freedom
  293. scale_matrix: Positive definite real square matrix
  294. Scale Matrix
  295. Returns
  296. =======
  297. RandomSymbol
  298. Examples
  299. ========
  300. >>> from sympy.stats import density, Wishart
  301. >>> from sympy import MatrixSymbol, symbols
  302. >>> n = symbols('n', positive=True)
  303. >>> W = Wishart('W', n, [[2, 1], [1, 2]])
  304. >>> X = MatrixSymbol('X', 2, 2)
  305. >>> density(W)(X).doit()
  306. exp(Trace(Matrix([
  307. [-1/3, 1/6],
  308. [ 1/6, -1/3]])*X))*Determinant(X)**(n/2 - 3/2)/(2**n*3**(n/2)*sqrt(pi)*gamma(n/2)*gamma(n/2 - 1/2))
  309. >>> density(W)([[1, 0], [0, 1]]).doit()
  310. exp(-2/3)/(2**n*3**(n/2)*sqrt(pi)*gamma(n/2)*gamma(n/2 - 1/2))
  311. References
  312. ==========
  313. .. [1] https://en.wikipedia.org/wiki/Wishart_distribution
  314. """
  315. if isinstance(scale_matrix, list):
  316. scale_matrix = ImmutableMatrix(scale_matrix)
  317. return rv(symbol, WishartDistribution, (n, scale_matrix))
  318. #-------------------------------------------------------------------------------
  319. # Matrix Normal distribution ---------------------------------------------------
  320. class MatrixNormalDistribution(MatrixDistribution):
  321. _argnames = ('location_matrix', 'scale_matrix_1', 'scale_matrix_2')
  322. @staticmethod
  323. def check(location_matrix, scale_matrix_1, scale_matrix_2):
  324. if not isinstance(scale_matrix_1, MatrixSymbol):
  325. _value_check(scale_matrix_1.is_positive_definite, "The shape "
  326. "matrix must be positive definite.")
  327. if not isinstance(scale_matrix_2, MatrixSymbol):
  328. _value_check(scale_matrix_2.is_positive_definite, "The shape "
  329. "matrix must be positive definite.")
  330. _value_check(scale_matrix_1.is_square, "Scale matrix 1 should be "
  331. "be square matrix")
  332. _value_check(scale_matrix_2.is_square, "Scale matrix 2 should be "
  333. "be square matrix")
  334. n = location_matrix.shape[0]
  335. p = location_matrix.shape[1]
  336. _value_check(scale_matrix_1.shape[0] == n, "Scale matrix 1 should be"
  337. " of shape %s x %s"% (str(n), str(n)))
  338. _value_check(scale_matrix_2.shape[0] == p, "Scale matrix 2 should be"
  339. " of shape %s x %s"% (str(p), str(p)))
  340. @property
  341. def set(self):
  342. n, p = self.location_matrix.shape
  343. return MatrixSet(n, p, S.Reals)
  344. @property
  345. def dimension(self):
  346. return self.location_matrix.shape
  347. def pdf(self, x):
  348. M, U, V = self.location_matrix, self.scale_matrix_1, self.scale_matrix_2
  349. n, p = M.shape
  350. if isinstance(x, list):
  351. x = ImmutableMatrix(x)
  352. if not isinstance(x, (MatrixBase, MatrixSymbol)):
  353. raise ValueError("%s should be an isinstance of Matrix "
  354. "or MatrixSymbol" % str(x))
  355. term1 = Inverse(V)*Transpose(x - M)*Inverse(U)*(x - M)
  356. num = exp(-Trace(term1)/S(2))
  357. den = (2*pi)**(S(n*p)/2) * Determinant(U)**S(p)/2 * Determinant(V)**S(n)/2
  358. return num/den
  359. def MatrixNormal(symbol, location_matrix, scale_matrix_1, scale_matrix_2):
  360. """
  361. Creates a random variable with Matrix Normal Distribution.
  362. The density of the said distribution can be found at [1].
  363. Parameters
  364. ==========
  365. location_matrix: Real ``n x p`` matrix
  366. Represents degrees of freedom
  367. scale_matrix_1: Positive definite matrix
  368. Scale Matrix of shape ``n x n``
  369. scale_matrix_2: Positive definite matrix
  370. Scale Matrix of shape ``p x p``
  371. Returns
  372. =======
  373. RandomSymbol
  374. Examples
  375. ========
  376. >>> from sympy import MatrixSymbol
  377. >>> from sympy.stats import density, MatrixNormal
  378. >>> M = MatrixNormal('M', [[1, 2]], [1], [[1, 0], [0, 1]])
  379. >>> X = MatrixSymbol('X', 1, 2)
  380. >>> density(M)(X).doit()
  381. 2*exp(-Trace((Matrix([
  382. [-1],
  383. [-2]]) + X.T)*(Matrix([[-1, -2]]) + X))/2)/pi
  384. >>> density(M)([[3, 4]]).doit()
  385. 2*exp(-4)/pi
  386. References
  387. ==========
  388. .. [1] https://en.wikipedia.org/wiki/Matrix_normal_distribution
  389. """
  390. if isinstance(location_matrix, list):
  391. location_matrix = ImmutableMatrix(location_matrix)
  392. if isinstance(scale_matrix_1, list):
  393. scale_matrix_1 = ImmutableMatrix(scale_matrix_1)
  394. if isinstance(scale_matrix_2, list):
  395. scale_matrix_2 = ImmutableMatrix(scale_matrix_2)
  396. args = (location_matrix, scale_matrix_1, scale_matrix_2)
  397. return rv(symbol, MatrixNormalDistribution, args)
  398. #-------------------------------------------------------------------------------
  399. # Matrix Student's T distribution ---------------------------------------------------
  400. class MatrixStudentTDistribution(MatrixDistribution):
  401. _argnames = ('nu', 'location_matrix', 'scale_matrix_1', 'scale_matrix_2')
  402. @staticmethod
  403. def check(nu, location_matrix, scale_matrix_1, scale_matrix_2):
  404. if not isinstance(scale_matrix_1, MatrixSymbol):
  405. _value_check(scale_matrix_1.is_positive_definite != False, "The shape "
  406. "matrix must be positive definite.")
  407. if not isinstance(scale_matrix_2, MatrixSymbol):
  408. _value_check(scale_matrix_2.is_positive_definite != False, "The shape "
  409. "matrix must be positive definite.")
  410. _value_check(scale_matrix_1.is_square != False, "Scale matrix 1 should be "
  411. "be square matrix")
  412. _value_check(scale_matrix_2.is_square != False, "Scale matrix 2 should be "
  413. "be square matrix")
  414. n = location_matrix.shape[0]
  415. p = location_matrix.shape[1]
  416. _value_check(scale_matrix_1.shape[0] == p, "Scale matrix 1 should be"
  417. " of shape %s x %s" % (str(p), str(p)))
  418. _value_check(scale_matrix_2.shape[0] == n, "Scale matrix 2 should be"
  419. " of shape %s x %s" % (str(n), str(n)))
  420. _value_check(nu.is_positive != False, "Degrees of freedom must be positive")
  421. @property
  422. def set(self):
  423. n, p = self.location_matrix.shape
  424. return MatrixSet(n, p, S.Reals)
  425. @property
  426. def dimension(self):
  427. return self.location_matrix.shape
  428. def pdf(self, x):
  429. from sympy.matrices.dense import eye
  430. if isinstance(x, list):
  431. x = ImmutableMatrix(x)
  432. if not isinstance(x, (MatrixBase, MatrixSymbol)):
  433. raise ValueError("%s should be an isinstance of Matrix "
  434. "or MatrixSymbol" % str(x))
  435. nu, M, Omega, Sigma = self.nu, self.location_matrix, self.scale_matrix_1, self.scale_matrix_2
  436. n, p = M.shape
  437. K = multigamma((nu + n + p - 1)/2, p) * Determinant(Omega)**(-n/2) * Determinant(Sigma)**(-p/2) \
  438. / ((pi)**(n*p/2) * multigamma((nu + p - 1)/2, p))
  439. return K * (Determinant(eye(n) + Inverse(Sigma)*(x - M)*Inverse(Omega)*Transpose(x - M))) \
  440. **(-(nu + n + p -1)/2)
  441. def MatrixStudentT(symbol, nu, location_matrix, scale_matrix_1, scale_matrix_2):
  442. """
  443. Creates a random variable with Matrix Gamma Distribution.
  444. The density of the said distribution can be found at [1].
  445. Parameters
  446. ==========
  447. nu: Positive Real number
  448. degrees of freedom
  449. location_matrix: Positive definite real square matrix
  450. Location Matrix of shape ``n x p``
  451. scale_matrix_1: Positive definite real square matrix
  452. Scale Matrix of shape ``p x p``
  453. scale_matrix_2: Positive definite real square matrix
  454. Scale Matrix of shape ``n x n``
  455. Returns
  456. =======
  457. RandomSymbol
  458. Examples
  459. ========
  460. >>> from sympy import MatrixSymbol,symbols
  461. >>> from sympy.stats import density, MatrixStudentT
  462. >>> v = symbols('v',positive=True)
  463. >>> M = MatrixStudentT('M', v, [[1, 2]], [[1, 0], [0, 1]], [1])
  464. >>> X = MatrixSymbol('X', 1, 2)
  465. >>> density(M)(X)
  466. gamma(v/2 + 1)*Determinant((Matrix([[-1, -2]]) + X)*(Matrix([
  467. [-1],
  468. [-2]]) + X.T) + Matrix([[1]]))**(-v/2 - 1)/(pi**1.0*gamma(v/2)*Determinant(Matrix([[1]]))**1.0*Determinant(Matrix([
  469. [1, 0],
  470. [0, 1]]))**0.5)
  471. References
  472. ==========
  473. .. [1] https://en.wikipedia.org/wiki/Matrix_t-distribution
  474. """
  475. if isinstance(location_matrix, list):
  476. location_matrix = ImmutableMatrix(location_matrix)
  477. if isinstance(scale_matrix_1, list):
  478. scale_matrix_1 = ImmutableMatrix(scale_matrix_1)
  479. if isinstance(scale_matrix_2, list):
  480. scale_matrix_2 = ImmutableMatrix(scale_matrix_2)
  481. args = (nu, location_matrix, scale_matrix_1, scale_matrix_2)
  482. return rv(symbol, MatrixStudentTDistribution, args)