图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

475 lines
16 KiB

  1. import os
  2. from os.path import join
  3. import sys
  4. import numpy as np
  5. from numpy.testing import (assert_equal, assert_allclose, assert_array_equal,
  6. assert_raises)
  7. import pytest
  8. from numpy.random import (
  9. Generator, MT19937, PCG64, PCG64DXSM, Philox, RandomState, SeedSequence,
  10. SFC64, default_rng
  11. )
  12. from numpy.random._common import interface
  13. try:
  14. import cffi # noqa: F401
  15. MISSING_CFFI = False
  16. except ImportError:
  17. MISSING_CFFI = True
  18. try:
  19. import ctypes # noqa: F401
  20. MISSING_CTYPES = False
  21. except ImportError:
  22. MISSING_CTYPES = False
  23. if sys.flags.optimize > 1:
  24. # no docstrings present to inspect when PYTHONOPTIMIZE/Py_OptimizeFlag > 1
  25. # cffi cannot succeed
  26. MISSING_CFFI = True
  27. pwd = os.path.dirname(os.path.abspath(__file__))
  28. def assert_state_equal(actual, target):
  29. for key in actual:
  30. if isinstance(actual[key], dict):
  31. assert_state_equal(actual[key], target[key])
  32. elif isinstance(actual[key], np.ndarray):
  33. assert_array_equal(actual[key], target[key])
  34. else:
  35. assert actual[key] == target[key]
  36. def uniform32_from_uint64(x):
  37. x = np.uint64(x)
  38. upper = np.array(x >> np.uint64(32), dtype=np.uint32)
  39. lower = np.uint64(0xffffffff)
  40. lower = np.array(x & lower, dtype=np.uint32)
  41. joined = np.column_stack([lower, upper]).ravel()
  42. out = (joined >> np.uint32(9)) * (1.0 / 2 ** 23)
  43. return out.astype(np.float32)
  44. def uniform32_from_uint53(x):
  45. x = np.uint64(x) >> np.uint64(16)
  46. x = np.uint32(x & np.uint64(0xffffffff))
  47. out = (x >> np.uint32(9)) * (1.0 / 2 ** 23)
  48. return out.astype(np.float32)
  49. def uniform32_from_uint32(x):
  50. return (x >> np.uint32(9)) * (1.0 / 2 ** 23)
  51. def uniform32_from_uint(x, bits):
  52. if bits == 64:
  53. return uniform32_from_uint64(x)
  54. elif bits == 53:
  55. return uniform32_from_uint53(x)
  56. elif bits == 32:
  57. return uniform32_from_uint32(x)
  58. else:
  59. raise NotImplementedError
  60. def uniform_from_uint(x, bits):
  61. if bits in (64, 63, 53):
  62. return uniform_from_uint64(x)
  63. elif bits == 32:
  64. return uniform_from_uint32(x)
  65. def uniform_from_uint64(x):
  66. return (x >> np.uint64(11)) * (1.0 / 9007199254740992.0)
  67. def uniform_from_uint32(x):
  68. out = np.empty(len(x) // 2)
  69. for i in range(0, len(x), 2):
  70. a = x[i] >> 5
  71. b = x[i + 1] >> 6
  72. out[i // 2] = (a * 67108864.0 + b) / 9007199254740992.0
  73. return out
  74. def uniform_from_dsfmt(x):
  75. return x.view(np.double) - 1.0
  76. def gauss_from_uint(x, n, bits):
  77. if bits in (64, 63):
  78. doubles = uniform_from_uint64(x)
  79. elif bits == 32:
  80. doubles = uniform_from_uint32(x)
  81. else: # bits == 'dsfmt'
  82. doubles = uniform_from_dsfmt(x)
  83. gauss = []
  84. loc = 0
  85. x1 = x2 = 0.0
  86. while len(gauss) < n:
  87. r2 = 2
  88. while r2 >= 1.0 or r2 == 0.0:
  89. x1 = 2.0 * doubles[loc] - 1.0
  90. x2 = 2.0 * doubles[loc + 1] - 1.0
  91. r2 = x1 * x1 + x2 * x2
  92. loc += 2
  93. f = np.sqrt(-2.0 * np.log(r2) / r2)
  94. gauss.append(f * x2)
  95. gauss.append(f * x1)
  96. return gauss[:n]
  97. def test_seedsequence():
  98. from numpy.random.bit_generator import (ISeedSequence,
  99. ISpawnableSeedSequence,
  100. SeedlessSeedSequence)
  101. s1 = SeedSequence(range(10), spawn_key=(1, 2), pool_size=6)
  102. s1.spawn(10)
  103. s2 = SeedSequence(**s1.state)
  104. assert_equal(s1.state, s2.state)
  105. assert_equal(s1.n_children_spawned, s2.n_children_spawned)
  106. # The interfaces cannot be instantiated themselves.
  107. assert_raises(TypeError, ISeedSequence)
  108. assert_raises(TypeError, ISpawnableSeedSequence)
  109. dummy = SeedlessSeedSequence()
  110. assert_raises(NotImplementedError, dummy.generate_state, 10)
  111. assert len(dummy.spawn(10)) == 10
  112. class Base:
  113. dtype = np.uint64
  114. data2 = data1 = {}
  115. @classmethod
  116. def setup_class(cls):
  117. cls.bit_generator = PCG64
  118. cls.bits = 64
  119. cls.dtype = np.uint64
  120. cls.seed_error_type = TypeError
  121. cls.invalid_init_types = []
  122. cls.invalid_init_values = []
  123. @classmethod
  124. def _read_csv(cls, filename):
  125. with open(filename) as csv:
  126. seed = csv.readline()
  127. seed = seed.split(',')
  128. seed = [int(s.strip(), 0) for s in seed[1:]]
  129. data = []
  130. for line in csv:
  131. data.append(int(line.split(',')[-1].strip(), 0))
  132. return {'seed': seed, 'data': np.array(data, dtype=cls.dtype)}
  133. def test_raw(self):
  134. bit_generator = self.bit_generator(*self.data1['seed'])
  135. uints = bit_generator.random_raw(1000)
  136. assert_equal(uints, self.data1['data'])
  137. bit_generator = self.bit_generator(*self.data1['seed'])
  138. uints = bit_generator.random_raw()
  139. assert_equal(uints, self.data1['data'][0])
  140. bit_generator = self.bit_generator(*self.data2['seed'])
  141. uints = bit_generator.random_raw(1000)
  142. assert_equal(uints, self.data2['data'])
  143. def test_random_raw(self):
  144. bit_generator = self.bit_generator(*self.data1['seed'])
  145. uints = bit_generator.random_raw(output=False)
  146. assert uints is None
  147. uints = bit_generator.random_raw(1000, output=False)
  148. assert uints is None
  149. def test_gauss_inv(self):
  150. n = 25
  151. rs = RandomState(self.bit_generator(*self.data1['seed']))
  152. gauss = rs.standard_normal(n)
  153. assert_allclose(gauss,
  154. gauss_from_uint(self.data1['data'], n, self.bits))
  155. rs = RandomState(self.bit_generator(*self.data2['seed']))
  156. gauss = rs.standard_normal(25)
  157. assert_allclose(gauss,
  158. gauss_from_uint(self.data2['data'], n, self.bits))
  159. def test_uniform_double(self):
  160. rs = Generator(self.bit_generator(*self.data1['seed']))
  161. vals = uniform_from_uint(self.data1['data'], self.bits)
  162. uniforms = rs.random(len(vals))
  163. assert_allclose(uniforms, vals)
  164. assert_equal(uniforms.dtype, np.float64)
  165. rs = Generator(self.bit_generator(*self.data2['seed']))
  166. vals = uniform_from_uint(self.data2['data'], self.bits)
  167. uniforms = rs.random(len(vals))
  168. assert_allclose(uniforms, vals)
  169. assert_equal(uniforms.dtype, np.float64)
  170. def test_uniform_float(self):
  171. rs = Generator(self.bit_generator(*self.data1['seed']))
  172. vals = uniform32_from_uint(self.data1['data'], self.bits)
  173. uniforms = rs.random(len(vals), dtype=np.float32)
  174. assert_allclose(uniforms, vals)
  175. assert_equal(uniforms.dtype, np.float32)
  176. rs = Generator(self.bit_generator(*self.data2['seed']))
  177. vals = uniform32_from_uint(self.data2['data'], self.bits)
  178. uniforms = rs.random(len(vals), dtype=np.float32)
  179. assert_allclose(uniforms, vals)
  180. assert_equal(uniforms.dtype, np.float32)
  181. def test_repr(self):
  182. rs = Generator(self.bit_generator(*self.data1['seed']))
  183. assert 'Generator' in repr(rs)
  184. assert f'{id(rs):#x}'.upper().replace('X', 'x') in repr(rs)
  185. def test_str(self):
  186. rs = Generator(self.bit_generator(*self.data1['seed']))
  187. assert 'Generator' in str(rs)
  188. assert str(self.bit_generator.__name__) in str(rs)
  189. assert f'{id(rs):#x}'.upper().replace('X', 'x') not in str(rs)
  190. def test_pickle(self):
  191. import pickle
  192. bit_generator = self.bit_generator(*self.data1['seed'])
  193. state = bit_generator.state
  194. bitgen_pkl = pickle.dumps(bit_generator)
  195. reloaded = pickle.loads(bitgen_pkl)
  196. reloaded_state = reloaded.state
  197. assert_array_equal(Generator(bit_generator).standard_normal(1000),
  198. Generator(reloaded).standard_normal(1000))
  199. assert bit_generator is not reloaded
  200. assert_state_equal(reloaded_state, state)
  201. ss = SeedSequence(100)
  202. aa = pickle.loads(pickle.dumps(ss))
  203. assert_equal(ss.state, aa.state)
  204. def test_invalid_state_type(self):
  205. bit_generator = self.bit_generator(*self.data1['seed'])
  206. with pytest.raises(TypeError):
  207. bit_generator.state = {'1'}
  208. def test_invalid_state_value(self):
  209. bit_generator = self.bit_generator(*self.data1['seed'])
  210. state = bit_generator.state
  211. state['bit_generator'] = 'otherBitGenerator'
  212. with pytest.raises(ValueError):
  213. bit_generator.state = state
  214. def test_invalid_init_type(self):
  215. bit_generator = self.bit_generator
  216. for st in self.invalid_init_types:
  217. with pytest.raises(TypeError):
  218. bit_generator(*st)
  219. def test_invalid_init_values(self):
  220. bit_generator = self.bit_generator
  221. for st in self.invalid_init_values:
  222. with pytest.raises((ValueError, OverflowError)):
  223. bit_generator(*st)
  224. def test_benchmark(self):
  225. bit_generator = self.bit_generator(*self.data1['seed'])
  226. bit_generator._benchmark(1)
  227. bit_generator._benchmark(1, 'double')
  228. with pytest.raises(ValueError):
  229. bit_generator._benchmark(1, 'int32')
  230. @pytest.mark.skipif(MISSING_CFFI, reason='cffi not available')
  231. def test_cffi(self):
  232. bit_generator = self.bit_generator(*self.data1['seed'])
  233. cffi_interface = bit_generator.cffi
  234. assert isinstance(cffi_interface, interface)
  235. other_cffi_interface = bit_generator.cffi
  236. assert other_cffi_interface is cffi_interface
  237. @pytest.mark.skipif(MISSING_CTYPES, reason='ctypes not available')
  238. def test_ctypes(self):
  239. bit_generator = self.bit_generator(*self.data1['seed'])
  240. ctypes_interface = bit_generator.ctypes
  241. assert isinstance(ctypes_interface, interface)
  242. other_ctypes_interface = bit_generator.ctypes
  243. assert other_ctypes_interface is ctypes_interface
  244. def test_getstate(self):
  245. bit_generator = self.bit_generator(*self.data1['seed'])
  246. state = bit_generator.state
  247. alt_state = bit_generator.__getstate__()
  248. assert_state_equal(state, alt_state)
  249. class TestPhilox(Base):
  250. @classmethod
  251. def setup_class(cls):
  252. cls.bit_generator = Philox
  253. cls.bits = 64
  254. cls.dtype = np.uint64
  255. cls.data1 = cls._read_csv(
  256. join(pwd, './data/philox-testset-1.csv'))
  257. cls.data2 = cls._read_csv(
  258. join(pwd, './data/philox-testset-2.csv'))
  259. cls.seed_error_type = TypeError
  260. cls.invalid_init_types = []
  261. cls.invalid_init_values = [(1, None, 1), (-1,), (None, None, 2 ** 257 + 1)]
  262. def test_set_key(self):
  263. bit_generator = self.bit_generator(*self.data1['seed'])
  264. state = bit_generator.state
  265. keyed = self.bit_generator(counter=state['state']['counter'],
  266. key=state['state']['key'])
  267. assert_state_equal(bit_generator.state, keyed.state)
  268. class TestPCG64(Base):
  269. @classmethod
  270. def setup_class(cls):
  271. cls.bit_generator = PCG64
  272. cls.bits = 64
  273. cls.dtype = np.uint64
  274. cls.data1 = cls._read_csv(join(pwd, './data/pcg64-testset-1.csv'))
  275. cls.data2 = cls._read_csv(join(pwd, './data/pcg64-testset-2.csv'))
  276. cls.seed_error_type = (ValueError, TypeError)
  277. cls.invalid_init_types = [(3.2,), ([None],), (1, None)]
  278. cls.invalid_init_values = [(-1,)]
  279. def test_advance_symmetry(self):
  280. rs = Generator(self.bit_generator(*self.data1['seed']))
  281. state = rs.bit_generator.state
  282. step = -0x9e3779b97f4a7c150000000000000000
  283. rs.bit_generator.advance(step)
  284. val_neg = rs.integers(10)
  285. rs.bit_generator.state = state
  286. rs.bit_generator.advance(2**128 + step)
  287. val_pos = rs.integers(10)
  288. rs.bit_generator.state = state
  289. rs.bit_generator.advance(10 * 2**128 + step)
  290. val_big = rs.integers(10)
  291. assert val_neg == val_pos
  292. assert val_big == val_pos
  293. def test_advange_large(self):
  294. rs = Generator(self.bit_generator(38219308213743))
  295. pcg = rs.bit_generator
  296. state = pcg.state["state"]
  297. initial_state = 287608843259529770491897792873167516365
  298. assert state["state"] == initial_state
  299. pcg.advance(sum(2**i for i in (96, 64, 32, 16, 8, 4, 2, 1)))
  300. state = pcg.state["state"]
  301. advanced_state = 135275564607035429730177404003164635391
  302. assert state["state"] == advanced_state
  303. class TestPCG64DXSM(Base):
  304. @classmethod
  305. def setup_class(cls):
  306. cls.bit_generator = PCG64DXSM
  307. cls.bits = 64
  308. cls.dtype = np.uint64
  309. cls.data1 = cls._read_csv(join(pwd, './data/pcg64dxsm-testset-1.csv'))
  310. cls.data2 = cls._read_csv(join(pwd, './data/pcg64dxsm-testset-2.csv'))
  311. cls.seed_error_type = (ValueError, TypeError)
  312. cls.invalid_init_types = [(3.2,), ([None],), (1, None)]
  313. cls.invalid_init_values = [(-1,)]
  314. def test_advance_symmetry(self):
  315. rs = Generator(self.bit_generator(*self.data1['seed']))
  316. state = rs.bit_generator.state
  317. step = -0x9e3779b97f4a7c150000000000000000
  318. rs.bit_generator.advance(step)
  319. val_neg = rs.integers(10)
  320. rs.bit_generator.state = state
  321. rs.bit_generator.advance(2**128 + step)
  322. val_pos = rs.integers(10)
  323. rs.bit_generator.state = state
  324. rs.bit_generator.advance(10 * 2**128 + step)
  325. val_big = rs.integers(10)
  326. assert val_neg == val_pos
  327. assert val_big == val_pos
  328. def test_advange_large(self):
  329. rs = Generator(self.bit_generator(38219308213743))
  330. pcg = rs.bit_generator
  331. state = pcg.state
  332. initial_state = 287608843259529770491897792873167516365
  333. assert state["state"]["state"] == initial_state
  334. pcg.advance(sum(2**i for i in (96, 64, 32, 16, 8, 4, 2, 1)))
  335. state = pcg.state["state"]
  336. advanced_state = 277778083536782149546677086420637664879
  337. assert state["state"] == advanced_state
  338. class TestMT19937(Base):
  339. @classmethod
  340. def setup_class(cls):
  341. cls.bit_generator = MT19937
  342. cls.bits = 32
  343. cls.dtype = np.uint32
  344. cls.data1 = cls._read_csv(join(pwd, './data/mt19937-testset-1.csv'))
  345. cls.data2 = cls._read_csv(join(pwd, './data/mt19937-testset-2.csv'))
  346. cls.seed_error_type = ValueError
  347. cls.invalid_init_types = []
  348. cls.invalid_init_values = [(-1,)]
  349. def test_seed_float_array(self):
  350. assert_raises(TypeError, self.bit_generator, np.array([np.pi]))
  351. assert_raises(TypeError, self.bit_generator, np.array([-np.pi]))
  352. assert_raises(TypeError, self.bit_generator, np.array([np.pi, -np.pi]))
  353. assert_raises(TypeError, self.bit_generator, np.array([0, np.pi]))
  354. assert_raises(TypeError, self.bit_generator, [np.pi])
  355. assert_raises(TypeError, self.bit_generator, [0, np.pi])
  356. def test_state_tuple(self):
  357. rs = Generator(self.bit_generator(*self.data1['seed']))
  358. bit_generator = rs.bit_generator
  359. state = bit_generator.state
  360. desired = rs.integers(2 ** 16)
  361. tup = (state['bit_generator'], state['state']['key'],
  362. state['state']['pos'])
  363. bit_generator.state = tup
  364. actual = rs.integers(2 ** 16)
  365. assert_equal(actual, desired)
  366. tup = tup + (0, 0.0)
  367. bit_generator.state = tup
  368. actual = rs.integers(2 ** 16)
  369. assert_equal(actual, desired)
  370. class TestSFC64(Base):
  371. @classmethod
  372. def setup_class(cls):
  373. cls.bit_generator = SFC64
  374. cls.bits = 64
  375. cls.dtype = np.uint64
  376. cls.data1 = cls._read_csv(
  377. join(pwd, './data/sfc64-testset-1.csv'))
  378. cls.data2 = cls._read_csv(
  379. join(pwd, './data/sfc64-testset-2.csv'))
  380. cls.seed_error_type = (ValueError, TypeError)
  381. cls.invalid_init_types = [(3.2,), ([None],), (1, None)]
  382. cls.invalid_init_values = [(-1,)]
  383. class TestDefaultRNG:
  384. def test_seed(self):
  385. for args in [(), (None,), (1234,), ([1234, 5678],)]:
  386. rg = default_rng(*args)
  387. assert isinstance(rg.bit_generator, PCG64)
  388. def test_passthrough(self):
  389. bg = Philox()
  390. rg = default_rng(bg)
  391. assert rg.bit_generator is bg
  392. rg2 = default_rng(rg)
  393. assert rg2 is rg
  394. assert rg2.bit_generator is bg