图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

104 lines
3.2 KiB

  1. import numpy as np
  2. import numpy.core as nx
  3. import numpy.lib.ufunclike as ufl
  4. from numpy.testing import (
  5. assert_, assert_equal, assert_array_equal, assert_warns, assert_raises
  6. )
  7. class TestUfunclike:
  8. def test_isposinf(self):
  9. a = nx.array([nx.inf, -nx.inf, nx.nan, 0.0, 3.0, -3.0])
  10. out = nx.zeros(a.shape, bool)
  11. tgt = nx.array([True, False, False, False, False, False])
  12. res = ufl.isposinf(a)
  13. assert_equal(res, tgt)
  14. res = ufl.isposinf(a, out)
  15. assert_equal(res, tgt)
  16. assert_equal(out, tgt)
  17. a = a.astype(np.complex_)
  18. with assert_raises(TypeError):
  19. ufl.isposinf(a)
  20. def test_isneginf(self):
  21. a = nx.array([nx.inf, -nx.inf, nx.nan, 0.0, 3.0, -3.0])
  22. out = nx.zeros(a.shape, bool)
  23. tgt = nx.array([False, True, False, False, False, False])
  24. res = ufl.isneginf(a)
  25. assert_equal(res, tgt)
  26. res = ufl.isneginf(a, out)
  27. assert_equal(res, tgt)
  28. assert_equal(out, tgt)
  29. a = a.astype(np.complex_)
  30. with assert_raises(TypeError):
  31. ufl.isneginf(a)
  32. def test_fix(self):
  33. a = nx.array([[1.0, 1.1, 1.5, 1.8], [-1.0, -1.1, -1.5, -1.8]])
  34. out = nx.zeros(a.shape, float)
  35. tgt = nx.array([[1., 1., 1., 1.], [-1., -1., -1., -1.]])
  36. res = ufl.fix(a)
  37. assert_equal(res, tgt)
  38. res = ufl.fix(a, out)
  39. assert_equal(res, tgt)
  40. assert_equal(out, tgt)
  41. assert_equal(ufl.fix(3.14), 3)
  42. def test_fix_with_subclass(self):
  43. class MyArray(nx.ndarray):
  44. def __new__(cls, data, metadata=None):
  45. res = nx.array(data, copy=True).view(cls)
  46. res.metadata = metadata
  47. return res
  48. def __array_wrap__(self, obj, context=None):
  49. if isinstance(obj, MyArray):
  50. obj.metadata = self.metadata
  51. return obj
  52. def __array_finalize__(self, obj):
  53. self.metadata = getattr(obj, 'metadata', None)
  54. return self
  55. a = nx.array([1.1, -1.1])
  56. m = MyArray(a, metadata='foo')
  57. f = ufl.fix(m)
  58. assert_array_equal(f, nx.array([1, -1]))
  59. assert_(isinstance(f, MyArray))
  60. assert_equal(f.metadata, 'foo')
  61. # check 0d arrays don't decay to scalars
  62. m0d = m[0,...]
  63. m0d.metadata = 'bar'
  64. f0d = ufl.fix(m0d)
  65. assert_(isinstance(f0d, MyArray))
  66. assert_equal(f0d.metadata, 'bar')
  67. def test_deprecated(self):
  68. # NumPy 1.13.0, 2017-04-26
  69. assert_warns(DeprecationWarning, ufl.fix, [1, 2], y=nx.empty(2))
  70. assert_warns(DeprecationWarning, ufl.isposinf, [1, 2], y=nx.empty(2))
  71. assert_warns(DeprecationWarning, ufl.isneginf, [1, 2], y=nx.empty(2))
  72. def test_scalar(self):
  73. x = np.inf
  74. actual = np.isposinf(x)
  75. expected = np.True_
  76. assert_equal(actual, expected)
  77. assert_equal(type(actual), type(expected))
  78. x = -3.4
  79. actual = np.fix(x)
  80. expected = np.float64(-3.0)
  81. assert_equal(actual, expected)
  82. assert_equal(type(actual), type(expected))
  83. out = np.array(0.0)
  84. actual = np.fix(x, out=out)
  85. assert_(actual is out)