图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

121 lines
3.5 KiB

  1. """Tests for polyutils module.
  2. """
  3. import numpy as np
  4. import numpy.polynomial.polyutils as pu
  5. from numpy.testing import (
  6. assert_almost_equal, assert_raises, assert_equal, assert_,
  7. )
  8. class TestMisc:
  9. def test_trimseq(self):
  10. for i in range(5):
  11. tgt = [1]
  12. res = pu.trimseq([1] + [0]*5)
  13. assert_equal(res, tgt)
  14. def test_as_series(self):
  15. # check exceptions
  16. assert_raises(ValueError, pu.as_series, [[]])
  17. assert_raises(ValueError, pu.as_series, [[[1, 2]]])
  18. assert_raises(ValueError, pu.as_series, [[1], ['a']])
  19. # check common types
  20. types = ['i', 'd', 'O']
  21. for i in range(len(types)):
  22. for j in range(i):
  23. ci = np.ones(1, types[i])
  24. cj = np.ones(1, types[j])
  25. [resi, resj] = pu.as_series([ci, cj])
  26. assert_(resi.dtype.char == resj.dtype.char)
  27. assert_(resj.dtype.char == types[i])
  28. def test_trimcoef(self):
  29. coef = [2, -1, 1, 0]
  30. # Test exceptions
  31. assert_raises(ValueError, pu.trimcoef, coef, -1)
  32. # Test results
  33. assert_equal(pu.trimcoef(coef), coef[:-1])
  34. assert_equal(pu.trimcoef(coef, 1), coef[:-3])
  35. assert_equal(pu.trimcoef(coef, 2), [0])
  36. def test_vander_nd_exception(self):
  37. # n_dims != len(points)
  38. assert_raises(ValueError, pu._vander_nd, (), (1, 2, 3), [90])
  39. # n_dims != len(degrees)
  40. assert_raises(ValueError, pu._vander_nd, (), (), [90.65])
  41. # n_dims == 0
  42. assert_raises(ValueError, pu._vander_nd, (), (), [])
  43. def test_div_zerodiv(self):
  44. # c2[-1] == 0
  45. assert_raises(ZeroDivisionError, pu._div, pu._div, (1, 2, 3), [0])
  46. def test_pow_too_large(self):
  47. # power > maxpower
  48. assert_raises(ValueError, pu._pow, (), [1, 2, 3], 5, 4)
  49. class TestDomain:
  50. def test_getdomain(self):
  51. # test for real values
  52. x = [1, 10, 3, -1]
  53. tgt = [-1, 10]
  54. res = pu.getdomain(x)
  55. assert_almost_equal(res, tgt)
  56. # test for complex values
  57. x = [1 + 1j, 1 - 1j, 0, 2]
  58. tgt = [-1j, 2 + 1j]
  59. res = pu.getdomain(x)
  60. assert_almost_equal(res, tgt)
  61. def test_mapdomain(self):
  62. # test for real values
  63. dom1 = [0, 4]
  64. dom2 = [1, 3]
  65. tgt = dom2
  66. res = pu.mapdomain(dom1, dom1, dom2)
  67. assert_almost_equal(res, tgt)
  68. # test for complex values
  69. dom1 = [0 - 1j, 2 + 1j]
  70. dom2 = [-2, 2]
  71. tgt = dom2
  72. x = dom1
  73. res = pu.mapdomain(x, dom1, dom2)
  74. assert_almost_equal(res, tgt)
  75. # test for multidimensional arrays
  76. dom1 = [0, 4]
  77. dom2 = [1, 3]
  78. tgt = np.array([dom2, dom2])
  79. x = np.array([dom1, dom1])
  80. res = pu.mapdomain(x, dom1, dom2)
  81. assert_almost_equal(res, tgt)
  82. # test that subtypes are preserved.
  83. class MyNDArray(np.ndarray):
  84. pass
  85. dom1 = [0, 4]
  86. dom2 = [1, 3]
  87. x = np.array([dom1, dom1]).view(MyNDArray)
  88. res = pu.mapdomain(x, dom1, dom2)
  89. assert_(isinstance(res, MyNDArray))
  90. def test_mapparms(self):
  91. # test for real values
  92. dom1 = [0, 4]
  93. dom2 = [1, 3]
  94. tgt = [1, .5]
  95. res = pu. mapparms(dom1, dom2)
  96. assert_almost_equal(res, tgt)
  97. # test for complex values
  98. dom1 = [0 - 1j, 2 + 1j]
  99. dom2 = [-2, 2]
  100. tgt = [-1 + 1j, 1 - 1j]
  101. res = pu.mapparms(dom1, dom2)
  102. assert_almost_equal(res, tgt)