图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

84 lines
1.9 KiB

  1. import numpy as np
  2. import numba as nb
  3. from numpy.random import PCG64
  4. from timeit import timeit
  5. bit_gen = PCG64()
  6. next_d = bit_gen.cffi.next_double
  7. state_addr = bit_gen.cffi.state_address
  8. def normals(n, state):
  9. out = np.empty(n)
  10. for i in range((n + 1) // 2):
  11. x1 = 2.0 * next_d(state) - 1.0
  12. x2 = 2.0 * next_d(state) - 1.0
  13. r2 = x1 * x1 + x2 * x2
  14. while r2 >= 1.0 or r2 == 0.0:
  15. x1 = 2.0 * next_d(state) - 1.0
  16. x2 = 2.0 * next_d(state) - 1.0
  17. r2 = x1 * x1 + x2 * x2
  18. f = np.sqrt(-2.0 * np.log(r2) / r2)
  19. out[2 * i] = f * x1
  20. if 2 * i + 1 < n:
  21. out[2 * i + 1] = f * x2
  22. return out
  23. # Compile using Numba
  24. normalsj = nb.jit(normals, nopython=True)
  25. # Must use state address not state with numba
  26. n = 10000
  27. def numbacall():
  28. return normalsj(n, state_addr)
  29. rg = np.random.Generator(PCG64())
  30. def numpycall():
  31. return rg.normal(size=n)
  32. # Check that the functions work
  33. r1 = numbacall()
  34. r2 = numpycall()
  35. assert r1.shape == (n,)
  36. assert r1.shape == r2.shape
  37. t1 = timeit(numbacall, number=1000)
  38. print(f'{t1:.2f} secs for {n} PCG64 (Numba/PCG64) gaussian randoms')
  39. t2 = timeit(numpycall, number=1000)
  40. print(f'{t2:.2f} secs for {n} PCG64 (NumPy/PCG64) gaussian randoms')
  41. # example 2
  42. next_u32 = bit_gen.ctypes.next_uint32
  43. ctypes_state = bit_gen.ctypes.state
  44. @nb.jit(nopython=True)
  45. def bounded_uint(lb, ub, state):
  46. mask = delta = ub - lb
  47. mask |= mask >> 1
  48. mask |= mask >> 2
  49. mask |= mask >> 4
  50. mask |= mask >> 8
  51. mask |= mask >> 16
  52. val = next_u32(state) & mask
  53. while val > delta:
  54. val = next_u32(state) & mask
  55. return lb + val
  56. print(bounded_uint(323, 2394691, ctypes_state.value))
  57. @nb.jit(nopython=True)
  58. def bounded_uints(lb, ub, n, state):
  59. out = np.empty(n, dtype=np.uint32)
  60. for i in range(n):
  61. out[i] = bounded_uint(lb, ub, state)
  62. bounded_uints(323, 2394691, 10000000, ctypes_state.value)