图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

404 lines
17 KiB

  1. from sympy.concrete.products import Product
  2. from sympy.concrete.summations import Sum
  3. from sympy.core.numbers import (Rational, oo, pi)
  4. from sympy.core.relational import Eq
  5. from sympy.core.singleton import S
  6. from sympy.core.symbol import symbols
  7. from sympy.functions.combinatorial.factorials import (RisingFactorial, factorial)
  8. from sympy.functions.elementary.complexes import polar_lift
  9. from sympy.functions.elementary.exponential import exp
  10. from sympy.functions.elementary.miscellaneous import sqrt
  11. from sympy.functions.elementary.piecewise import Piecewise
  12. from sympy.functions.special.bessel import besselk
  13. from sympy.functions.special.gamma_functions import gamma
  14. from sympy.matrices.dense import eye
  15. from sympy.matrices.expressions.determinant import Determinant
  16. from sympy.sets.fancysets import Range
  17. from sympy.sets.sets import (Interval, ProductSet)
  18. from sympy.simplify.simplify import simplify
  19. from sympy.tensor.indexed import (Indexed, IndexedBase)
  20. from sympy.core.numbers import comp
  21. from sympy.integrals.integrals import integrate
  22. from sympy.matrices import Matrix, MatrixSymbol
  23. from sympy.matrices.expressions.matexpr import MatrixElement
  24. from sympy.stats import density, median, marginal_distribution, Normal, Laplace, E, sample
  25. from sympy.stats.joint_rv_types import (JointRV, MultivariateNormalDistribution,
  26. JointDistributionHandmade, MultivariateT, NormalGamma,
  27. GeneralizedMultivariateLogGammaOmega as GMVLGO, MultivariateBeta,
  28. GeneralizedMultivariateLogGamma as GMVLG, MultivariateEwens,
  29. Multinomial, NegativeMultinomial, MultivariateNormal,
  30. MultivariateLaplace)
  31. from sympy.testing.pytest import raises, XFAIL, skip
  32. from sympy.external import import_module
  33. from sympy.abc import x, y
  34. def test_Normal():
  35. m = Normal('A', [1, 2], [[1, 0], [0, 1]])
  36. A = MultivariateNormal('A', [1, 2], [[1, 0], [0, 1]])
  37. assert m == A
  38. assert density(m)(1, 2) == 1/(2*pi)
  39. assert m.pspace.distribution.set == ProductSet(S.Reals, S.Reals)
  40. raises (ValueError, lambda:m[2])
  41. n = Normal('B', [1, 2, 3], [[1, 0, 0], [0, 1, 0], [0, 0, 1]])
  42. p = Normal('C', Matrix([1, 2]), Matrix([[1, 0], [0, 1]]))
  43. assert density(m)(x, y) == density(p)(x, y)
  44. assert marginal_distribution(n, 0, 1)(1, 2) == 1/(2*pi)
  45. raises(ValueError, lambda: marginal_distribution(m))
  46. assert integrate(density(m)(x, y), (x, -oo, oo), (y, -oo, oo)).evalf() == 1
  47. N = Normal('N', [1, 2], [[x, 0], [0, y]])
  48. assert density(N)(0, 0) == exp(-((4*x + y)/(2*x*y)))/(2*pi*sqrt(x*y))
  49. raises (ValueError, lambda: Normal('M', [1, 2], [[1, 1], [1, -1]]))
  50. # symbolic
  51. n = symbols('n', integer=True, positive=True)
  52. mu = MatrixSymbol('mu', n, 1)
  53. sigma = MatrixSymbol('sigma', n, n)
  54. X = Normal('X', mu, sigma)
  55. assert density(X) == MultivariateNormalDistribution(mu, sigma)
  56. raises (NotImplementedError, lambda: median(m))
  57. # Below tests should work after issue #17267 is resolved
  58. # assert E(X) == mu
  59. # assert variance(X) == sigma
  60. # test symbolic multivariate normal densities
  61. n = 3
  62. Sg = MatrixSymbol('Sg', n, n)
  63. mu = MatrixSymbol('mu', n, 1)
  64. obs = MatrixSymbol('obs', n, 1)
  65. X = MultivariateNormal('X', mu, Sg)
  66. density_X = density(X)
  67. eval_a = density_X(obs).subs({Sg: eye(3),
  68. mu: Matrix([0, 0, 0]), obs: Matrix([0, 0, 0])}).doit()
  69. eval_b = density_X(0, 0, 0).subs({Sg: eye(3), mu: Matrix([0, 0, 0])}).doit()
  70. assert eval_a == sqrt(2)/(4*pi**Rational(3/2))
  71. assert eval_b == sqrt(2)/(4*pi**Rational(3/2))
  72. n = symbols('n', integer=True, positive=True)
  73. Sg = MatrixSymbol('Sg', n, n)
  74. mu = MatrixSymbol('mu', n, 1)
  75. obs = MatrixSymbol('obs', n, 1)
  76. X = MultivariateNormal('X', mu, Sg)
  77. density_X_at_obs = density(X)(obs)
  78. expected_density = MatrixElement(
  79. exp((S(1)/2) * (mu.T - obs.T) * Sg**(-1) * (-mu + obs)) / \
  80. sqrt((2*pi)**n * Determinant(Sg)), 0, 0)
  81. assert density_X_at_obs == expected_density
  82. def test_MultivariateTDist():
  83. t1 = MultivariateT('T', [0, 0], [[1, 0], [0, 1]], 2)
  84. assert(density(t1))(1, 1) == 1/(8*pi)
  85. assert t1.pspace.distribution.set == ProductSet(S.Reals, S.Reals)
  86. assert integrate(density(t1)(x, y), (x, -oo, oo), \
  87. (y, -oo, oo)).evalf() == 1
  88. raises(ValueError, lambda: MultivariateT('T', [1, 2], [[1, 1], [1, -1]], 1))
  89. t2 = MultivariateT('t2', [1, 2], [[x, 0], [0, y]], 1)
  90. assert density(t2)(1, 2) == 1/(2*pi*sqrt(x*y))
  91. def test_multivariate_laplace():
  92. raises(ValueError, lambda: Laplace('T', [1, 2], [[1, 2], [2, 1]]))
  93. L = Laplace('L', [1, 0], [[1, 0], [0, 1]])
  94. L2 = MultivariateLaplace('L2', [1, 0], [[1, 0], [0, 1]])
  95. assert density(L)(2, 3) == exp(2)*besselk(0, sqrt(39))/pi
  96. L1 = Laplace('L1', [1, 2], [[x, 0], [0, y]])
  97. assert density(L1)(0, 1) == \
  98. exp(2/y)*besselk(0, sqrt((2 + 4/y + 1/x)/y))/(pi*sqrt(x*y))
  99. assert L.pspace.distribution.set == ProductSet(S.Reals, S.Reals)
  100. assert L.pspace.distribution == L2.pspace.distribution
  101. def test_NormalGamma():
  102. ng = NormalGamma('G', 1, 2, 3, 4)
  103. assert density(ng)(1, 1) == 32*exp(-4)/sqrt(pi)
  104. assert ng.pspace.distribution.set == ProductSet(S.Reals, Interval(0, oo))
  105. raises(ValueError, lambda:NormalGamma('G', 1, 2, 3, -1))
  106. assert marginal_distribution(ng, 0)(1) == \
  107. 3*sqrt(10)*gamma(Rational(7, 4))/(10*sqrt(pi)*gamma(Rational(5, 4)))
  108. assert marginal_distribution(ng, y)(1) == exp(Rational(-1, 4))/128
  109. assert marginal_distribution(ng,[0,1])(x) == x**2*exp(-x/4)/128
  110. def test_GeneralizedMultivariateLogGammaDistribution():
  111. h = S.Half
  112. omega = Matrix([[1, h, h, h],
  113. [h, 1, h, h],
  114. [h, h, 1, h],
  115. [h, h, h, 1]])
  116. v, l, mu = (4, [1, 2, 3, 4], [1, 2, 3, 4])
  117. y_1, y_2, y_3, y_4 = symbols('y_1:5', real=True)
  118. delta = symbols('d', positive=True)
  119. G = GMVLGO('G', omega, v, l, mu)
  120. Gd = GMVLG('Gd', delta, v, l, mu)
  121. dend = ("d**4*Sum(4*24**(-n - 4)*(1 - d)**n*exp((n + 4)*(y_1 + 2*y_2 + 3*y_3 "
  122. "+ 4*y_4) - exp(y_1) - exp(2*y_2)/2 - exp(3*y_3)/3 - exp(4*y_4)/4)/"
  123. "(gamma(n + 1)*gamma(n + 4)**3), (n, 0, oo))")
  124. assert str(density(Gd)(y_1, y_2, y_3, y_4)) == dend
  125. den = ("5*2**(2/3)*5**(1/3)*Sum(4*24**(-n - 4)*(-2**(2/3)*5**(1/3)/4 + 1)**n*"
  126. "exp((n + 4)*(y_1 + 2*y_2 + 3*y_3 + 4*y_4) - exp(y_1) - exp(2*y_2)/2 - "
  127. "exp(3*y_3)/3 - exp(4*y_4)/4)/(gamma(n + 1)*gamma(n + 4)**3), (n, 0, oo))/64")
  128. assert str(density(G)(y_1, y_2, y_3, y_4)) == den
  129. marg = ("5*2**(2/3)*5**(1/3)*exp(4*y_1)*exp(-exp(y_1))*Integral(exp(-exp(4*G[3])"
  130. "/4)*exp(16*G[3])*Integral(exp(-exp(3*G[2])/3)*exp(12*G[2])*Integral(exp("
  131. "-exp(2*G[1])/2)*exp(8*G[1])*Sum((-1/4)**n*(-4 + 2**(2/3)*5**(1/3"
  132. "))**n*exp(n*y_1)*exp(2*n*G[1])*exp(3*n*G[2])*exp(4*n*G[3])/(24**n*gamma(n + 1)"
  133. "*gamma(n + 4)**3), (n, 0, oo)), (G[1], -oo, oo)), (G[2], -oo, oo)), (G[3]"
  134. ", -oo, oo))/5308416")
  135. assert str(marginal_distribution(G, G[0])(y_1)) == marg
  136. omega_f1 = Matrix([[1, h, h]])
  137. omega_f2 = Matrix([[1, h, h, h],
  138. [h, 1, 2, h],
  139. [h, h, 1, h],
  140. [h, h, h, 1]])
  141. omega_f3 = Matrix([[6, h, h, h],
  142. [h, 1, 2, h],
  143. [h, h, 1, h],
  144. [h, h, h, 1]])
  145. v_f = symbols("v_f", positive=False, real=True)
  146. l_f = [1, 2, v_f, 4]
  147. m_f = [v_f, 2, 3, 4]
  148. omega_f4 = Matrix([[1, h, h, h, h],
  149. [h, 1, h, h, h],
  150. [h, h, 1, h, h],
  151. [h, h, h, 1, h],
  152. [h, h, h, h, 1]])
  153. l_f1 = [1, 2, 3, 4, 5]
  154. omega_f5 = Matrix([[1]])
  155. mu_f5 = l_f5 = [1]
  156. raises(ValueError, lambda: GMVLGO('G', omega_f1, v, l, mu))
  157. raises(ValueError, lambda: GMVLGO('G', omega_f2, v, l, mu))
  158. raises(ValueError, lambda: GMVLGO('G', omega_f3, v, l, mu))
  159. raises(ValueError, lambda: GMVLGO('G', omega, v_f, l, mu))
  160. raises(ValueError, lambda: GMVLGO('G', omega, v, l_f, mu))
  161. raises(ValueError, lambda: GMVLGO('G', omega, v, l, m_f))
  162. raises(ValueError, lambda: GMVLGO('G', omega_f4, v, l, mu))
  163. raises(ValueError, lambda: GMVLGO('G', omega, v, l_f1, mu))
  164. raises(ValueError, lambda: GMVLGO('G', omega_f5, v, l_f5, mu_f5))
  165. raises(ValueError, lambda: GMVLG('G', Rational(3, 2), v, l, mu))
  166. def test_MultivariateBeta():
  167. a1, a2 = symbols('a1, a2', positive=True)
  168. a1_f, a2_f = symbols('a1, a2', positive=False, real=True)
  169. mb = MultivariateBeta('B', [a1, a2])
  170. mb_c = MultivariateBeta('C', a1, a2)
  171. assert density(mb)(1, 2) == S(2)**(a2 - 1)*gamma(a1 + a2)/\
  172. (gamma(a1)*gamma(a2))
  173. assert marginal_distribution(mb_c, 0)(3) == S(3)**(a1 - 1)*gamma(a1 + a2)/\
  174. (a2*gamma(a1)*gamma(a2))
  175. raises(ValueError, lambda: MultivariateBeta('b1', [a1_f, a2]))
  176. raises(ValueError, lambda: MultivariateBeta('b2', [a1, a2_f]))
  177. raises(ValueError, lambda: MultivariateBeta('b3', [0, 0]))
  178. raises(ValueError, lambda: MultivariateBeta('b4', [a1_f, a2_f]))
  179. assert mb.pspace.distribution.set == ProductSet(Interval(0, 1), Interval(0, 1))
  180. def test_MultivariateEwens():
  181. n, theta, i = symbols('n theta i', positive=True)
  182. # tests for integer dimensions
  183. theta_f = symbols('t_f', negative=True)
  184. a = symbols('a_1:4', positive = True, integer = True)
  185. ed = MultivariateEwens('E', 3, theta)
  186. assert density(ed)(a[0], a[1], a[2]) == Piecewise((6*2**(-a[1])*3**(-a[2])*
  187. theta**a[0]*theta**a[1]*theta**a[2]/
  188. (theta*(theta + 1)*(theta + 2)*
  189. factorial(a[0])*factorial(a[1])*
  190. factorial(a[2])), Eq(a[0] + 2*a[1] +
  191. 3*a[2], 3)), (0, True))
  192. assert marginal_distribution(ed, ed[1])(a[1]) == Piecewise((6*2**(-a[1])*
  193. theta**a[1]/((theta + 1)*
  194. (theta + 2)*factorial(a[1])),
  195. Eq(2*a[1] + 1, 3)), (0, True))
  196. raises(ValueError, lambda: MultivariateEwens('e1', 5, theta_f))
  197. assert ed.pspace.distribution.set == ProductSet(Range(0, 4, 1),
  198. Range(0, 2, 1), Range(0, 2, 1))
  199. # tests for symbolic dimensions
  200. eds = MultivariateEwens('E', n, theta)
  201. a = IndexedBase('a')
  202. j, k = symbols('j, k')
  203. den = Piecewise((factorial(n)*Product(theta**a[j]*(j + 1)**(-a[j])/
  204. factorial(a[j]), (j, 0, n - 1))/RisingFactorial(theta, n),
  205. Eq(n, Sum((k + 1)*a[k], (k, 0, n - 1)))), (0, True))
  206. assert density(eds)(a).dummy_eq(den)
  207. def test_Multinomial():
  208. n, x1, x2, x3, x4 = symbols('n, x1, x2, x3, x4', nonnegative=True, integer=True)
  209. p1, p2, p3, p4 = symbols('p1, p2, p3, p4', positive=True)
  210. p1_f, n_f = symbols('p1_f, n_f', negative=True)
  211. M = Multinomial('M', n, [p1, p2, p3, p4])
  212. C = Multinomial('C', 3, p1, p2, p3)
  213. f = factorial
  214. assert density(M)(x1, x2, x3, x4) == Piecewise((p1**x1*p2**x2*p3**x3*p4**x4*
  215. f(n)/(f(x1)*f(x2)*f(x3)*f(x4)),
  216. Eq(n, x1 + x2 + x3 + x4)), (0, True))
  217. assert marginal_distribution(C, C[0])(x1).subs(x1, 1) ==\
  218. 3*p1*p2**2 +\
  219. 6*p1*p2*p3 +\
  220. 3*p1*p3**2
  221. raises(ValueError, lambda: Multinomial('b1', 5, [p1, p2, p3, p1_f]))
  222. raises(ValueError, lambda: Multinomial('b2', n_f, [p1, p2, p3, p4]))
  223. raises(ValueError, lambda: Multinomial('b3', n, 0.5, 0.4, 0.3, 0.1))
  224. def test_NegativeMultinomial():
  225. k0, x1, x2, x3, x4 = symbols('k0, x1, x2, x3, x4', nonnegative=True, integer=True)
  226. p1, p2, p3, p4 = symbols('p1, p2, p3, p4', positive=True)
  227. p1_f = symbols('p1_f', negative=True)
  228. N = NegativeMultinomial('N', 4, [p1, p2, p3, p4])
  229. C = NegativeMultinomial('C', 4, 0.1, 0.2, 0.3)
  230. g = gamma
  231. f = factorial
  232. assert simplify(density(N)(x1, x2, x3, x4) -
  233. p1**x1*p2**x2*p3**x3*p4**x4*(-p1 - p2 - p3 - p4 + 1)**4*g(x1 + x2 +
  234. x3 + x4 + 4)/(6*f(x1)*f(x2)*f(x3)*f(x4))) is S.Zero
  235. assert comp(marginal_distribution(C, C[0])(1).evalf(), 0.33, .01)
  236. raises(ValueError, lambda: NegativeMultinomial('b1', 5, [p1, p2, p3, p1_f]))
  237. raises(ValueError, lambda: NegativeMultinomial('b2', k0, 0.5, 0.4, 0.3, 0.4))
  238. assert N.pspace.distribution.set == ProductSet(Range(0, oo, 1),
  239. Range(0, oo, 1), Range(0, oo, 1), Range(0, oo, 1))
  240. def test_JointPSpace_marginal_distribution():
  241. T = MultivariateT('T', [0, 0], [[1, 0], [0, 1]], 2)
  242. got = marginal_distribution(T, T[1])(x)
  243. ans = sqrt(2)*(x**2/2 + 1)/(4*polar_lift(x**2/2 + 1)**(S(5)/2))
  244. assert got == ans, got
  245. assert integrate(marginal_distribution(T, 1)(x), (x, -oo, oo)) == 1
  246. t = MultivariateT('T', [0, 0, 0], [[1, 0, 0], [0, 1, 0], [0, 0, 1]], 3)
  247. assert comp(marginal_distribution(t, 0)(1).evalf(), 0.2, .01)
  248. def test_JointRV():
  249. x1, x2 = (Indexed('x', i) for i in (1, 2))
  250. pdf = exp(-x1**2/2 + x1 - x2**2/2 - S.Half)/(2*pi)
  251. X = JointRV('x', pdf)
  252. assert density(X)(1, 2) == exp(-2)/(2*pi)
  253. assert isinstance(X.pspace.distribution, JointDistributionHandmade)
  254. assert marginal_distribution(X, 0)(2) == sqrt(2)*exp(Rational(-1, 2))/(2*sqrt(pi))
  255. def test_expectation():
  256. m = Normal('A', [x, y], [[1, 0], [0, 1]])
  257. assert simplify(E(m[1])) == y
  258. @XFAIL
  259. def test_joint_vector_expectation():
  260. m = Normal('A', [x, y], [[1, 0], [0, 1]])
  261. assert E(m) == (x, y)
  262. def test_sample_numpy():
  263. distribs_numpy = [
  264. MultivariateNormal("M", [3, 4], [[2, 1], [1, 2]]),
  265. MultivariateBeta("B", [0.4, 5, 15, 50, 203]),
  266. Multinomial("N", 50, [0.3, 0.2, 0.1, 0.25, 0.15])
  267. ]
  268. size = 3
  269. numpy = import_module('numpy')
  270. if not numpy:
  271. skip('Numpy is not installed. Abort tests for _sample_numpy.')
  272. else:
  273. for X in distribs_numpy:
  274. samps = sample(X, size=size, library='numpy')
  275. for sam in samps:
  276. assert tuple(sam) in X.pspace.distribution.set
  277. N_c = NegativeMultinomial('N', 3, 0.1, 0.1, 0.1)
  278. raises(NotImplementedError, lambda: sample(N_c, library='numpy'))
  279. def test_sample_scipy():
  280. distribs_scipy = [
  281. MultivariateNormal("M", [0, 0], [[0.1, 0.025], [0.025, 0.1]]),
  282. MultivariateBeta("B", [0.4, 5, 15]),
  283. Multinomial("N", 8, [0.3, 0.2, 0.1, 0.4])
  284. ]
  285. size = 3
  286. scipy = import_module('scipy')
  287. if not scipy:
  288. skip('Scipy not installed. Abort tests for _sample_scipy.')
  289. else:
  290. for X in distribs_scipy:
  291. samps = sample(X, size=size)
  292. samps2 = sample(X, size=(2, 2))
  293. for sam in samps:
  294. assert tuple(sam) in X.pspace.distribution.set
  295. for i in range(2):
  296. for j in range(2):
  297. assert tuple(samps2[i][j]) in X.pspace.distribution.set
  298. N_c = NegativeMultinomial('N', 3, 0.1, 0.1, 0.1)
  299. raises(NotImplementedError, lambda: sample(N_c))
  300. def test_sample_pymc3():
  301. distribs_pymc3 = [
  302. MultivariateNormal("M", [5, 2], [[1, 0], [0, 1]]),
  303. MultivariateBeta("B", [0.4, 5, 15]),
  304. Multinomial("N", 4, [0.3, 0.2, 0.1, 0.4])
  305. ]
  306. size = 3
  307. pymc3 = import_module('pymc3')
  308. if not pymc3:
  309. skip('PyMC3 is not installed. Abort tests for _sample_pymc3.')
  310. else:
  311. for X in distribs_pymc3:
  312. samps = sample(X, size=size, library='pymc3')
  313. for sam in samps:
  314. assert tuple(sam.flatten()) in X.pspace.distribution.set
  315. N_c = NegativeMultinomial('N', 3, 0.1, 0.1, 0.1)
  316. raises(NotImplementedError, lambda: sample(N_c, library='pymc3'))
  317. def test_sample_seed():
  318. x1, x2 = (Indexed('x', i) for i in (1, 2))
  319. pdf = exp(-x1**2/2 + x1 - x2**2/2 - S.Half)/(2*pi)
  320. X = JointRV('x', pdf)
  321. libraries = ['scipy', 'numpy', 'pymc3']
  322. for lib in libraries:
  323. try:
  324. imported_lib = import_module(lib)
  325. if imported_lib:
  326. s0, s1, s2 = [], [], []
  327. s0 = sample(X, size=10, library=lib, seed=0)
  328. s1 = sample(X, size=10, library=lib, seed=0)
  329. s2 = sample(X, size=10, library=lib, seed=1)
  330. assert all(s0 == s1)
  331. assert all(s1 != s2)
  332. except NotImplementedError:
  333. continue
  334. def test_issue_21057():
  335. m = Normal("x", [0, 0], [[0, 0], [0, 0]])
  336. n = MultivariateNormal("x", [0, 0], [[0, 0], [0, 0]])
  337. p = Normal("x", [0, 0], [[0, 0], [0, 1]])
  338. assert m == n
  339. libraries = ['scipy', 'numpy', 'pymc3']
  340. for library in libraries:
  341. try:
  342. imported_lib = import_module(library)
  343. if imported_lib:
  344. s1 = sample(m, size=8, library=library)
  345. s2 = sample(n, size=8, library=library)
  346. s3 = sample(p, size=8, library=library)
  347. assert tuple(s1.flatten()) == tuple(s2.flatten())
  348. for s in s3:
  349. assert tuple(s.flatten()) in p.pspace.distribution.set
  350. except NotImplementedError:
  351. continue