m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

998 lines
38 KiB

7 months ago
  1. import warnings
  2. import pytest
  3. import numpy as np
  4. from numpy.lib.nanfunctions import _nan_mask, _replace_nan
  5. from numpy.testing import (
  6. assert_, assert_equal, assert_almost_equal, assert_no_warnings,
  7. assert_raises, assert_array_equal, suppress_warnings
  8. )
  9. # Test data
  10. _ndat = np.array([[0.6244, np.nan, 0.2692, 0.0116, np.nan, 0.1170],
  11. [0.5351, -0.9403, np.nan, 0.2100, 0.4759, 0.2833],
  12. [np.nan, np.nan, np.nan, 0.1042, np.nan, -0.5954],
  13. [0.1610, np.nan, np.nan, 0.1859, 0.3146, np.nan]])
  14. # Rows of _ndat with nans removed
  15. _rdat = [np.array([0.6244, 0.2692, 0.0116, 0.1170]),
  16. np.array([0.5351, -0.9403, 0.2100, 0.4759, 0.2833]),
  17. np.array([0.1042, -0.5954]),
  18. np.array([0.1610, 0.1859, 0.3146])]
  19. # Rows of _ndat with nans converted to ones
  20. _ndat_ones = np.array([[0.6244, 1.0, 0.2692, 0.0116, 1.0, 0.1170],
  21. [0.5351, -0.9403, 1.0, 0.2100, 0.4759, 0.2833],
  22. [1.0, 1.0, 1.0, 0.1042, 1.0, -0.5954],
  23. [0.1610, 1.0, 1.0, 0.1859, 0.3146, 1.0]])
  24. # Rows of _ndat with nans converted to zeros
  25. _ndat_zeros = np.array([[0.6244, 0.0, 0.2692, 0.0116, 0.0, 0.1170],
  26. [0.5351, -0.9403, 0.0, 0.2100, 0.4759, 0.2833],
  27. [0.0, 0.0, 0.0, 0.1042, 0.0, -0.5954],
  28. [0.1610, 0.0, 0.0, 0.1859, 0.3146, 0.0]])
  29. class TestNanFunctions_MinMax:
  30. nanfuncs = [np.nanmin, np.nanmax]
  31. stdfuncs = [np.min, np.max]
  32. def test_mutation(self):
  33. # Check that passed array is not modified.
  34. ndat = _ndat.copy()
  35. for f in self.nanfuncs:
  36. f(ndat)
  37. assert_equal(ndat, _ndat)
  38. def test_keepdims(self):
  39. mat = np.eye(3)
  40. for nf, rf in zip(self.nanfuncs, self.stdfuncs):
  41. for axis in [None, 0, 1]:
  42. tgt = rf(mat, axis=axis, keepdims=True)
  43. res = nf(mat, axis=axis, keepdims=True)
  44. assert_(res.ndim == tgt.ndim)
  45. def test_out(self):
  46. mat = np.eye(3)
  47. for nf, rf in zip(self.nanfuncs, self.stdfuncs):
  48. resout = np.zeros(3)
  49. tgt = rf(mat, axis=1)
  50. res = nf(mat, axis=1, out=resout)
  51. assert_almost_equal(res, resout)
  52. assert_almost_equal(res, tgt)
  53. def test_dtype_from_input(self):
  54. codes = 'efdgFDG'
  55. for nf, rf in zip(self.nanfuncs, self.stdfuncs):
  56. for c in codes:
  57. mat = np.eye(3, dtype=c)
  58. tgt = rf(mat, axis=1).dtype.type
  59. res = nf(mat, axis=1).dtype.type
  60. assert_(res is tgt)
  61. # scalar case
  62. tgt = rf(mat, axis=None).dtype.type
  63. res = nf(mat, axis=None).dtype.type
  64. assert_(res is tgt)
  65. def test_result_values(self):
  66. for nf, rf in zip(self.nanfuncs, self.stdfuncs):
  67. tgt = [rf(d) for d in _rdat]
  68. res = nf(_ndat, axis=1)
  69. assert_almost_equal(res, tgt)
  70. def test_allnans(self):
  71. mat = np.array([np.nan]*9).reshape(3, 3)
  72. for f in self.nanfuncs:
  73. for axis in [None, 0, 1]:
  74. with warnings.catch_warnings(record=True) as w:
  75. warnings.simplefilter('always')
  76. assert_(np.isnan(f(mat, axis=axis)).all())
  77. assert_(len(w) == 1, 'no warning raised')
  78. assert_(issubclass(w[0].category, RuntimeWarning))
  79. # Check scalars
  80. with warnings.catch_warnings(record=True) as w:
  81. warnings.simplefilter('always')
  82. assert_(np.isnan(f(np.nan)))
  83. assert_(len(w) == 1, 'no warning raised')
  84. assert_(issubclass(w[0].category, RuntimeWarning))
  85. def test_masked(self):
  86. mat = np.ma.fix_invalid(_ndat)
  87. msk = mat._mask.copy()
  88. for f in [np.nanmin]:
  89. res = f(mat, axis=1)
  90. tgt = f(_ndat, axis=1)
  91. assert_equal(res, tgt)
  92. assert_equal(mat._mask, msk)
  93. assert_(not np.isinf(mat).any())
  94. def test_scalar(self):
  95. for f in self.nanfuncs:
  96. assert_(f(0.) == 0.)
  97. def test_subclass(self):
  98. class MyNDArray(np.ndarray):
  99. pass
  100. # Check that it works and that type and
  101. # shape are preserved
  102. mine = np.eye(3).view(MyNDArray)
  103. for f in self.nanfuncs:
  104. res = f(mine, axis=0)
  105. assert_(isinstance(res, MyNDArray))
  106. assert_(res.shape == (3,))
  107. res = f(mine, axis=1)
  108. assert_(isinstance(res, MyNDArray))
  109. assert_(res.shape == (3,))
  110. res = f(mine)
  111. assert_(res.shape == ())
  112. # check that rows of nan are dealt with for subclasses (#4628)
  113. mine[1] = np.nan
  114. for f in self.nanfuncs:
  115. with warnings.catch_warnings(record=True) as w:
  116. warnings.simplefilter('always')
  117. res = f(mine, axis=0)
  118. assert_(isinstance(res, MyNDArray))
  119. assert_(not np.any(np.isnan(res)))
  120. assert_(len(w) == 0)
  121. with warnings.catch_warnings(record=True) as w:
  122. warnings.simplefilter('always')
  123. res = f(mine, axis=1)
  124. assert_(isinstance(res, MyNDArray))
  125. assert_(np.isnan(res[1]) and not np.isnan(res[0])
  126. and not np.isnan(res[2]))
  127. assert_(len(w) == 1, 'no warning raised')
  128. assert_(issubclass(w[0].category, RuntimeWarning))
  129. with warnings.catch_warnings(record=True) as w:
  130. warnings.simplefilter('always')
  131. res = f(mine)
  132. assert_(res.shape == ())
  133. assert_(res != np.nan)
  134. assert_(len(w) == 0)
  135. def test_object_array(self):
  136. arr = np.array([[1.0, 2.0], [np.nan, 4.0], [np.nan, np.nan]], dtype=object)
  137. assert_equal(np.nanmin(arr), 1.0)
  138. assert_equal(np.nanmin(arr, axis=0), [1.0, 2.0])
  139. with warnings.catch_warnings(record=True) as w:
  140. warnings.simplefilter('always')
  141. # assert_equal does not work on object arrays of nan
  142. assert_equal(list(np.nanmin(arr, axis=1)), [1.0, 4.0, np.nan])
  143. assert_(len(w) == 1, 'no warning raised')
  144. assert_(issubclass(w[0].category, RuntimeWarning))
  145. class TestNanFunctions_ArgminArgmax:
  146. nanfuncs = [np.nanargmin, np.nanargmax]
  147. def test_mutation(self):
  148. # Check that passed array is not modified.
  149. ndat = _ndat.copy()
  150. for f in self.nanfuncs:
  151. f(ndat)
  152. assert_equal(ndat, _ndat)
  153. def test_result_values(self):
  154. for f, fcmp in zip(self.nanfuncs, [np.greater, np.less]):
  155. for row in _ndat:
  156. with suppress_warnings() as sup:
  157. sup.filter(RuntimeWarning, "invalid value encountered in")
  158. ind = f(row)
  159. val = row[ind]
  160. # comparing with NaN is tricky as the result
  161. # is always false except for NaN != NaN
  162. assert_(not np.isnan(val))
  163. assert_(not fcmp(val, row).any())
  164. assert_(not np.equal(val, row[:ind]).any())
  165. def test_allnans(self):
  166. mat = np.array([np.nan]*9).reshape(3, 3)
  167. for f in self.nanfuncs:
  168. for axis in [None, 0, 1]:
  169. assert_raises(ValueError, f, mat, axis=axis)
  170. assert_raises(ValueError, f, np.nan)
  171. def test_empty(self):
  172. mat = np.zeros((0, 3))
  173. for f in self.nanfuncs:
  174. for axis in [0, None]:
  175. assert_raises(ValueError, f, mat, axis=axis)
  176. for axis in [1]:
  177. res = f(mat, axis=axis)
  178. assert_equal(res, np.zeros(0))
  179. def test_scalar(self):
  180. for f in self.nanfuncs:
  181. assert_(f(0.) == 0.)
  182. def test_subclass(self):
  183. class MyNDArray(np.ndarray):
  184. pass
  185. # Check that it works and that type and
  186. # shape are preserved
  187. mine = np.eye(3).view(MyNDArray)
  188. for f in self.nanfuncs:
  189. res = f(mine, axis=0)
  190. assert_(isinstance(res, MyNDArray))
  191. assert_(res.shape == (3,))
  192. res = f(mine, axis=1)
  193. assert_(isinstance(res, MyNDArray))
  194. assert_(res.shape == (3,))
  195. res = f(mine)
  196. assert_(res.shape == ())
  197. class TestNanFunctions_IntTypes:
  198. int_types = (np.int8, np.int16, np.int32, np.int64, np.uint8,
  199. np.uint16, np.uint32, np.uint64)
  200. mat = np.array([127, 39, 93, 87, 46])
  201. def integer_arrays(self):
  202. for dtype in self.int_types:
  203. yield self.mat.astype(dtype)
  204. def test_nanmin(self):
  205. tgt = np.min(self.mat)
  206. for mat in self.integer_arrays():
  207. assert_equal(np.nanmin(mat), tgt)
  208. def test_nanmax(self):
  209. tgt = np.max(self.mat)
  210. for mat in self.integer_arrays():
  211. assert_equal(np.nanmax(mat), tgt)
  212. def test_nanargmin(self):
  213. tgt = np.argmin(self.mat)
  214. for mat in self.integer_arrays():
  215. assert_equal(np.nanargmin(mat), tgt)
  216. def test_nanargmax(self):
  217. tgt = np.argmax(self.mat)
  218. for mat in self.integer_arrays():
  219. assert_equal(np.nanargmax(mat), tgt)
  220. def test_nansum(self):
  221. tgt = np.sum(self.mat)
  222. for mat in self.integer_arrays():
  223. assert_equal(np.nansum(mat), tgt)
  224. def test_nanprod(self):
  225. tgt = np.prod(self.mat)
  226. for mat in self.integer_arrays():
  227. assert_equal(np.nanprod(mat), tgt)
  228. def test_nancumsum(self):
  229. tgt = np.cumsum(self.mat)
  230. for mat in self.integer_arrays():
  231. assert_equal(np.nancumsum(mat), tgt)
  232. def test_nancumprod(self):
  233. tgt = np.cumprod(self.mat)
  234. for mat in self.integer_arrays():
  235. assert_equal(np.nancumprod(mat), tgt)
  236. def test_nanmean(self):
  237. tgt = np.mean(self.mat)
  238. for mat in self.integer_arrays():
  239. assert_equal(np.nanmean(mat), tgt)
  240. def test_nanvar(self):
  241. tgt = np.var(self.mat)
  242. for mat in self.integer_arrays():
  243. assert_equal(np.nanvar(mat), tgt)
  244. tgt = np.var(mat, ddof=1)
  245. for mat in self.integer_arrays():
  246. assert_equal(np.nanvar(mat, ddof=1), tgt)
  247. def test_nanstd(self):
  248. tgt = np.std(self.mat)
  249. for mat in self.integer_arrays():
  250. assert_equal(np.nanstd(mat), tgt)
  251. tgt = np.std(self.mat, ddof=1)
  252. for mat in self.integer_arrays():
  253. assert_equal(np.nanstd(mat, ddof=1), tgt)
  254. class SharedNanFunctionsTestsMixin:
  255. def test_mutation(self):
  256. # Check that passed array is not modified.
  257. ndat = _ndat.copy()
  258. for f in self.nanfuncs:
  259. f(ndat)
  260. assert_equal(ndat, _ndat)
  261. def test_keepdims(self):
  262. mat = np.eye(3)
  263. for nf, rf in zip(self.nanfuncs, self.stdfuncs):
  264. for axis in [None, 0, 1]:
  265. tgt = rf(mat, axis=axis, keepdims=True)
  266. res = nf(mat, axis=axis, keepdims=True)
  267. assert_(res.ndim == tgt.ndim)
  268. def test_out(self):
  269. mat = np.eye(3)
  270. for nf, rf in zip(self.nanfuncs, self.stdfuncs):
  271. resout = np.zeros(3)
  272. tgt = rf(mat, axis=1)
  273. res = nf(mat, axis=1, out=resout)
  274. assert_almost_equal(res, resout)
  275. assert_almost_equal(res, tgt)
  276. def test_dtype_from_dtype(self):
  277. mat = np.eye(3)
  278. codes = 'efdgFDG'
  279. for nf, rf in zip(self.nanfuncs, self.stdfuncs):
  280. for c in codes:
  281. with suppress_warnings() as sup:
  282. if nf in {np.nanstd, np.nanvar} and c in 'FDG':
  283. # Giving the warning is a small bug, see gh-8000
  284. sup.filter(np.ComplexWarning)
  285. tgt = rf(mat, dtype=np.dtype(c), axis=1).dtype.type
  286. res = nf(mat, dtype=np.dtype(c), axis=1).dtype.type
  287. assert_(res is tgt)
  288. # scalar case
  289. tgt = rf(mat, dtype=np.dtype(c), axis=None).dtype.type
  290. res = nf(mat, dtype=np.dtype(c), axis=None).dtype.type
  291. assert_(res is tgt)
  292. def test_dtype_from_char(self):
  293. mat = np.eye(3)
  294. codes = 'efdgFDG'
  295. for nf, rf in zip(self.nanfuncs, self.stdfuncs):
  296. for c in codes:
  297. with suppress_warnings() as sup:
  298. if nf in {np.nanstd, np.nanvar} and c in 'FDG':
  299. # Giving the warning is a small bug, see gh-8000
  300. sup.filter(np.ComplexWarning)
  301. tgt = rf(mat, dtype=c, axis=1).dtype.type
  302. res = nf(mat, dtype=c, axis=1).dtype.type
  303. assert_(res is tgt)
  304. # scalar case
  305. tgt = rf(mat, dtype=c, axis=None).dtype.type
  306. res = nf(mat, dtype=c, axis=None).dtype.type
  307. assert_(res is tgt)
  308. def test_dtype_from_input(self):
  309. codes = 'efdgFDG'
  310. for nf, rf in zip(self.nanfuncs, self.stdfuncs):
  311. for c in codes:
  312. mat = np.eye(3, dtype=c)
  313. tgt = rf(mat, axis=1).dtype.type
  314. res = nf(mat, axis=1).dtype.type
  315. assert_(res is tgt, "res %s, tgt %s" % (res, tgt))
  316. # scalar case
  317. tgt = rf(mat, axis=None).dtype.type
  318. res = nf(mat, axis=None).dtype.type
  319. assert_(res is tgt)
  320. def test_result_values(self):
  321. for nf, rf in zip(self.nanfuncs, self.stdfuncs):
  322. tgt = [rf(d) for d in _rdat]
  323. res = nf(_ndat, axis=1)
  324. assert_almost_equal(res, tgt)
  325. def test_scalar(self):
  326. for f in self.nanfuncs:
  327. assert_(f(0.) == 0.)
  328. def test_subclass(self):
  329. class MyNDArray(np.ndarray):
  330. pass
  331. # Check that it works and that type and
  332. # shape are preserved
  333. array = np.eye(3)
  334. mine = array.view(MyNDArray)
  335. for f in self.nanfuncs:
  336. expected_shape = f(array, axis=0).shape
  337. res = f(mine, axis=0)
  338. assert_(isinstance(res, MyNDArray))
  339. assert_(res.shape == expected_shape)
  340. expected_shape = f(array, axis=1).shape
  341. res = f(mine, axis=1)
  342. assert_(isinstance(res, MyNDArray))
  343. assert_(res.shape == expected_shape)
  344. expected_shape = f(array).shape
  345. res = f(mine)
  346. assert_(isinstance(res, MyNDArray))
  347. assert_(res.shape == expected_shape)
  348. class TestNanFunctions_SumProd(SharedNanFunctionsTestsMixin):
  349. nanfuncs = [np.nansum, np.nanprod]
  350. stdfuncs = [np.sum, np.prod]
  351. def test_allnans(self):
  352. # Check for FutureWarning
  353. with warnings.catch_warnings(record=True) as w:
  354. warnings.simplefilter('always')
  355. res = np.nansum([np.nan]*3, axis=None)
  356. assert_(res == 0, 'result is not 0')
  357. assert_(len(w) == 0, 'warning raised')
  358. # Check scalar
  359. res = np.nansum(np.nan)
  360. assert_(res == 0, 'result is not 0')
  361. assert_(len(w) == 0, 'warning raised')
  362. # Check there is no warning for not all-nan
  363. np.nansum([0]*3, axis=None)
  364. assert_(len(w) == 0, 'unwanted warning raised')
  365. def test_empty(self):
  366. for f, tgt_value in zip([np.nansum, np.nanprod], [0, 1]):
  367. mat = np.zeros((0, 3))
  368. tgt = [tgt_value]*3
  369. res = f(mat, axis=0)
  370. assert_equal(res, tgt)
  371. tgt = []
  372. res = f(mat, axis=1)
  373. assert_equal(res, tgt)
  374. tgt = tgt_value
  375. res = f(mat, axis=None)
  376. assert_equal(res, tgt)
  377. class TestNanFunctions_CumSumProd(SharedNanFunctionsTestsMixin):
  378. nanfuncs = [np.nancumsum, np.nancumprod]
  379. stdfuncs = [np.cumsum, np.cumprod]
  380. def test_allnans(self):
  381. for f, tgt_value in zip(self.nanfuncs, [0, 1]):
  382. # Unlike other nan-functions, sum/prod/cumsum/cumprod don't warn on all nan input
  383. with assert_no_warnings():
  384. res = f([np.nan]*3, axis=None)
  385. tgt = tgt_value*np.ones((3))
  386. assert_(np.array_equal(res, tgt), 'result is not %s * np.ones((3))' % (tgt_value))
  387. # Check scalar
  388. res = f(np.nan)
  389. tgt = tgt_value*np.ones((1))
  390. assert_(np.array_equal(res, tgt), 'result is not %s * np.ones((1))' % (tgt_value))
  391. # Check there is no warning for not all-nan
  392. f([0]*3, axis=None)
  393. def test_empty(self):
  394. for f, tgt_value in zip(self.nanfuncs, [0, 1]):
  395. mat = np.zeros((0, 3))
  396. tgt = tgt_value*np.ones((0, 3))
  397. res = f(mat, axis=0)
  398. assert_equal(res, tgt)
  399. tgt = mat
  400. res = f(mat, axis=1)
  401. assert_equal(res, tgt)
  402. tgt = np.zeros((0))
  403. res = f(mat, axis=None)
  404. assert_equal(res, tgt)
  405. def test_keepdims(self):
  406. for f, g in zip(self.nanfuncs, self.stdfuncs):
  407. mat = np.eye(3)
  408. for axis in [None, 0, 1]:
  409. tgt = f(mat, axis=axis, out=None)
  410. res = g(mat, axis=axis, out=None)
  411. assert_(res.ndim == tgt.ndim)
  412. for f in self.nanfuncs:
  413. d = np.ones((3, 5, 7, 11))
  414. # Randomly set some elements to NaN:
  415. rs = np.random.RandomState(0)
  416. d[rs.rand(*d.shape) < 0.5] = np.nan
  417. res = f(d, axis=None)
  418. assert_equal(res.shape, (1155,))
  419. for axis in np.arange(4):
  420. res = f(d, axis=axis)
  421. assert_equal(res.shape, (3, 5, 7, 11))
  422. def test_result_values(self):
  423. for axis in (-2, -1, 0, 1, None):
  424. tgt = np.cumprod(_ndat_ones, axis=axis)
  425. res = np.nancumprod(_ndat, axis=axis)
  426. assert_almost_equal(res, tgt)
  427. tgt = np.cumsum(_ndat_zeros,axis=axis)
  428. res = np.nancumsum(_ndat, axis=axis)
  429. assert_almost_equal(res, tgt)
  430. def test_out(self):
  431. mat = np.eye(3)
  432. for nf, rf in zip(self.nanfuncs, self.stdfuncs):
  433. resout = np.eye(3)
  434. for axis in (-2, -1, 0, 1):
  435. tgt = rf(mat, axis=axis)
  436. res = nf(mat, axis=axis, out=resout)
  437. assert_almost_equal(res, resout)
  438. assert_almost_equal(res, tgt)
  439. class TestNanFunctions_MeanVarStd(SharedNanFunctionsTestsMixin):
  440. nanfuncs = [np.nanmean, np.nanvar, np.nanstd]
  441. stdfuncs = [np.mean, np.var, np.std]
  442. def test_dtype_error(self):
  443. for f in self.nanfuncs:
  444. for dtype in [np.bool_, np.int_, np.object_]:
  445. assert_raises(TypeError, f, _ndat, axis=1, dtype=dtype)
  446. def test_out_dtype_error(self):
  447. for f in self.nanfuncs:
  448. for dtype in [np.bool_, np.int_, np.object_]:
  449. out = np.empty(_ndat.shape[0], dtype=dtype)
  450. assert_raises(TypeError, f, _ndat, axis=1, out=out)
  451. def test_ddof(self):
  452. nanfuncs = [np.nanvar, np.nanstd]
  453. stdfuncs = [np.var, np.std]
  454. for nf, rf in zip(nanfuncs, stdfuncs):
  455. for ddof in [0, 1]:
  456. tgt = [rf(d, ddof=ddof) for d in _rdat]
  457. res = nf(_ndat, axis=1, ddof=ddof)
  458. assert_almost_equal(res, tgt)
  459. def test_ddof_too_big(self):
  460. nanfuncs = [np.nanvar, np.nanstd]
  461. stdfuncs = [np.var, np.std]
  462. dsize = [len(d) for d in _rdat]
  463. for nf, rf in zip(nanfuncs, stdfuncs):
  464. for ddof in range(5):
  465. with suppress_warnings() as sup:
  466. sup.record(RuntimeWarning)
  467. sup.filter(np.ComplexWarning)
  468. tgt = [ddof >= d for d in dsize]
  469. res = nf(_ndat, axis=1, ddof=ddof)
  470. assert_equal(np.isnan(res), tgt)
  471. if any(tgt):
  472. assert_(len(sup.log) == 1)
  473. else:
  474. assert_(len(sup.log) == 0)
  475. def test_allnans(self):
  476. mat = np.array([np.nan]*9).reshape(3, 3)
  477. for f in self.nanfuncs:
  478. for axis in [None, 0, 1]:
  479. with warnings.catch_warnings(record=True) as w:
  480. warnings.simplefilter('always')
  481. assert_(np.isnan(f(mat, axis=axis)).all())
  482. assert_(len(w) == 1)
  483. assert_(issubclass(w[0].category, RuntimeWarning))
  484. # Check scalar
  485. assert_(np.isnan(f(np.nan)))
  486. assert_(len(w) == 2)
  487. assert_(issubclass(w[0].category, RuntimeWarning))
  488. def test_empty(self):
  489. mat = np.zeros((0, 3))
  490. for f in self.nanfuncs:
  491. for axis in [0, None]:
  492. with warnings.catch_warnings(record=True) as w:
  493. warnings.simplefilter('always')
  494. assert_(np.isnan(f(mat, axis=axis)).all())
  495. assert_(len(w) == 1)
  496. assert_(issubclass(w[0].category, RuntimeWarning))
  497. for axis in [1]:
  498. with warnings.catch_warnings(record=True) as w:
  499. warnings.simplefilter('always')
  500. assert_equal(f(mat, axis=axis), np.zeros([]))
  501. assert_(len(w) == 0)
  502. _TIME_UNITS = (
  503. "Y", "M", "W", "D", "h", "m", "s", "ms", "us", "ns", "ps", "fs", "as"
  504. )
  505. # All `inexact` + `timdelta64` type codes
  506. _TYPE_CODES = list(np.typecodes["AllFloat"])
  507. _TYPE_CODES += [f"m8[{unit}]" for unit in _TIME_UNITS]
  508. class TestNanFunctions_Median:
  509. def test_mutation(self):
  510. # Check that passed array is not modified.
  511. ndat = _ndat.copy()
  512. np.nanmedian(ndat)
  513. assert_equal(ndat, _ndat)
  514. def test_keepdims(self):
  515. mat = np.eye(3)
  516. for axis in [None, 0, 1]:
  517. tgt = np.median(mat, axis=axis, out=None, overwrite_input=False)
  518. res = np.nanmedian(mat, axis=axis, out=None, overwrite_input=False)
  519. assert_(res.ndim == tgt.ndim)
  520. d = np.ones((3, 5, 7, 11))
  521. # Randomly set some elements to NaN:
  522. w = np.random.random((4, 200)) * np.array(d.shape)[:, None]
  523. w = w.astype(np.intp)
  524. d[tuple(w)] = np.nan
  525. with suppress_warnings() as sup:
  526. sup.filter(RuntimeWarning)
  527. res = np.nanmedian(d, axis=None, keepdims=True)
  528. assert_equal(res.shape, (1, 1, 1, 1))
  529. res = np.nanmedian(d, axis=(0, 1), keepdims=True)
  530. assert_equal(res.shape, (1, 1, 7, 11))
  531. res = np.nanmedian(d, axis=(0, 3), keepdims=True)
  532. assert_equal(res.shape, (1, 5, 7, 1))
  533. res = np.nanmedian(d, axis=(1,), keepdims=True)
  534. assert_equal(res.shape, (3, 1, 7, 11))
  535. res = np.nanmedian(d, axis=(0, 1, 2, 3), keepdims=True)
  536. assert_equal(res.shape, (1, 1, 1, 1))
  537. res = np.nanmedian(d, axis=(0, 1, 3), keepdims=True)
  538. assert_equal(res.shape, (1, 1, 7, 1))
  539. def test_out(self):
  540. mat = np.random.rand(3, 3)
  541. nan_mat = np.insert(mat, [0, 2], np.nan, axis=1)
  542. resout = np.zeros(3)
  543. tgt = np.median(mat, axis=1)
  544. res = np.nanmedian(nan_mat, axis=1, out=resout)
  545. assert_almost_equal(res, resout)
  546. assert_almost_equal(res, tgt)
  547. # 0-d output:
  548. resout = np.zeros(())
  549. tgt = np.median(mat, axis=None)
  550. res = np.nanmedian(nan_mat, axis=None, out=resout)
  551. assert_almost_equal(res, resout)
  552. assert_almost_equal(res, tgt)
  553. res = np.nanmedian(nan_mat, axis=(0, 1), out=resout)
  554. assert_almost_equal(res, resout)
  555. assert_almost_equal(res, tgt)
  556. def test_small_large(self):
  557. # test the small and large code paths, current cutoff 400 elements
  558. for s in [5, 20, 51, 200, 1000]:
  559. d = np.random.randn(4, s)
  560. # Randomly set some elements to NaN:
  561. w = np.random.randint(0, d.size, size=d.size // 5)
  562. d.ravel()[w] = np.nan
  563. d[:,0] = 1. # ensure at least one good value
  564. # use normal median without nans to compare
  565. tgt = []
  566. for x in d:
  567. nonan = np.compress(~np.isnan(x), x)
  568. tgt.append(np.median(nonan, overwrite_input=True))
  569. assert_array_equal(np.nanmedian(d, axis=-1), tgt)
  570. def test_result_values(self):
  571. tgt = [np.median(d) for d in _rdat]
  572. res = np.nanmedian(_ndat, axis=1)
  573. assert_almost_equal(res, tgt)
  574. @pytest.mark.parametrize("axis", [None, 0, 1])
  575. @pytest.mark.parametrize("dtype", _TYPE_CODES)
  576. def test_allnans(self, dtype, axis):
  577. mat = np.full((3, 3), np.nan).astype(dtype)
  578. with suppress_warnings() as sup:
  579. sup.record(RuntimeWarning)
  580. output = np.nanmedian(mat, axis=axis)
  581. assert output.dtype == mat.dtype
  582. assert np.isnan(output).all()
  583. if axis is None:
  584. assert_(len(sup.log) == 1)
  585. else:
  586. assert_(len(sup.log) == 3)
  587. # Check scalar
  588. scalar = np.array(np.nan).astype(dtype)[()]
  589. output_scalar = np.nanmedian(scalar)
  590. assert output_scalar.dtype == scalar.dtype
  591. assert np.isnan(output_scalar)
  592. if axis is None:
  593. assert_(len(sup.log) == 2)
  594. else:
  595. assert_(len(sup.log) == 4)
  596. def test_empty(self):
  597. mat = np.zeros((0, 3))
  598. for axis in [0, None]:
  599. with warnings.catch_warnings(record=True) as w:
  600. warnings.simplefilter('always')
  601. assert_(np.isnan(np.nanmedian(mat, axis=axis)).all())
  602. assert_(len(w) == 1)
  603. assert_(issubclass(w[0].category, RuntimeWarning))
  604. for axis in [1]:
  605. with warnings.catch_warnings(record=True) as w:
  606. warnings.simplefilter('always')
  607. assert_equal(np.nanmedian(mat, axis=axis), np.zeros([]))
  608. assert_(len(w) == 0)
  609. def test_scalar(self):
  610. assert_(np.nanmedian(0.) == 0.)
  611. def test_extended_axis_invalid(self):
  612. d = np.ones((3, 5, 7, 11))
  613. assert_raises(np.AxisError, np.nanmedian, d, axis=-5)
  614. assert_raises(np.AxisError, np.nanmedian, d, axis=(0, -5))
  615. assert_raises(np.AxisError, np.nanmedian, d, axis=4)
  616. assert_raises(np.AxisError, np.nanmedian, d, axis=(0, 4))
  617. assert_raises(ValueError, np.nanmedian, d, axis=(1, 1))
  618. def test_float_special(self):
  619. with suppress_warnings() as sup:
  620. sup.filter(RuntimeWarning)
  621. for inf in [np.inf, -np.inf]:
  622. a = np.array([[inf, np.nan], [np.nan, np.nan]])
  623. assert_equal(np.nanmedian(a, axis=0), [inf, np.nan])
  624. assert_equal(np.nanmedian(a, axis=1), [inf, np.nan])
  625. assert_equal(np.nanmedian(a), inf)
  626. # minimum fill value check
  627. a = np.array([[np.nan, np.nan, inf],
  628. [np.nan, np.nan, inf]])
  629. assert_equal(np.nanmedian(a), inf)
  630. assert_equal(np.nanmedian(a, axis=0), [np.nan, np.nan, inf])
  631. assert_equal(np.nanmedian(a, axis=1), inf)
  632. # no mask path
  633. a = np.array([[inf, inf], [inf, inf]])
  634. assert_equal(np.nanmedian(a, axis=1), inf)
  635. a = np.array([[inf, 7, -inf, -9],
  636. [-10, np.nan, np.nan, 5],
  637. [4, np.nan, np.nan, inf]],
  638. dtype=np.float32)
  639. if inf > 0:
  640. assert_equal(np.nanmedian(a, axis=0), [4., 7., -inf, 5.])
  641. assert_equal(np.nanmedian(a), 4.5)
  642. else:
  643. assert_equal(np.nanmedian(a, axis=0), [-10., 7., -inf, -9.])
  644. assert_equal(np.nanmedian(a), -2.5)
  645. assert_equal(np.nanmedian(a, axis=-1), [-1., -2.5, inf])
  646. for i in range(0, 10):
  647. for j in range(1, 10):
  648. a = np.array([([np.nan] * i) + ([inf] * j)] * 2)
  649. assert_equal(np.nanmedian(a), inf)
  650. assert_equal(np.nanmedian(a, axis=1), inf)
  651. assert_equal(np.nanmedian(a, axis=0),
  652. ([np.nan] * i) + [inf] * j)
  653. a = np.array([([np.nan] * i) + ([-inf] * j)] * 2)
  654. assert_equal(np.nanmedian(a), -inf)
  655. assert_equal(np.nanmedian(a, axis=1), -inf)
  656. assert_equal(np.nanmedian(a, axis=0),
  657. ([np.nan] * i) + [-inf] * j)
  658. class TestNanFunctions_Percentile:
  659. def test_mutation(self):
  660. # Check that passed array is not modified.
  661. ndat = _ndat.copy()
  662. np.nanpercentile(ndat, 30)
  663. assert_equal(ndat, _ndat)
  664. def test_keepdims(self):
  665. mat = np.eye(3)
  666. for axis in [None, 0, 1]:
  667. tgt = np.percentile(mat, 70, axis=axis, out=None,
  668. overwrite_input=False)
  669. res = np.nanpercentile(mat, 70, axis=axis, out=None,
  670. overwrite_input=False)
  671. assert_(res.ndim == tgt.ndim)
  672. d = np.ones((3, 5, 7, 11))
  673. # Randomly set some elements to NaN:
  674. w = np.random.random((4, 200)) * np.array(d.shape)[:, None]
  675. w = w.astype(np.intp)
  676. d[tuple(w)] = np.nan
  677. with suppress_warnings() as sup:
  678. sup.filter(RuntimeWarning)
  679. res = np.nanpercentile(d, 90, axis=None, keepdims=True)
  680. assert_equal(res.shape, (1, 1, 1, 1))
  681. res = np.nanpercentile(d, 90, axis=(0, 1), keepdims=True)
  682. assert_equal(res.shape, (1, 1, 7, 11))
  683. res = np.nanpercentile(d, 90, axis=(0, 3), keepdims=True)
  684. assert_equal(res.shape, (1, 5, 7, 1))
  685. res = np.nanpercentile(d, 90, axis=(1,), keepdims=True)
  686. assert_equal(res.shape, (3, 1, 7, 11))
  687. res = np.nanpercentile(d, 90, axis=(0, 1, 2, 3), keepdims=True)
  688. assert_equal(res.shape, (1, 1, 1, 1))
  689. res = np.nanpercentile(d, 90, axis=(0, 1, 3), keepdims=True)
  690. assert_equal(res.shape, (1, 1, 7, 1))
  691. def test_out(self):
  692. mat = np.random.rand(3, 3)
  693. nan_mat = np.insert(mat, [0, 2], np.nan, axis=1)
  694. resout = np.zeros(3)
  695. tgt = np.percentile(mat, 42, axis=1)
  696. res = np.nanpercentile(nan_mat, 42, axis=1, out=resout)
  697. assert_almost_equal(res, resout)
  698. assert_almost_equal(res, tgt)
  699. # 0-d output:
  700. resout = np.zeros(())
  701. tgt = np.percentile(mat, 42, axis=None)
  702. res = np.nanpercentile(nan_mat, 42, axis=None, out=resout)
  703. assert_almost_equal(res, resout)
  704. assert_almost_equal(res, tgt)
  705. res = np.nanpercentile(nan_mat, 42, axis=(0, 1), out=resout)
  706. assert_almost_equal(res, resout)
  707. assert_almost_equal(res, tgt)
  708. def test_result_values(self):
  709. tgt = [np.percentile(d, 28) for d in _rdat]
  710. res = np.nanpercentile(_ndat, 28, axis=1)
  711. assert_almost_equal(res, tgt)
  712. # Transpose the array to fit the output convention of numpy.percentile
  713. tgt = np.transpose([np.percentile(d, (28, 98)) for d in _rdat])
  714. res = np.nanpercentile(_ndat, (28, 98), axis=1)
  715. assert_almost_equal(res, tgt)
  716. def test_allnans(self):
  717. mat = np.array([np.nan]*9).reshape(3, 3)
  718. for axis in [None, 0, 1]:
  719. with warnings.catch_warnings(record=True) as w:
  720. warnings.simplefilter('always')
  721. assert_(np.isnan(np.nanpercentile(mat, 60, axis=axis)).all())
  722. if axis is None:
  723. assert_(len(w) == 1)
  724. else:
  725. assert_(len(w) == 3)
  726. assert_(issubclass(w[0].category, RuntimeWarning))
  727. # Check scalar
  728. assert_(np.isnan(np.nanpercentile(np.nan, 60)))
  729. if axis is None:
  730. assert_(len(w) == 2)
  731. else:
  732. assert_(len(w) == 4)
  733. assert_(issubclass(w[0].category, RuntimeWarning))
  734. def test_empty(self):
  735. mat = np.zeros((0, 3))
  736. for axis in [0, None]:
  737. with warnings.catch_warnings(record=True) as w:
  738. warnings.simplefilter('always')
  739. assert_(np.isnan(np.nanpercentile(mat, 40, axis=axis)).all())
  740. assert_(len(w) == 1)
  741. assert_(issubclass(w[0].category, RuntimeWarning))
  742. for axis in [1]:
  743. with warnings.catch_warnings(record=True) as w:
  744. warnings.simplefilter('always')
  745. assert_equal(np.nanpercentile(mat, 40, axis=axis), np.zeros([]))
  746. assert_(len(w) == 0)
  747. def test_scalar(self):
  748. assert_equal(np.nanpercentile(0., 100), 0.)
  749. a = np.arange(6)
  750. r = np.nanpercentile(a, 50, axis=0)
  751. assert_equal(r, 2.5)
  752. assert_(np.isscalar(r))
  753. def test_extended_axis_invalid(self):
  754. d = np.ones((3, 5, 7, 11))
  755. assert_raises(np.AxisError, np.nanpercentile, d, q=5, axis=-5)
  756. assert_raises(np.AxisError, np.nanpercentile, d, q=5, axis=(0, -5))
  757. assert_raises(np.AxisError, np.nanpercentile, d, q=5, axis=4)
  758. assert_raises(np.AxisError, np.nanpercentile, d, q=5, axis=(0, 4))
  759. assert_raises(ValueError, np.nanpercentile, d, q=5, axis=(1, 1))
  760. def test_multiple_percentiles(self):
  761. perc = [50, 100]
  762. mat = np.ones((4, 3))
  763. nan_mat = np.nan * mat
  764. # For checking consistency in higher dimensional case
  765. large_mat = np.ones((3, 4, 5))
  766. large_mat[:, 0:2:4, :] = 0
  767. large_mat[:, :, 3:] *= 2
  768. for axis in [None, 0, 1]:
  769. for keepdim in [False, True]:
  770. with suppress_warnings() as sup:
  771. sup.filter(RuntimeWarning, "All-NaN slice encountered")
  772. val = np.percentile(mat, perc, axis=axis, keepdims=keepdim)
  773. nan_val = np.nanpercentile(nan_mat, perc, axis=axis,
  774. keepdims=keepdim)
  775. assert_equal(nan_val.shape, val.shape)
  776. val = np.percentile(large_mat, perc, axis=axis,
  777. keepdims=keepdim)
  778. nan_val = np.nanpercentile(large_mat, perc, axis=axis,
  779. keepdims=keepdim)
  780. assert_equal(nan_val, val)
  781. megamat = np.ones((3, 4, 5, 6))
  782. assert_equal(np.nanpercentile(megamat, perc, axis=(1, 2)).shape, (2, 3, 6))
  783. class TestNanFunctions_Quantile:
  784. # most of this is already tested by TestPercentile
  785. def test_regression(self):
  786. ar = np.arange(24).reshape(2, 3, 4).astype(float)
  787. ar[0][1] = np.nan
  788. assert_equal(np.nanquantile(ar, q=0.5), np.nanpercentile(ar, q=50))
  789. assert_equal(np.nanquantile(ar, q=0.5, axis=0),
  790. np.nanpercentile(ar, q=50, axis=0))
  791. assert_equal(np.nanquantile(ar, q=0.5, axis=1),
  792. np.nanpercentile(ar, q=50, axis=1))
  793. assert_equal(np.nanquantile(ar, q=[0.5], axis=1),
  794. np.nanpercentile(ar, q=[50], axis=1))
  795. assert_equal(np.nanquantile(ar, q=[0.25, 0.5, 0.75], axis=1),
  796. np.nanpercentile(ar, q=[25, 50, 75], axis=1))
  797. def test_basic(self):
  798. x = np.arange(8) * 0.5
  799. assert_equal(np.nanquantile(x, 0), 0.)
  800. assert_equal(np.nanquantile(x, 1), 3.5)
  801. assert_equal(np.nanquantile(x, 0.5), 1.75)
  802. def test_no_p_overwrite(self):
  803. # this is worth retesting, because quantile does not make a copy
  804. p0 = np.array([0, 0.75, 0.25, 0.5, 1.0])
  805. p = p0.copy()
  806. np.nanquantile(np.arange(100.), p, interpolation="midpoint")
  807. assert_array_equal(p, p0)
  808. p0 = p0.tolist()
  809. p = p.tolist()
  810. np.nanquantile(np.arange(100.), p, interpolation="midpoint")
  811. assert_array_equal(p, p0)
  812. @pytest.mark.parametrize("arr, expected", [
  813. # array of floats with some nans
  814. (np.array([np.nan, 5.0, np.nan, np.inf]),
  815. np.array([False, True, False, True])),
  816. # int64 array that can't possibly have nans
  817. (np.array([1, 5, 7, 9], dtype=np.int64),
  818. True),
  819. # bool array that can't possibly have nans
  820. (np.array([False, True, False, True]),
  821. True),
  822. # 2-D complex array with nans
  823. (np.array([[np.nan, 5.0],
  824. [np.nan, np.inf]], dtype=np.complex64),
  825. np.array([[False, True],
  826. [False, True]])),
  827. ])
  828. def test__nan_mask(arr, expected):
  829. for out in [None, np.empty(arr.shape, dtype=np.bool_)]:
  830. actual = _nan_mask(arr, out=out)
  831. assert_equal(actual, expected)
  832. # the above won't distinguish between True proper
  833. # and an array of True values; we want True proper
  834. # for types that can't possibly contain NaN
  835. if type(expected) is not np.ndarray:
  836. assert actual is True
  837. def test__replace_nan():
  838. """ Test that _replace_nan returns the original array if there are no
  839. NaNs, not a copy.
  840. """
  841. for dtype in [np.bool_, np.int32, np.int64]:
  842. arr = np.array([0, 1], dtype=dtype)
  843. result, mask = _replace_nan(arr, 0)
  844. assert mask is None
  845. # do not make a copy if there are no nans
  846. assert result is arr
  847. for dtype in [np.float32, np.float64]:
  848. arr = np.array([0, 1], dtype=dtype)
  849. result, mask = _replace_nan(arr, 2)
  850. assert (mask == False).all()
  851. # mask is not None, so we make a copy
  852. assert result is not arr
  853. assert_equal(result, arr)
  854. arr_nan = np.array([0, 1, np.nan], dtype=dtype)
  855. result_nan, mask_nan = _replace_nan(arr_nan, 2)
  856. assert_equal(mask_nan, np.array([False, False, True]))
  857. assert result_nan is not arr_nan
  858. assert_equal(result_nan, np.array([0, 1, 2]))
  859. assert np.isnan(arr_nan[-1])