图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

167 lines
6.2 KiB

  1. from functools import singledispatch
  2. from sympy.core.symbol import Dummy
  3. from sympy.functions.elementary.exponential import exp
  4. from sympy.utilities.lambdify import lambdify
  5. from sympy.external import import_module
  6. from sympy.stats import DiscreteDistributionHandmade
  7. from sympy.stats.crv import SingleContinuousDistribution
  8. from sympy.stats.crv_types import ChiSquaredDistribution, ExponentialDistribution, GammaDistribution, \
  9. LogNormalDistribution, NormalDistribution, ParetoDistribution, UniformDistribution, BetaDistribution, \
  10. StudentTDistribution, CauchyDistribution
  11. from sympy.stats.drv_types import GeometricDistribution, LogarithmicDistribution, NegativeBinomialDistribution, \
  12. PoissonDistribution, SkellamDistribution, YuleSimonDistribution, ZetaDistribution
  13. from sympy.stats.frv import SingleFiniteDistribution
  14. scipy = import_module("scipy", import_kwargs={'fromlist':['stats']})
  15. @singledispatch
  16. def do_sample_scipy(dist, size, seed):
  17. return None
  18. # CRV
  19. @do_sample_scipy.register(SingleContinuousDistribution)
  20. def _(dist: SingleContinuousDistribution, size, seed):
  21. # if we don't need to make a handmade pdf, we won't
  22. import scipy.stats
  23. z = Dummy('z')
  24. handmade_pdf = lambdify(z, dist.pdf(z), ['numpy', 'scipy'])
  25. class scipy_pdf(scipy.stats.rv_continuous):
  26. def _pdf(dist, x):
  27. return handmade_pdf(x)
  28. scipy_rv = scipy_pdf(a=float(dist.set._inf),
  29. b=float(dist.set._sup), name='scipy_pdf')
  30. return scipy_rv.rvs(size=size, random_state=seed)
  31. @do_sample_scipy.register(ChiSquaredDistribution)
  32. def _(dist: ChiSquaredDistribution, size, seed):
  33. # same parametrisation
  34. return scipy.stats.chi2.rvs(df=float(dist.k), size=size, random_state=seed)
  35. @do_sample_scipy.register(ExponentialDistribution)
  36. def _(dist: ExponentialDistribution, size, seed):
  37. # https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.expon.html#scipy.stats.expon
  38. return scipy.stats.expon.rvs(scale=1 / float(dist.rate), size=size, random_state=seed)
  39. @do_sample_scipy.register(GammaDistribution)
  40. def _(dist: GammaDistribution, size, seed):
  41. # https://stackoverflow.com/questions/42150965/how-to-plot-gamma-distribution-with-alpha-and-beta-parameters-in-python
  42. return scipy.stats.gamma.rvs(a=float(dist.k), scale=float(dist.theta), size=size, random_state=seed)
  43. @do_sample_scipy.register(LogNormalDistribution)
  44. def _(dist: LogNormalDistribution, size, seed):
  45. # https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.lognorm.html
  46. return scipy.stats.lognorm.rvs(scale=float(exp(dist.mean)), s=float(dist.std), size=size, random_state=seed)
  47. @do_sample_scipy.register(NormalDistribution)
  48. def _(dist: NormalDistribution, size, seed):
  49. return scipy.stats.norm.rvs(loc=float(dist.mean), scale=float(dist.std), size=size, random_state=seed)
  50. @do_sample_scipy.register(ParetoDistribution)
  51. def _(dist: ParetoDistribution, size, seed):
  52. # https://stackoverflow.com/questions/42260519/defining-pareto-distribution-in-python-scipy
  53. return scipy.stats.pareto.rvs(b=float(dist.alpha), scale=float(dist.xm), size=size, random_state=seed)
  54. @do_sample_scipy.register(StudentTDistribution)
  55. def _(dist: StudentTDistribution, size, seed):
  56. return scipy.stats.t.rvs(df=float(dist.nu), size=size, random_state=seed)
  57. @do_sample_scipy.register(UniformDistribution)
  58. def _(dist: UniformDistribution, size, seed):
  59. # https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.uniform.html
  60. return scipy.stats.uniform.rvs(loc=float(dist.left), scale=float(dist.right - dist.left), size=size, random_state=seed)
  61. @do_sample_scipy.register(BetaDistribution)
  62. def _(dist: BetaDistribution, size, seed):
  63. # same parametrisation
  64. return scipy.stats.beta.rvs(a=float(dist.alpha), b=float(dist.beta), size=size, random_state=seed)
  65. @do_sample_scipy.register(CauchyDistribution)
  66. def _(dist: CauchyDistribution, size, seed):
  67. return scipy.stats.cauchy.rvs(loc=float(dist.x0), scale=float(dist.gamma), size=size, random_state=seed)
  68. # DRV:
  69. @do_sample_scipy.register(DiscreteDistributionHandmade)
  70. def _(dist: DiscreteDistributionHandmade, size, seed):
  71. from scipy.stats import rv_discrete
  72. z = Dummy('z')
  73. handmade_pmf = lambdify(z, dist.pdf(z), ['numpy', 'scipy'])
  74. class scipy_pmf(rv_discrete):
  75. def _pmf(dist, x):
  76. return handmade_pmf(x)
  77. scipy_rv = scipy_pmf(a=float(dist.set._inf), b=float(dist.set._sup),
  78. name='scipy_pmf')
  79. return scipy_rv.rvs(size=size, random_state=seed)
  80. @do_sample_scipy.register(GeometricDistribution)
  81. def _(dist: GeometricDistribution, size, seed):
  82. return scipy.stats.geom.rvs(p=float(dist.p), size=size, random_state=seed)
  83. @do_sample_scipy.register(LogarithmicDistribution)
  84. def _(dist: LogarithmicDistribution, size, seed):
  85. return scipy.stats.logser.rvs(p=float(dist.p), size=size, random_state=seed)
  86. @do_sample_scipy.register(NegativeBinomialDistribution)
  87. def _(dist: NegativeBinomialDistribution, size, seed):
  88. return scipy.stats.nbinom.rvs(n=float(dist.r), p=float(dist.p), size=size, random_state=seed)
  89. @do_sample_scipy.register(PoissonDistribution)
  90. def _(dist: PoissonDistribution, size, seed):
  91. return scipy.stats.poisson.rvs(mu=float(dist.lamda), size=size, random_state=seed)
  92. @do_sample_scipy.register(SkellamDistribution)
  93. def _(dist: SkellamDistribution, size, seed):
  94. return scipy.stats.skellam.rvs(mu1=float(dist.mu1), mu2=float(dist.mu2), size=size, random_state=seed)
  95. @do_sample_scipy.register(YuleSimonDistribution)
  96. def _(dist: YuleSimonDistribution, size, seed):
  97. return scipy.stats.yulesimon.rvs(alpha=float(dist.rho), size=size, random_state=seed)
  98. @do_sample_scipy.register(ZetaDistribution)
  99. def _(dist: ZetaDistribution, size, seed):
  100. return scipy.stats.zipf.rvs(a=float(dist.s), size=size, random_state=seed)
  101. # FRV:
  102. @do_sample_scipy.register(SingleFiniteDistribution)
  103. def _(dist: SingleFiniteDistribution, size, seed):
  104. # scipy can handle with custom distributions
  105. from scipy.stats import rv_discrete
  106. density_ = dist.dict
  107. x, y = [], []
  108. for k, v in density_.items():
  109. x.append(int(k))
  110. y.append(float(v))
  111. scipy_rv = rv_discrete(name='scipy_rv', values=(x, y))
  112. return scipy_rv.rvs(size=size, random_state=seed)