图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

81 lines
2.6 KiB

  1. from functools import singledispatch
  2. from sympy.external import import_module
  3. from sympy.stats.crv_types import BetaDistribution, ChiSquaredDistribution, ExponentialDistribution, GammaDistribution, \
  4. LogNormalDistribution, NormalDistribution, ParetoDistribution, UniformDistribution
  5. from sympy.stats.drv_types import GeometricDistribution, PoissonDistribution, ZetaDistribution
  6. from sympy.stats.frv_types import BinomialDistribution
  7. numpy = import_module('numpy')
  8. @singledispatch
  9. def do_sample_numpy(dist, size, rand_state):
  10. return None
  11. # CRV:
  12. @do_sample_numpy.register(BetaDistribution)
  13. def _(dist: BetaDistribution, size, rand_state):
  14. return rand_state.beta(a=float(dist.alpha), b=float(dist.beta), size=size)
  15. @do_sample_numpy.register(ChiSquaredDistribution)
  16. def _(dist: ChiSquaredDistribution, size, rand_state):
  17. return rand_state.chisquare(df=float(dist.k), size=size)
  18. @do_sample_numpy.register(ExponentialDistribution)
  19. def _(dist: ExponentialDistribution, size, rand_state):
  20. return rand_state.exponential(1 / float(dist.rate), size=size)
  21. @do_sample_numpy.register(GammaDistribution)
  22. def _(dist: GammaDistribution, size, rand_state):
  23. return rand_state.gamma(float(dist.k), float(dist.theta), size=size)
  24. @do_sample_numpy.register(LogNormalDistribution)
  25. def _(dist: LogNormalDistribution, size, rand_state):
  26. return rand_state.lognormal(float(dist.mean), float(dist.std), size=size)
  27. @do_sample_numpy.register(NormalDistribution)
  28. def _(dist: NormalDistribution, size, rand_state):
  29. return rand_state.normal(float(dist.mean), float(dist.std), size=size)
  30. @do_sample_numpy.register(ParetoDistribution)
  31. def _(dist: ParetoDistribution, size, rand_state):
  32. return (numpy.random.pareto(a=float(dist.alpha), size=size) + 1) * float(dist.xm)
  33. @do_sample_numpy.register(UniformDistribution)
  34. def _(dist: UniformDistribution, size, rand_state):
  35. return rand_state.uniform(low=float(dist.left), high=float(dist.right), size=size)
  36. # DRV:
  37. @do_sample_numpy.register(GeometricDistribution)
  38. def _(dist: GeometricDistribution, size, rand_state):
  39. return rand_state.geometric(p=float(dist.p), size=size)
  40. @do_sample_numpy.register(PoissonDistribution)
  41. def _(dist: PoissonDistribution, size, rand_state):
  42. return rand_state.poisson(lam=float(dist.lamda), size=size)
  43. @do_sample_numpy.register(ZetaDistribution)
  44. def _(dist: ZetaDistribution, size, rand_state):
  45. return rand_state.zipf(a=float(dist.s), size=size)
  46. # FRV:
  47. @do_sample_numpy.register(BinomialDistribution)
  48. def _(dist: BinomialDistribution, size, rand_state):
  49. return rand_state.binomial(n=int(dist.n), p=float(dist.p), size=size)