m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

518 lines
18 KiB

7 months ago
  1. import pytest
  2. import numpy as np
  3. from numpy.testing import (
  4. assert_, assert_equal, assert_array_equal, assert_almost_equal,
  5. assert_array_almost_equal, assert_raises, assert_raises_regex,
  6. assert_warns
  7. )
  8. from numpy.lib.index_tricks import (
  9. mgrid, ogrid, ndenumerate, fill_diagonal, diag_indices, diag_indices_from,
  10. index_exp, ndindex, r_, s_, ix_
  11. )
  12. class TestRavelUnravelIndex:
  13. def test_basic(self):
  14. assert_equal(np.unravel_index(2, (2, 2)), (1, 0))
  15. # test that new shape argument works properly
  16. assert_equal(np.unravel_index(indices=2,
  17. shape=(2, 2)),
  18. (1, 0))
  19. # test that an invalid second keyword argument
  20. # is properly handled, including the old name `dims`.
  21. with assert_raises(TypeError):
  22. np.unravel_index(indices=2, hape=(2, 2))
  23. with assert_raises(TypeError):
  24. np.unravel_index(2, hape=(2, 2))
  25. with assert_raises(TypeError):
  26. np.unravel_index(254, ims=(17, 94))
  27. with assert_raises(TypeError):
  28. np.unravel_index(254, dims=(17, 94))
  29. assert_equal(np.ravel_multi_index((1, 0), (2, 2)), 2)
  30. assert_equal(np.unravel_index(254, (17, 94)), (2, 66))
  31. assert_equal(np.ravel_multi_index((2, 66), (17, 94)), 254)
  32. assert_raises(ValueError, np.unravel_index, -1, (2, 2))
  33. assert_raises(TypeError, np.unravel_index, 0.5, (2, 2))
  34. assert_raises(ValueError, np.unravel_index, 4, (2, 2))
  35. assert_raises(ValueError, np.ravel_multi_index, (-3, 1), (2, 2))
  36. assert_raises(ValueError, np.ravel_multi_index, (2, 1), (2, 2))
  37. assert_raises(ValueError, np.ravel_multi_index, (0, -3), (2, 2))
  38. assert_raises(ValueError, np.ravel_multi_index, (0, 2), (2, 2))
  39. assert_raises(TypeError, np.ravel_multi_index, (0.1, 0.), (2, 2))
  40. assert_equal(np.unravel_index((2*3 + 1)*6 + 4, (4, 3, 6)), [2, 1, 4])
  41. assert_equal(
  42. np.ravel_multi_index([2, 1, 4], (4, 3, 6)), (2*3 + 1)*6 + 4)
  43. arr = np.array([[3, 6, 6], [4, 5, 1]])
  44. assert_equal(np.ravel_multi_index(arr, (7, 6)), [22, 41, 37])
  45. assert_equal(
  46. np.ravel_multi_index(arr, (7, 6), order='F'), [31, 41, 13])
  47. assert_equal(
  48. np.ravel_multi_index(arr, (4, 6), mode='clip'), [22, 23, 19])
  49. assert_equal(np.ravel_multi_index(arr, (4, 4), mode=('clip', 'wrap')),
  50. [12, 13, 13])
  51. assert_equal(np.ravel_multi_index((3, 1, 4, 1), (6, 7, 8, 9)), 1621)
  52. assert_equal(np.unravel_index(np.array([22, 41, 37]), (7, 6)),
  53. [[3, 6, 6], [4, 5, 1]])
  54. assert_equal(
  55. np.unravel_index(np.array([31, 41, 13]), (7, 6), order='F'),
  56. [[3, 6, 6], [4, 5, 1]])
  57. assert_equal(np.unravel_index(1621, (6, 7, 8, 9)), [3, 1, 4, 1])
  58. def test_empty_indices(self):
  59. msg1 = 'indices must be integral: the provided empty sequence was'
  60. msg2 = 'only int indices permitted'
  61. assert_raises_regex(TypeError, msg1, np.unravel_index, [], (10, 3, 5))
  62. assert_raises_regex(TypeError, msg1, np.unravel_index, (), (10, 3, 5))
  63. assert_raises_regex(TypeError, msg2, np.unravel_index, np.array([]),
  64. (10, 3, 5))
  65. assert_equal(np.unravel_index(np.array([],dtype=int), (10, 3, 5)),
  66. [[], [], []])
  67. assert_raises_regex(TypeError, msg1, np.ravel_multi_index, ([], []),
  68. (10, 3))
  69. assert_raises_regex(TypeError, msg1, np.ravel_multi_index, ([], ['abc']),
  70. (10, 3))
  71. assert_raises_regex(TypeError, msg2, np.ravel_multi_index,
  72. (np.array([]), np.array([])), (5, 3))
  73. assert_equal(np.ravel_multi_index(
  74. (np.array([], dtype=int), np.array([], dtype=int)), (5, 3)), [])
  75. assert_equal(np.ravel_multi_index(np.array([[], []], dtype=int),
  76. (5, 3)), [])
  77. def test_big_indices(self):
  78. # ravel_multi_index for big indices (issue #7546)
  79. if np.intp == np.int64:
  80. arr = ([1, 29], [3, 5], [3, 117], [19, 2],
  81. [2379, 1284], [2, 2], [0, 1])
  82. assert_equal(
  83. np.ravel_multi_index(arr, (41, 7, 120, 36, 2706, 8, 6)),
  84. [5627771580, 117259570957])
  85. # test unravel_index for big indices (issue #9538)
  86. assert_raises(ValueError, np.unravel_index, 1, (2**32-1, 2**31+1))
  87. # test overflow checking for too big array (issue #7546)
  88. dummy_arr = ([0],[0])
  89. half_max = np.iinfo(np.intp).max // 2
  90. assert_equal(
  91. np.ravel_multi_index(dummy_arr, (half_max, 2)), [0])
  92. assert_raises(ValueError,
  93. np.ravel_multi_index, dummy_arr, (half_max+1, 2))
  94. assert_equal(
  95. np.ravel_multi_index(dummy_arr, (half_max, 2), order='F'), [0])
  96. assert_raises(ValueError,
  97. np.ravel_multi_index, dummy_arr, (half_max+1, 2), order='F')
  98. def test_dtypes(self):
  99. # Test with different data types
  100. for dtype in [np.int16, np.uint16, np.int32,
  101. np.uint32, np.int64, np.uint64]:
  102. coords = np.array(
  103. [[1, 0, 1, 2, 3, 4], [1, 6, 1, 3, 2, 0]], dtype=dtype)
  104. shape = (5, 8)
  105. uncoords = 8*coords[0]+coords[1]
  106. assert_equal(np.ravel_multi_index(coords, shape), uncoords)
  107. assert_equal(coords, np.unravel_index(uncoords, shape))
  108. uncoords = coords[0]+5*coords[1]
  109. assert_equal(
  110. np.ravel_multi_index(coords, shape, order='F'), uncoords)
  111. assert_equal(coords, np.unravel_index(uncoords, shape, order='F'))
  112. coords = np.array(
  113. [[1, 0, 1, 2, 3, 4], [1, 6, 1, 3, 2, 0], [1, 3, 1, 0, 9, 5]],
  114. dtype=dtype)
  115. shape = (5, 8, 10)
  116. uncoords = 10*(8*coords[0]+coords[1])+coords[2]
  117. assert_equal(np.ravel_multi_index(coords, shape), uncoords)
  118. assert_equal(coords, np.unravel_index(uncoords, shape))
  119. uncoords = coords[0]+5*(coords[1]+8*coords[2])
  120. assert_equal(
  121. np.ravel_multi_index(coords, shape, order='F'), uncoords)
  122. assert_equal(coords, np.unravel_index(uncoords, shape, order='F'))
  123. def test_clipmodes(self):
  124. # Test clipmodes
  125. assert_equal(
  126. np.ravel_multi_index([5, 1, -1, 2], (4, 3, 7, 12), mode='wrap'),
  127. np.ravel_multi_index([1, 1, 6, 2], (4, 3, 7, 12)))
  128. assert_equal(np.ravel_multi_index([5, 1, -1, 2], (4, 3, 7, 12),
  129. mode=(
  130. 'wrap', 'raise', 'clip', 'raise')),
  131. np.ravel_multi_index([1, 1, 0, 2], (4, 3, 7, 12)))
  132. assert_raises(
  133. ValueError, np.ravel_multi_index, [5, 1, -1, 2], (4, 3, 7, 12))
  134. def test_writeability(self):
  135. # See gh-7269
  136. x, y = np.unravel_index([1, 2, 3], (4, 5))
  137. assert_(x.flags.writeable)
  138. assert_(y.flags.writeable)
  139. def test_0d(self):
  140. # gh-580
  141. x = np.unravel_index(0, ())
  142. assert_equal(x, ())
  143. assert_raises_regex(ValueError, "0d array", np.unravel_index, [0], ())
  144. assert_raises_regex(
  145. ValueError, "out of bounds", np.unravel_index, [1], ())
  146. @pytest.mark.parametrize("mode", ["clip", "wrap", "raise"])
  147. def test_empty_array_ravel(self, mode):
  148. res = np.ravel_multi_index(
  149. np.zeros((3, 0), dtype=np.intp), (2, 1, 0), mode=mode)
  150. assert(res.shape == (0,))
  151. with assert_raises(ValueError):
  152. np.ravel_multi_index(
  153. np.zeros((3, 1), dtype=np.intp), (2, 1, 0), mode=mode)
  154. def test_empty_array_unravel(self):
  155. res = np.unravel_index(np.zeros(0, dtype=np.intp), (2, 1, 0))
  156. # res is a tuple of three empty arrays
  157. assert(len(res) == 3)
  158. assert(all(a.shape == (0,) for a in res))
  159. with assert_raises(ValueError):
  160. np.unravel_index([1], (2, 1, 0))
  161. class TestGrid:
  162. def test_basic(self):
  163. a = mgrid[-1:1:10j]
  164. b = mgrid[-1:1:0.1]
  165. assert_(a.shape == (10,))
  166. assert_(b.shape == (20,))
  167. assert_(a[0] == -1)
  168. assert_almost_equal(a[-1], 1)
  169. assert_(b[0] == -1)
  170. assert_almost_equal(b[1]-b[0], 0.1, 11)
  171. assert_almost_equal(b[-1], b[0]+19*0.1, 11)
  172. assert_almost_equal(a[1]-a[0], 2.0/9.0, 11)
  173. def test_linspace_equivalence(self):
  174. y, st = np.linspace(2, 10, retstep=True)
  175. assert_almost_equal(st, 8/49.0)
  176. assert_array_almost_equal(y, mgrid[2:10:50j], 13)
  177. def test_nd(self):
  178. c = mgrid[-1:1:10j, -2:2:10j]
  179. d = mgrid[-1:1:0.1, -2:2:0.2]
  180. assert_(c.shape == (2, 10, 10))
  181. assert_(d.shape == (2, 20, 20))
  182. assert_array_equal(c[0][0, :], -np.ones(10, 'd'))
  183. assert_array_equal(c[1][:, 0], -2*np.ones(10, 'd'))
  184. assert_array_almost_equal(c[0][-1, :], np.ones(10, 'd'), 11)
  185. assert_array_almost_equal(c[1][:, -1], 2*np.ones(10, 'd'), 11)
  186. assert_array_almost_equal(d[0, 1, :] - d[0, 0, :],
  187. 0.1*np.ones(20, 'd'), 11)
  188. assert_array_almost_equal(d[1, :, 1] - d[1, :, 0],
  189. 0.2*np.ones(20, 'd'), 11)
  190. def test_sparse(self):
  191. grid_full = mgrid[-1:1:10j, -2:2:10j]
  192. grid_sparse = ogrid[-1:1:10j, -2:2:10j]
  193. # sparse grids can be made dense by broadcasting
  194. grid_broadcast = np.broadcast_arrays(*grid_sparse)
  195. for f, b in zip(grid_full, grid_broadcast):
  196. assert_equal(f, b)
  197. @pytest.mark.parametrize("start, stop, step, expected", [
  198. (None, 10, 10j, (200, 10)),
  199. (-10, 20, None, (1800, 30)),
  200. ])
  201. def test_mgrid_size_none_handling(self, start, stop, step, expected):
  202. # regression test None value handling for
  203. # start and step values used by mgrid;
  204. # internally, this aims to cover previously
  205. # unexplored code paths in nd_grid()
  206. grid = mgrid[start:stop:step, start:stop:step]
  207. # need a smaller grid to explore one of the
  208. # untested code paths
  209. grid_small = mgrid[start:stop:step]
  210. assert_equal(grid.size, expected[0])
  211. assert_equal(grid_small.size, expected[1])
  212. def test_accepts_npfloating(self):
  213. # regression test for #16466
  214. grid64 = mgrid[0.1:0.33:0.1, ]
  215. grid32 = mgrid[np.float32(0.1):np.float32(0.33):np.float32(0.1), ]
  216. assert_(grid32.dtype == np.float64)
  217. assert_array_almost_equal(grid64, grid32)
  218. # different code path for single slice
  219. grid64 = mgrid[0.1:0.33:0.1]
  220. grid32 = mgrid[np.float32(0.1):np.float32(0.33):np.float32(0.1)]
  221. assert_(grid32.dtype == np.float64)
  222. assert_array_almost_equal(grid64, grid32)
  223. def test_accepts_npcomplexfloating(self):
  224. # Related to #16466
  225. assert_array_almost_equal(
  226. mgrid[0.1:0.3:3j, ], mgrid[0.1:0.3:np.complex64(3j), ]
  227. )
  228. # different code path for single slice
  229. assert_array_almost_equal(
  230. mgrid[0.1:0.3:3j], mgrid[0.1:0.3:np.complex64(3j)]
  231. )
  232. class TestConcatenator:
  233. def test_1d(self):
  234. assert_array_equal(r_[1, 2, 3, 4, 5, 6], np.array([1, 2, 3, 4, 5, 6]))
  235. b = np.ones(5)
  236. c = r_[b, 0, 0, b]
  237. assert_array_equal(c, [1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1])
  238. def test_mixed_type(self):
  239. g = r_[10.1, 1:10]
  240. assert_(g.dtype == 'f8')
  241. def test_more_mixed_type(self):
  242. g = r_[-10.1, np.array([1]), np.array([2, 3, 4]), 10.0]
  243. assert_(g.dtype == 'f8')
  244. def test_complex_step(self):
  245. # Regression test for #12262
  246. g = r_[0:36:100j]
  247. assert_(g.shape == (100,))
  248. # Related to #16466
  249. g = r_[0:36:np.complex64(100j)]
  250. assert_(g.shape == (100,))
  251. def test_2d(self):
  252. b = np.random.rand(5, 5)
  253. c = np.random.rand(5, 5)
  254. d = r_['1', b, c] # append columns
  255. assert_(d.shape == (5, 10))
  256. assert_array_equal(d[:, :5], b)
  257. assert_array_equal(d[:, 5:], c)
  258. d = r_[b, c]
  259. assert_(d.shape == (10, 5))
  260. assert_array_equal(d[:5, :], b)
  261. assert_array_equal(d[5:, :], c)
  262. def test_0d(self):
  263. assert_equal(r_[0, np.array(1), 2], [0, 1, 2])
  264. assert_equal(r_[[0, 1, 2], np.array(3)], [0, 1, 2, 3])
  265. assert_equal(r_[np.array(0), [1, 2, 3]], [0, 1, 2, 3])
  266. class TestNdenumerate:
  267. def test_basic(self):
  268. a = np.array([[1, 2], [3, 4]])
  269. assert_equal(list(ndenumerate(a)),
  270. [((0, 0), 1), ((0, 1), 2), ((1, 0), 3), ((1, 1), 4)])
  271. class TestIndexExpression:
  272. def test_regression_1(self):
  273. # ticket #1196
  274. a = np.arange(2)
  275. assert_equal(a[:-1], a[s_[:-1]])
  276. assert_equal(a[:-1], a[index_exp[:-1]])
  277. def test_simple_1(self):
  278. a = np.random.rand(4, 5, 6)
  279. assert_equal(a[:, :3, [1, 2]], a[index_exp[:, :3, [1, 2]]])
  280. assert_equal(a[:, :3, [1, 2]], a[s_[:, :3, [1, 2]]])
  281. class TestIx_:
  282. def test_regression_1(self):
  283. # Test empty untyped inputs create outputs of indexing type, gh-5804
  284. a, = np.ix_(range(0))
  285. assert_equal(a.dtype, np.intp)
  286. a, = np.ix_([])
  287. assert_equal(a.dtype, np.intp)
  288. # but if the type is specified, don't change it
  289. a, = np.ix_(np.array([], dtype=np.float32))
  290. assert_equal(a.dtype, np.float32)
  291. def test_shape_and_dtype(self):
  292. sizes = (4, 5, 3, 2)
  293. # Test both lists and arrays
  294. for func in (range, np.arange):
  295. arrays = np.ix_(*[func(sz) for sz in sizes])
  296. for k, (a, sz) in enumerate(zip(arrays, sizes)):
  297. assert_equal(a.shape[k], sz)
  298. assert_(all(sh == 1 for j, sh in enumerate(a.shape) if j != k))
  299. assert_(np.issubdtype(a.dtype, np.integer))
  300. def test_bool(self):
  301. bool_a = [True, False, True, True]
  302. int_a, = np.nonzero(bool_a)
  303. assert_equal(np.ix_(bool_a)[0], int_a)
  304. def test_1d_only(self):
  305. idx2d = [[1, 2, 3], [4, 5, 6]]
  306. assert_raises(ValueError, np.ix_, idx2d)
  307. def test_repeated_input(self):
  308. length_of_vector = 5
  309. x = np.arange(length_of_vector)
  310. out = ix_(x, x)
  311. assert_equal(out[0].shape, (length_of_vector, 1))
  312. assert_equal(out[1].shape, (1, length_of_vector))
  313. # check that input shape is not modified
  314. assert_equal(x.shape, (length_of_vector,))
  315. def test_c_():
  316. a = np.c_[np.array([[1, 2, 3]]), 0, 0, np.array([[4, 5, 6]])]
  317. assert_equal(a, [[1, 2, 3, 0, 0, 4, 5, 6]])
  318. class TestFillDiagonal:
  319. def test_basic(self):
  320. a = np.zeros((3, 3), int)
  321. fill_diagonal(a, 5)
  322. assert_array_equal(
  323. a, np.array([[5, 0, 0],
  324. [0, 5, 0],
  325. [0, 0, 5]])
  326. )
  327. def test_tall_matrix(self):
  328. a = np.zeros((10, 3), int)
  329. fill_diagonal(a, 5)
  330. assert_array_equal(
  331. a, np.array([[5, 0, 0],
  332. [0, 5, 0],
  333. [0, 0, 5],
  334. [0, 0, 0],
  335. [0, 0, 0],
  336. [0, 0, 0],
  337. [0, 0, 0],
  338. [0, 0, 0],
  339. [0, 0, 0],
  340. [0, 0, 0]])
  341. )
  342. def test_tall_matrix_wrap(self):
  343. a = np.zeros((10, 3), int)
  344. fill_diagonal(a, 5, True)
  345. assert_array_equal(
  346. a, np.array([[5, 0, 0],
  347. [0, 5, 0],
  348. [0, 0, 5],
  349. [0, 0, 0],
  350. [5, 0, 0],
  351. [0, 5, 0],
  352. [0, 0, 5],
  353. [0, 0, 0],
  354. [5, 0, 0],
  355. [0, 5, 0]])
  356. )
  357. def test_wide_matrix(self):
  358. a = np.zeros((3, 10), int)
  359. fill_diagonal(a, 5)
  360. assert_array_equal(
  361. a, np.array([[5, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  362. [0, 5, 0, 0, 0, 0, 0, 0, 0, 0],
  363. [0, 0, 5, 0, 0, 0, 0, 0, 0, 0]])
  364. )
  365. def test_operate_4d_array(self):
  366. a = np.zeros((3, 3, 3, 3), int)
  367. fill_diagonal(a, 4)
  368. i = np.array([0, 1, 2])
  369. assert_equal(np.where(a != 0), (i, i, i, i))
  370. def test_low_dim_handling(self):
  371. # raise error with low dimensionality
  372. a = np.zeros(3, int)
  373. with assert_raises_regex(ValueError, "at least 2-d"):
  374. fill_diagonal(a, 5)
  375. def test_hetero_shape_handling(self):
  376. # raise error with high dimensionality and
  377. # shape mismatch
  378. a = np.zeros((3,3,7,3), int)
  379. with assert_raises_regex(ValueError, "equal length"):
  380. fill_diagonal(a, 2)
  381. def test_diag_indices():
  382. di = diag_indices(4)
  383. a = np.array([[1, 2, 3, 4],
  384. [5, 6, 7, 8],
  385. [9, 10, 11, 12],
  386. [13, 14, 15, 16]])
  387. a[di] = 100
  388. assert_array_equal(
  389. a, np.array([[100, 2, 3, 4],
  390. [5, 100, 7, 8],
  391. [9, 10, 100, 12],
  392. [13, 14, 15, 100]])
  393. )
  394. # Now, we create indices to manipulate a 3-d array:
  395. d3 = diag_indices(2, 3)
  396. # And use it to set the diagonal of a zeros array to 1:
  397. a = np.zeros((2, 2, 2), int)
  398. a[d3] = 1
  399. assert_array_equal(
  400. a, np.array([[[1, 0],
  401. [0, 0]],
  402. [[0, 0],
  403. [0, 1]]])
  404. )
  405. class TestDiagIndicesFrom:
  406. def test_diag_indices_from(self):
  407. x = np.random.random((4, 4))
  408. r, c = diag_indices_from(x)
  409. assert_array_equal(r, np.arange(4))
  410. assert_array_equal(c, np.arange(4))
  411. def test_error_small_input(self):
  412. x = np.ones(7)
  413. with assert_raises_regex(ValueError, "at least 2-d"):
  414. diag_indices_from(x)
  415. def test_error_shape_mismatch(self):
  416. x = np.zeros((3, 3, 2, 3), int)
  417. with assert_raises_regex(ValueError, "equal length"):
  418. diag_indices_from(x)
  419. def test_ndindex():
  420. x = list(ndindex(1, 2, 3))
  421. expected = [ix for ix, e in ndenumerate(np.zeros((1, 2, 3)))]
  422. assert_array_equal(x, expected)
  423. x = list(ndindex((1, 2, 3)))
  424. assert_array_equal(x, expected)
  425. # Test use of scalars and tuples
  426. x = list(ndindex((3,)))
  427. assert_array_equal(x, list(ndindex(3)))
  428. # Make sure size argument is optional
  429. x = list(ndindex())
  430. assert_equal(x, [()])
  431. x = list(ndindex(()))
  432. assert_equal(x, [()])
  433. # Make sure 0-sized ndindex works correctly
  434. x = list(ndindex(*[0]))
  435. assert_equal(x, [])