m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

46 lines
1.3 KiB

7 months ago
  1. from operator import mul
  2. from functools import reduce
  3. import numpy as np
  4. from numpy.random import randint
  5. from numpy.lib import Arrayterator
  6. from numpy.testing import assert_
  7. def test():
  8. np.random.seed(np.arange(10))
  9. # Create a random array
  10. ndims = randint(5)+1
  11. shape = tuple(randint(10)+1 for dim in range(ndims))
  12. els = reduce(mul, shape)
  13. a = np.arange(els)
  14. a.shape = shape
  15. buf_size = randint(2*els)
  16. b = Arrayterator(a, buf_size)
  17. # Check that each block has at most ``buf_size`` elements
  18. for block in b:
  19. assert_(len(block.flat) <= (buf_size or els))
  20. # Check that all elements are iterated correctly
  21. assert_(list(b.flat) == list(a.flat))
  22. # Slice arrayterator
  23. start = [randint(dim) for dim in shape]
  24. stop = [randint(dim)+1 for dim in shape]
  25. step = [randint(dim)+1 for dim in shape]
  26. slice_ = tuple(slice(*t) for t in zip(start, stop, step))
  27. c = b[slice_]
  28. d = a[slice_]
  29. # Check that each block has at most ``buf_size`` elements
  30. for block in c:
  31. assert_(len(block.flat) <= (buf_size or els))
  32. # Check that the arrayterator is sliced correctly
  33. assert_(np.all(c.__array__() == d))
  34. # Check that all elements are iterated correctly
  35. assert_(list(c.flat) == list(d.flat))