图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

818 lines
28 KiB

  1. import pickle
  2. from functools import partial
  3. import numpy as np
  4. import pytest
  5. from numpy.testing import assert_equal, assert_, assert_array_equal
  6. from numpy.random import (Generator, MT19937, PCG64, PCG64DXSM, Philox, SFC64)
  7. @pytest.fixture(scope='module',
  8. params=(np.bool_, np.int8, np.int16, np.int32, np.int64,
  9. np.uint8, np.uint16, np.uint32, np.uint64))
  10. def dtype(request):
  11. return request.param
  12. def params_0(f):
  13. val = f()
  14. assert_(np.isscalar(val))
  15. val = f(10)
  16. assert_(val.shape == (10,))
  17. val = f((10, 10))
  18. assert_(val.shape == (10, 10))
  19. val = f((10, 10, 10))
  20. assert_(val.shape == (10, 10, 10))
  21. val = f(size=(5, 5))
  22. assert_(val.shape == (5, 5))
  23. def params_1(f, bounded=False):
  24. a = 5.0
  25. b = np.arange(2.0, 12.0)
  26. c = np.arange(2.0, 102.0).reshape((10, 10))
  27. d = np.arange(2.0, 1002.0).reshape((10, 10, 10))
  28. e = np.array([2.0, 3.0])
  29. g = np.arange(2.0, 12.0).reshape((1, 10, 1))
  30. if bounded:
  31. a = 0.5
  32. b = b / (1.5 * b.max())
  33. c = c / (1.5 * c.max())
  34. d = d / (1.5 * d.max())
  35. e = e / (1.5 * e.max())
  36. g = g / (1.5 * g.max())
  37. # Scalar
  38. f(a)
  39. # Scalar - size
  40. f(a, size=(10, 10))
  41. # 1d
  42. f(b)
  43. # 2d
  44. f(c)
  45. # 3d
  46. f(d)
  47. # 1d size
  48. f(b, size=10)
  49. # 2d - size - broadcast
  50. f(e, size=(10, 2))
  51. # 3d - size
  52. f(g, size=(10, 10, 10))
  53. def comp_state(state1, state2):
  54. identical = True
  55. if isinstance(state1, dict):
  56. for key in state1:
  57. identical &= comp_state(state1[key], state2[key])
  58. elif type(state1) != type(state2):
  59. identical &= type(state1) == type(state2)
  60. else:
  61. if (isinstance(state1, (list, tuple, np.ndarray)) and isinstance(
  62. state2, (list, tuple, np.ndarray))):
  63. for s1, s2 in zip(state1, state2):
  64. identical &= comp_state(s1, s2)
  65. else:
  66. identical &= state1 == state2
  67. return identical
  68. def warmup(rg, n=None):
  69. if n is None:
  70. n = 11 + np.random.randint(0, 20)
  71. rg.standard_normal(n)
  72. rg.standard_normal(n)
  73. rg.standard_normal(n, dtype=np.float32)
  74. rg.standard_normal(n, dtype=np.float32)
  75. rg.integers(0, 2 ** 24, n, dtype=np.uint64)
  76. rg.integers(0, 2 ** 48, n, dtype=np.uint64)
  77. rg.standard_gamma(11.0, n)
  78. rg.standard_gamma(11.0, n, dtype=np.float32)
  79. rg.random(n, dtype=np.float64)
  80. rg.random(n, dtype=np.float32)
  81. class RNG:
  82. @classmethod
  83. def setup_class(cls):
  84. # Overridden in test classes. Place holder to silence IDE noise
  85. cls.bit_generator = PCG64
  86. cls.advance = None
  87. cls.seed = [12345]
  88. cls.rg = Generator(cls.bit_generator(*cls.seed))
  89. cls.initial_state = cls.rg.bit_generator.state
  90. cls.seed_vector_bits = 64
  91. cls._extra_setup()
  92. @classmethod
  93. def _extra_setup(cls):
  94. cls.vec_1d = np.arange(2.0, 102.0)
  95. cls.vec_2d = np.arange(2.0, 102.0)[None, :]
  96. cls.mat = np.arange(2.0, 102.0, 0.01).reshape((100, 100))
  97. cls.seed_error = TypeError
  98. def _reset_state(self):
  99. self.rg.bit_generator.state = self.initial_state
  100. def test_init(self):
  101. rg = Generator(self.bit_generator())
  102. state = rg.bit_generator.state
  103. rg.standard_normal(1)
  104. rg.standard_normal(1)
  105. rg.bit_generator.state = state
  106. new_state = rg.bit_generator.state
  107. assert_(comp_state(state, new_state))
  108. def test_advance(self):
  109. state = self.rg.bit_generator.state
  110. if hasattr(self.rg.bit_generator, 'advance'):
  111. self.rg.bit_generator.advance(self.advance)
  112. assert_(not comp_state(state, self.rg.bit_generator.state))
  113. else:
  114. bitgen_name = self.rg.bit_generator.__class__.__name__
  115. pytest.skip(f'Advance is not supported by {bitgen_name}')
  116. def test_jump(self):
  117. state = self.rg.bit_generator.state
  118. if hasattr(self.rg.bit_generator, 'jumped'):
  119. bit_gen2 = self.rg.bit_generator.jumped()
  120. jumped_state = bit_gen2.state
  121. assert_(not comp_state(state, jumped_state))
  122. self.rg.random(2 * 3 * 5 * 7 * 11 * 13 * 17)
  123. self.rg.bit_generator.state = state
  124. bit_gen3 = self.rg.bit_generator.jumped()
  125. rejumped_state = bit_gen3.state
  126. assert_(comp_state(jumped_state, rejumped_state))
  127. else:
  128. bitgen_name = self.rg.bit_generator.__class__.__name__
  129. if bitgen_name not in ('SFC64',):
  130. raise AttributeError(f'no "jumped" in {bitgen_name}')
  131. pytest.skip(f'Jump is not supported by {bitgen_name}')
  132. def test_uniform(self):
  133. r = self.rg.uniform(-1.0, 0.0, size=10)
  134. assert_(len(r) == 10)
  135. assert_((r > -1).all())
  136. assert_((r <= 0).all())
  137. def test_uniform_array(self):
  138. r = self.rg.uniform(np.array([-1.0] * 10), 0.0, size=10)
  139. assert_(len(r) == 10)
  140. assert_((r > -1).all())
  141. assert_((r <= 0).all())
  142. r = self.rg.uniform(np.array([-1.0] * 10),
  143. np.array([0.0] * 10), size=10)
  144. assert_(len(r) == 10)
  145. assert_((r > -1).all())
  146. assert_((r <= 0).all())
  147. r = self.rg.uniform(-1.0, np.array([0.0] * 10), size=10)
  148. assert_(len(r) == 10)
  149. assert_((r > -1).all())
  150. assert_((r <= 0).all())
  151. def test_random(self):
  152. assert_(len(self.rg.random(10)) == 10)
  153. params_0(self.rg.random)
  154. def test_standard_normal_zig(self):
  155. assert_(len(self.rg.standard_normal(10)) == 10)
  156. def test_standard_normal(self):
  157. assert_(len(self.rg.standard_normal(10)) == 10)
  158. params_0(self.rg.standard_normal)
  159. def test_standard_gamma(self):
  160. assert_(len(self.rg.standard_gamma(10, 10)) == 10)
  161. assert_(len(self.rg.standard_gamma(np.array([10] * 10), 10)) == 10)
  162. params_1(self.rg.standard_gamma)
  163. def test_standard_exponential(self):
  164. assert_(len(self.rg.standard_exponential(10)) == 10)
  165. params_0(self.rg.standard_exponential)
  166. def test_standard_exponential_float(self):
  167. randoms = self.rg.standard_exponential(10, dtype='float32')
  168. assert_(len(randoms) == 10)
  169. assert randoms.dtype == np.float32
  170. params_0(partial(self.rg.standard_exponential, dtype='float32'))
  171. def test_standard_exponential_float_log(self):
  172. randoms = self.rg.standard_exponential(10, dtype='float32',
  173. method='inv')
  174. assert_(len(randoms) == 10)
  175. assert randoms.dtype == np.float32
  176. params_0(partial(self.rg.standard_exponential, dtype='float32',
  177. method='inv'))
  178. def test_standard_cauchy(self):
  179. assert_(len(self.rg.standard_cauchy(10)) == 10)
  180. params_0(self.rg.standard_cauchy)
  181. def test_standard_t(self):
  182. assert_(len(self.rg.standard_t(10, 10)) == 10)
  183. params_1(self.rg.standard_t)
  184. def test_binomial(self):
  185. assert_(self.rg.binomial(10, .5) >= 0)
  186. assert_(self.rg.binomial(1000, .5) >= 0)
  187. def test_reset_state(self):
  188. state = self.rg.bit_generator.state
  189. int_1 = self.rg.integers(2**31)
  190. self.rg.bit_generator.state = state
  191. int_2 = self.rg.integers(2**31)
  192. assert_(int_1 == int_2)
  193. def test_entropy_init(self):
  194. rg = Generator(self.bit_generator())
  195. rg2 = Generator(self.bit_generator())
  196. assert_(not comp_state(rg.bit_generator.state,
  197. rg2.bit_generator.state))
  198. def test_seed(self):
  199. rg = Generator(self.bit_generator(*self.seed))
  200. rg2 = Generator(self.bit_generator(*self.seed))
  201. rg.random()
  202. rg2.random()
  203. assert_(comp_state(rg.bit_generator.state, rg2.bit_generator.state))
  204. def test_reset_state_gauss(self):
  205. rg = Generator(self.bit_generator(*self.seed))
  206. rg.standard_normal()
  207. state = rg.bit_generator.state
  208. n1 = rg.standard_normal(size=10)
  209. rg2 = Generator(self.bit_generator())
  210. rg2.bit_generator.state = state
  211. n2 = rg2.standard_normal(size=10)
  212. assert_array_equal(n1, n2)
  213. def test_reset_state_uint32(self):
  214. rg = Generator(self.bit_generator(*self.seed))
  215. rg.integers(0, 2 ** 24, 120, dtype=np.uint32)
  216. state = rg.bit_generator.state
  217. n1 = rg.integers(0, 2 ** 24, 10, dtype=np.uint32)
  218. rg2 = Generator(self.bit_generator())
  219. rg2.bit_generator.state = state
  220. n2 = rg2.integers(0, 2 ** 24, 10, dtype=np.uint32)
  221. assert_array_equal(n1, n2)
  222. def test_reset_state_float(self):
  223. rg = Generator(self.bit_generator(*self.seed))
  224. rg.random(dtype='float32')
  225. state = rg.bit_generator.state
  226. n1 = rg.random(size=10, dtype='float32')
  227. rg2 = Generator(self.bit_generator())
  228. rg2.bit_generator.state = state
  229. n2 = rg2.random(size=10, dtype='float32')
  230. assert_((n1 == n2).all())
  231. def test_shuffle(self):
  232. original = np.arange(200, 0, -1)
  233. permuted = self.rg.permutation(original)
  234. assert_((original != permuted).any())
  235. def test_permutation(self):
  236. original = np.arange(200, 0, -1)
  237. permuted = self.rg.permutation(original)
  238. assert_((original != permuted).any())
  239. def test_beta(self):
  240. vals = self.rg.beta(2.0, 2.0, 10)
  241. assert_(len(vals) == 10)
  242. vals = self.rg.beta(np.array([2.0] * 10), 2.0)
  243. assert_(len(vals) == 10)
  244. vals = self.rg.beta(2.0, np.array([2.0] * 10))
  245. assert_(len(vals) == 10)
  246. vals = self.rg.beta(np.array([2.0] * 10), np.array([2.0] * 10))
  247. assert_(len(vals) == 10)
  248. vals = self.rg.beta(np.array([2.0] * 10), np.array([[2.0]] * 10))
  249. assert_(vals.shape == (10, 10))
  250. def test_bytes(self):
  251. vals = self.rg.bytes(10)
  252. assert_(len(vals) == 10)
  253. def test_chisquare(self):
  254. vals = self.rg.chisquare(2.0, 10)
  255. assert_(len(vals) == 10)
  256. params_1(self.rg.chisquare)
  257. def test_exponential(self):
  258. vals = self.rg.exponential(2.0, 10)
  259. assert_(len(vals) == 10)
  260. params_1(self.rg.exponential)
  261. def test_f(self):
  262. vals = self.rg.f(3, 1000, 10)
  263. assert_(len(vals) == 10)
  264. def test_gamma(self):
  265. vals = self.rg.gamma(3, 2, 10)
  266. assert_(len(vals) == 10)
  267. def test_geometric(self):
  268. vals = self.rg.geometric(0.5, 10)
  269. assert_(len(vals) == 10)
  270. params_1(self.rg.exponential, bounded=True)
  271. def test_gumbel(self):
  272. vals = self.rg.gumbel(2.0, 2.0, 10)
  273. assert_(len(vals) == 10)
  274. def test_laplace(self):
  275. vals = self.rg.laplace(2.0, 2.0, 10)
  276. assert_(len(vals) == 10)
  277. def test_logitic(self):
  278. vals = self.rg.logistic(2.0, 2.0, 10)
  279. assert_(len(vals) == 10)
  280. def test_logseries(self):
  281. vals = self.rg.logseries(0.5, 10)
  282. assert_(len(vals) == 10)
  283. def test_negative_binomial(self):
  284. vals = self.rg.negative_binomial(10, 0.2, 10)
  285. assert_(len(vals) == 10)
  286. def test_noncentral_chisquare(self):
  287. vals = self.rg.noncentral_chisquare(10, 2, 10)
  288. assert_(len(vals) == 10)
  289. def test_noncentral_f(self):
  290. vals = self.rg.noncentral_f(3, 1000, 2, 10)
  291. assert_(len(vals) == 10)
  292. vals = self.rg.noncentral_f(np.array([3] * 10), 1000, 2)
  293. assert_(len(vals) == 10)
  294. vals = self.rg.noncentral_f(3, np.array([1000] * 10), 2)
  295. assert_(len(vals) == 10)
  296. vals = self.rg.noncentral_f(3, 1000, np.array([2] * 10))
  297. assert_(len(vals) == 10)
  298. def test_normal(self):
  299. vals = self.rg.normal(10, 0.2, 10)
  300. assert_(len(vals) == 10)
  301. def test_pareto(self):
  302. vals = self.rg.pareto(3.0, 10)
  303. assert_(len(vals) == 10)
  304. def test_poisson(self):
  305. vals = self.rg.poisson(10, 10)
  306. assert_(len(vals) == 10)
  307. vals = self.rg.poisson(np.array([10] * 10))
  308. assert_(len(vals) == 10)
  309. params_1(self.rg.poisson)
  310. def test_power(self):
  311. vals = self.rg.power(0.2, 10)
  312. assert_(len(vals) == 10)
  313. def test_integers(self):
  314. vals = self.rg.integers(10, 20, 10)
  315. assert_(len(vals) == 10)
  316. def test_rayleigh(self):
  317. vals = self.rg.rayleigh(0.2, 10)
  318. assert_(len(vals) == 10)
  319. params_1(self.rg.rayleigh, bounded=True)
  320. def test_vonmises(self):
  321. vals = self.rg.vonmises(10, 0.2, 10)
  322. assert_(len(vals) == 10)
  323. def test_wald(self):
  324. vals = self.rg.wald(1.0, 1.0, 10)
  325. assert_(len(vals) == 10)
  326. def test_weibull(self):
  327. vals = self.rg.weibull(1.0, 10)
  328. assert_(len(vals) == 10)
  329. def test_zipf(self):
  330. vals = self.rg.zipf(10, 10)
  331. assert_(len(vals) == 10)
  332. vals = self.rg.zipf(self.vec_1d)
  333. assert_(len(vals) == 100)
  334. vals = self.rg.zipf(self.vec_2d)
  335. assert_(vals.shape == (1, 100))
  336. vals = self.rg.zipf(self.mat)
  337. assert_(vals.shape == (100, 100))
  338. def test_hypergeometric(self):
  339. vals = self.rg.hypergeometric(25, 25, 20)
  340. assert_(np.isscalar(vals))
  341. vals = self.rg.hypergeometric(np.array([25] * 10), 25, 20)
  342. assert_(vals.shape == (10,))
  343. def test_triangular(self):
  344. vals = self.rg.triangular(-5, 0, 5)
  345. assert_(np.isscalar(vals))
  346. vals = self.rg.triangular(-5, np.array([0] * 10), 5)
  347. assert_(vals.shape == (10,))
  348. def test_multivariate_normal(self):
  349. mean = [0, 0]
  350. cov = [[1, 0], [0, 100]] # diagonal covariance
  351. x = self.rg.multivariate_normal(mean, cov, 5000)
  352. assert_(x.shape == (5000, 2))
  353. x_zig = self.rg.multivariate_normal(mean, cov, 5000)
  354. assert_(x.shape == (5000, 2))
  355. x_inv = self.rg.multivariate_normal(mean, cov, 5000)
  356. assert_(x.shape == (5000, 2))
  357. assert_((x_zig != x_inv).any())
  358. def test_multinomial(self):
  359. vals = self.rg.multinomial(100, [1.0 / 3, 2.0 / 3])
  360. assert_(vals.shape == (2,))
  361. vals = self.rg.multinomial(100, [1.0 / 3, 2.0 / 3], size=10)
  362. assert_(vals.shape == (10, 2))
  363. def test_dirichlet(self):
  364. s = self.rg.dirichlet((10, 5, 3), 20)
  365. assert_(s.shape == (20, 3))
  366. def test_pickle(self):
  367. pick = pickle.dumps(self.rg)
  368. unpick = pickle.loads(pick)
  369. assert_((type(self.rg) == type(unpick)))
  370. assert_(comp_state(self.rg.bit_generator.state,
  371. unpick.bit_generator.state))
  372. pick = pickle.dumps(self.rg)
  373. unpick = pickle.loads(pick)
  374. assert_((type(self.rg) == type(unpick)))
  375. assert_(comp_state(self.rg.bit_generator.state,
  376. unpick.bit_generator.state))
  377. def test_seed_array(self):
  378. if self.seed_vector_bits is None:
  379. bitgen_name = self.bit_generator.__name__
  380. pytest.skip(f'Vector seeding is not supported by {bitgen_name}')
  381. if self.seed_vector_bits == 32:
  382. dtype = np.uint32
  383. else:
  384. dtype = np.uint64
  385. seed = np.array([1], dtype=dtype)
  386. bg = self.bit_generator(seed)
  387. state1 = bg.state
  388. bg = self.bit_generator(1)
  389. state2 = bg.state
  390. assert_(comp_state(state1, state2))
  391. seed = np.arange(4, dtype=dtype)
  392. bg = self.bit_generator(seed)
  393. state1 = bg.state
  394. bg = self.bit_generator(seed[0])
  395. state2 = bg.state
  396. assert_(not comp_state(state1, state2))
  397. seed = np.arange(1500, dtype=dtype)
  398. bg = self.bit_generator(seed)
  399. state1 = bg.state
  400. bg = self.bit_generator(seed[0])
  401. state2 = bg.state
  402. assert_(not comp_state(state1, state2))
  403. seed = 2 ** np.mod(np.arange(1500, dtype=dtype),
  404. self.seed_vector_bits - 1) + 1
  405. bg = self.bit_generator(seed)
  406. state1 = bg.state
  407. bg = self.bit_generator(seed[0])
  408. state2 = bg.state
  409. assert_(not comp_state(state1, state2))
  410. def test_uniform_float(self):
  411. rg = Generator(self.bit_generator(12345))
  412. warmup(rg)
  413. state = rg.bit_generator.state
  414. r1 = rg.random(11, dtype=np.float32)
  415. rg2 = Generator(self.bit_generator())
  416. warmup(rg2)
  417. rg2.bit_generator.state = state
  418. r2 = rg2.random(11, dtype=np.float32)
  419. assert_array_equal(r1, r2)
  420. assert_equal(r1.dtype, np.float32)
  421. assert_(comp_state(rg.bit_generator.state, rg2.bit_generator.state))
  422. def test_gamma_floats(self):
  423. rg = Generator(self.bit_generator())
  424. warmup(rg)
  425. state = rg.bit_generator.state
  426. r1 = rg.standard_gamma(4.0, 11, dtype=np.float32)
  427. rg2 = Generator(self.bit_generator())
  428. warmup(rg2)
  429. rg2.bit_generator.state = state
  430. r2 = rg2.standard_gamma(4.0, 11, dtype=np.float32)
  431. assert_array_equal(r1, r2)
  432. assert_equal(r1.dtype, np.float32)
  433. assert_(comp_state(rg.bit_generator.state, rg2.bit_generator.state))
  434. def test_normal_floats(self):
  435. rg = Generator(self.bit_generator())
  436. warmup(rg)
  437. state = rg.bit_generator.state
  438. r1 = rg.standard_normal(11, dtype=np.float32)
  439. rg2 = Generator(self.bit_generator())
  440. warmup(rg2)
  441. rg2.bit_generator.state = state
  442. r2 = rg2.standard_normal(11, dtype=np.float32)
  443. assert_array_equal(r1, r2)
  444. assert_equal(r1.dtype, np.float32)
  445. assert_(comp_state(rg.bit_generator.state, rg2.bit_generator.state))
  446. def test_normal_zig_floats(self):
  447. rg = Generator(self.bit_generator())
  448. warmup(rg)
  449. state = rg.bit_generator.state
  450. r1 = rg.standard_normal(11, dtype=np.float32)
  451. rg2 = Generator(self.bit_generator())
  452. warmup(rg2)
  453. rg2.bit_generator.state = state
  454. r2 = rg2.standard_normal(11, dtype=np.float32)
  455. assert_array_equal(r1, r2)
  456. assert_equal(r1.dtype, np.float32)
  457. assert_(comp_state(rg.bit_generator.state, rg2.bit_generator.state))
  458. def test_output_fill(self):
  459. rg = self.rg
  460. state = rg.bit_generator.state
  461. size = (31, 7, 97)
  462. existing = np.empty(size)
  463. rg.bit_generator.state = state
  464. rg.standard_normal(out=existing)
  465. rg.bit_generator.state = state
  466. direct = rg.standard_normal(size=size)
  467. assert_equal(direct, existing)
  468. sized = np.empty(size)
  469. rg.bit_generator.state = state
  470. rg.standard_normal(out=sized, size=sized.shape)
  471. existing = np.empty(size, dtype=np.float32)
  472. rg.bit_generator.state = state
  473. rg.standard_normal(out=existing, dtype=np.float32)
  474. rg.bit_generator.state = state
  475. direct = rg.standard_normal(size=size, dtype=np.float32)
  476. assert_equal(direct, existing)
  477. def test_output_filling_uniform(self):
  478. rg = self.rg
  479. state = rg.bit_generator.state
  480. size = (31, 7, 97)
  481. existing = np.empty(size)
  482. rg.bit_generator.state = state
  483. rg.random(out=existing)
  484. rg.bit_generator.state = state
  485. direct = rg.random(size=size)
  486. assert_equal(direct, existing)
  487. existing = np.empty(size, dtype=np.float32)
  488. rg.bit_generator.state = state
  489. rg.random(out=existing, dtype=np.float32)
  490. rg.bit_generator.state = state
  491. direct = rg.random(size=size, dtype=np.float32)
  492. assert_equal(direct, existing)
  493. def test_output_filling_exponential(self):
  494. rg = self.rg
  495. state = rg.bit_generator.state
  496. size = (31, 7, 97)
  497. existing = np.empty(size)
  498. rg.bit_generator.state = state
  499. rg.standard_exponential(out=existing)
  500. rg.bit_generator.state = state
  501. direct = rg.standard_exponential(size=size)
  502. assert_equal(direct, existing)
  503. existing = np.empty(size, dtype=np.float32)
  504. rg.bit_generator.state = state
  505. rg.standard_exponential(out=existing, dtype=np.float32)
  506. rg.bit_generator.state = state
  507. direct = rg.standard_exponential(size=size, dtype=np.float32)
  508. assert_equal(direct, existing)
  509. def test_output_filling_gamma(self):
  510. rg = self.rg
  511. state = rg.bit_generator.state
  512. size = (31, 7, 97)
  513. existing = np.zeros(size)
  514. rg.bit_generator.state = state
  515. rg.standard_gamma(1.0, out=existing)
  516. rg.bit_generator.state = state
  517. direct = rg.standard_gamma(1.0, size=size)
  518. assert_equal(direct, existing)
  519. existing = np.zeros(size, dtype=np.float32)
  520. rg.bit_generator.state = state
  521. rg.standard_gamma(1.0, out=existing, dtype=np.float32)
  522. rg.bit_generator.state = state
  523. direct = rg.standard_gamma(1.0, size=size, dtype=np.float32)
  524. assert_equal(direct, existing)
  525. def test_output_filling_gamma_broadcast(self):
  526. rg = self.rg
  527. state = rg.bit_generator.state
  528. size = (31, 7, 97)
  529. mu = np.arange(97.0) + 1.0
  530. existing = np.zeros(size)
  531. rg.bit_generator.state = state
  532. rg.standard_gamma(mu, out=existing)
  533. rg.bit_generator.state = state
  534. direct = rg.standard_gamma(mu, size=size)
  535. assert_equal(direct, existing)
  536. existing = np.zeros(size, dtype=np.float32)
  537. rg.bit_generator.state = state
  538. rg.standard_gamma(mu, out=existing, dtype=np.float32)
  539. rg.bit_generator.state = state
  540. direct = rg.standard_gamma(mu, size=size, dtype=np.float32)
  541. assert_equal(direct, existing)
  542. def test_output_fill_error(self):
  543. rg = self.rg
  544. size = (31, 7, 97)
  545. existing = np.empty(size)
  546. with pytest.raises(TypeError):
  547. rg.standard_normal(out=existing, dtype=np.float32)
  548. with pytest.raises(ValueError):
  549. rg.standard_normal(out=existing[::3])
  550. existing = np.empty(size, dtype=np.float32)
  551. with pytest.raises(TypeError):
  552. rg.standard_normal(out=existing, dtype=np.float64)
  553. existing = np.zeros(size, dtype=np.float32)
  554. with pytest.raises(TypeError):
  555. rg.standard_gamma(1.0, out=existing, dtype=np.float64)
  556. with pytest.raises(ValueError):
  557. rg.standard_gamma(1.0, out=existing[::3], dtype=np.float32)
  558. existing = np.zeros(size, dtype=np.float64)
  559. with pytest.raises(TypeError):
  560. rg.standard_gamma(1.0, out=existing, dtype=np.float32)
  561. with pytest.raises(ValueError):
  562. rg.standard_gamma(1.0, out=existing[::3])
  563. def test_integers_broadcast(self, dtype):
  564. if dtype == np.bool_:
  565. upper = 2
  566. lower = 0
  567. else:
  568. info = np.iinfo(dtype)
  569. upper = int(info.max) + 1
  570. lower = info.min
  571. self._reset_state()
  572. a = self.rg.integers(lower, [upper] * 10, dtype=dtype)
  573. self._reset_state()
  574. b = self.rg.integers([lower] * 10, upper, dtype=dtype)
  575. assert_equal(a, b)
  576. self._reset_state()
  577. c = self.rg.integers(lower, upper, size=10, dtype=dtype)
  578. assert_equal(a, c)
  579. self._reset_state()
  580. d = self.rg.integers(np.array(
  581. [lower] * 10), np.array([upper], dtype=object), size=10,
  582. dtype=dtype)
  583. assert_equal(a, d)
  584. self._reset_state()
  585. e = self.rg.integers(
  586. np.array([lower] * 10), np.array([upper] * 10), size=10,
  587. dtype=dtype)
  588. assert_equal(a, e)
  589. self._reset_state()
  590. a = self.rg.integers(0, upper, size=10, dtype=dtype)
  591. self._reset_state()
  592. b = self.rg.integers([upper] * 10, dtype=dtype)
  593. assert_equal(a, b)
  594. def test_integers_numpy(self, dtype):
  595. high = np.array([1])
  596. low = np.array([0])
  597. out = self.rg.integers(low, high, dtype=dtype)
  598. assert out.shape == (1,)
  599. out = self.rg.integers(low[0], high, dtype=dtype)
  600. assert out.shape == (1,)
  601. out = self.rg.integers(low, high[0], dtype=dtype)
  602. assert out.shape == (1,)
  603. def test_integers_broadcast_errors(self, dtype):
  604. if dtype == np.bool_:
  605. upper = 2
  606. lower = 0
  607. else:
  608. info = np.iinfo(dtype)
  609. upper = int(info.max) + 1
  610. lower = info.min
  611. with pytest.raises(ValueError):
  612. self.rg.integers(lower, [upper + 1] * 10, dtype=dtype)
  613. with pytest.raises(ValueError):
  614. self.rg.integers(lower - 1, [upper] * 10, dtype=dtype)
  615. with pytest.raises(ValueError):
  616. self.rg.integers([lower - 1], [upper] * 10, dtype=dtype)
  617. with pytest.raises(ValueError):
  618. self.rg.integers([0], [0], dtype=dtype)
  619. class TestMT19937(RNG):
  620. @classmethod
  621. def setup_class(cls):
  622. cls.bit_generator = MT19937
  623. cls.advance = None
  624. cls.seed = [2 ** 21 + 2 ** 16 + 2 ** 5 + 1]
  625. cls.rg = Generator(cls.bit_generator(*cls.seed))
  626. cls.initial_state = cls.rg.bit_generator.state
  627. cls.seed_vector_bits = 32
  628. cls._extra_setup()
  629. cls.seed_error = ValueError
  630. def test_numpy_state(self):
  631. nprg = np.random.RandomState()
  632. nprg.standard_normal(99)
  633. state = nprg.get_state()
  634. self.rg.bit_generator.state = state
  635. state2 = self.rg.bit_generator.state
  636. assert_((state[1] == state2['state']['key']).all())
  637. assert_((state[2] == state2['state']['pos']))
  638. class TestPhilox(RNG):
  639. @classmethod
  640. def setup_class(cls):
  641. cls.bit_generator = Philox
  642. cls.advance = 2**63 + 2**31 + 2**15 + 1
  643. cls.seed = [12345]
  644. cls.rg = Generator(cls.bit_generator(*cls.seed))
  645. cls.initial_state = cls.rg.bit_generator.state
  646. cls.seed_vector_bits = 64
  647. cls._extra_setup()
  648. class TestSFC64(RNG):
  649. @classmethod
  650. def setup_class(cls):
  651. cls.bit_generator = SFC64
  652. cls.advance = None
  653. cls.seed = [12345]
  654. cls.rg = Generator(cls.bit_generator(*cls.seed))
  655. cls.initial_state = cls.rg.bit_generator.state
  656. cls.seed_vector_bits = 192
  657. cls._extra_setup()
  658. class TestPCG64(RNG):
  659. @classmethod
  660. def setup_class(cls):
  661. cls.bit_generator = PCG64
  662. cls.advance = 2**63 + 2**31 + 2**15 + 1
  663. cls.seed = [12345]
  664. cls.rg = Generator(cls.bit_generator(*cls.seed))
  665. cls.initial_state = cls.rg.bit_generator.state
  666. cls.seed_vector_bits = 64
  667. cls._extra_setup()
  668. class TestPCG64DXSM(RNG):
  669. @classmethod
  670. def setup_class(cls):
  671. cls.bit_generator = PCG64DXSM
  672. cls.advance = 2**63 + 2**31 + 2**15 + 1
  673. cls.seed = [12345]
  674. cls.rg = Generator(cls.bit_generator(*cls.seed))
  675. cls.initial_state = cls.rg.bit_generator.state
  676. cls.seed_vector_bits = 64
  677. cls._extra_setup()
  678. class TestDefaultRNG(RNG):
  679. @classmethod
  680. def setup_class(cls):
  681. # This will duplicate some tests that directly instantiate a fresh
  682. # Generator(), but that's okay.
  683. cls.bit_generator = PCG64
  684. cls.advance = 2**63 + 2**31 + 2**15 + 1
  685. cls.seed = [12345]
  686. cls.rg = np.random.default_rng(*cls.seed)
  687. cls.initial_state = cls.rg.bit_generator.state
  688. cls.seed_vector_bits = 64
  689. cls._extra_setup()
  690. def test_default_is_pcg64(self):
  691. # In order to change the default BitGenerator, we'll go through
  692. # a deprecation cycle to move to a different function.
  693. assert_(isinstance(self.rg.bit_generator, PCG64))
  694. def test_seed(self):
  695. np.random.default_rng()
  696. np.random.default_rng(None)
  697. np.random.default_rng(12345)
  698. np.random.default_rng(0)
  699. np.random.default_rng(43660444402423911716352051725018508569)
  700. np.random.default_rng([43660444402423911716352051725018508569,
  701. 279705150948142787361475340226491943209])
  702. with pytest.raises(ValueError):
  703. np.random.default_rng(-1)
  704. with pytest.raises(ValueError):
  705. np.random.default_rng([12345, -1])