图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

390 lines
15 KiB

  1. import pytest
  2. from numpy.core import array, arange, printoptions
  3. import numpy.polynomial as poly
  4. from numpy.testing import assert_equal, assert_
  5. # For testing polynomial printing with object arrays
  6. from fractions import Fraction
  7. from decimal import Decimal
  8. class TestStrUnicodeSuperSubscripts:
  9. @pytest.fixture(scope='class', autouse=True)
  10. def use_unicode(self):
  11. poly.set_default_printstyle('unicode')
  12. @pytest.mark.parametrize(('inp', 'tgt'), (
  13. ([1, 2, 3], "1.0 + 2.0·x¹ + 3.0·x²"),
  14. ([-1, 0, 3, -1], "-1.0 + 0.0·x¹ + 3.0·x² - 1.0·x³"),
  15. (arange(12), ("0.0 + 1.0·x¹ + 2.0·x² + 3.0·x³ + 4.0·x⁴ + 5.0·x⁵ + "
  16. "6.0·x⁶ + 7.0·x⁷ +\n8.0·x⁸ + 9.0·x⁹ + 10.0·x¹⁰ + "
  17. "11.0·x¹¹")),
  18. ))
  19. def test_polynomial_str(self, inp, tgt):
  20. res = str(poly.Polynomial(inp))
  21. assert_equal(res, tgt)
  22. @pytest.mark.parametrize(('inp', 'tgt'), (
  23. ([1, 2, 3], "1.0 + 2.0·T₁(x) + 3.0·T₂(x)"),
  24. ([-1, 0, 3, -1], "-1.0 + 0.0·T₁(x) + 3.0·T₂(x) - 1.0·T₃(x)"),
  25. (arange(12), ("0.0 + 1.0·T₁(x) + 2.0·T₂(x) + 3.0·T₃(x) + 4.0·T₄(x) + "
  26. "5.0·T₅(x) +\n6.0·T₆(x) + 7.0·T₇(x) + 8.0·T₈(x) + "
  27. "9.0·T₉(x) + 10.0·T₁₀(x) + 11.0·T₁₁(x)")),
  28. ))
  29. def test_chebyshev_str(self, inp, tgt):
  30. res = str(poly.Chebyshev(inp))
  31. assert_equal(res, tgt)
  32. @pytest.mark.parametrize(('inp', 'tgt'), (
  33. ([1, 2, 3], "1.0 + 2.0·P₁(x) + 3.0·P₂(x)"),
  34. ([-1, 0, 3, -1], "-1.0 + 0.0·P₁(x) + 3.0·P₂(x) - 1.0·P₃(x)"),
  35. (arange(12), ("0.0 + 1.0·P₁(x) + 2.0·P₂(x) + 3.0·P₃(x) + 4.0·P₄(x) + "
  36. "5.0·P₅(x) +\n6.0·P₆(x) + 7.0·P₇(x) + 8.0·P₈(x) + "
  37. "9.0·P₉(x) + 10.0·P₁₀(x) + 11.0·P₁₁(x)")),
  38. ))
  39. def test_legendre_str(self, inp, tgt):
  40. res = str(poly.Legendre(inp))
  41. assert_equal(res, tgt)
  42. @pytest.mark.parametrize(('inp', 'tgt'), (
  43. ([1, 2, 3], "1.0 + 2.0·H₁(x) + 3.0·H₂(x)"),
  44. ([-1, 0, 3, -1], "-1.0 + 0.0·H₁(x) + 3.0·H₂(x) - 1.0·H₃(x)"),
  45. (arange(12), ("0.0 + 1.0·H₁(x) + 2.0·H₂(x) + 3.0·H₃(x) + 4.0·H₄(x) + "
  46. "5.0·H₅(x) +\n6.0·H₆(x) + 7.0·H₇(x) + 8.0·H₈(x) + "
  47. "9.0·H₉(x) + 10.0·H₁₀(x) + 11.0·H₁₁(x)")),
  48. ))
  49. def test_hermite_str(self, inp, tgt):
  50. res = str(poly.Hermite(inp))
  51. assert_equal(res, tgt)
  52. @pytest.mark.parametrize(('inp', 'tgt'), (
  53. ([1, 2, 3], "1.0 + 2.0·He₁(x) + 3.0·He₂(x)"),
  54. ([-1, 0, 3, -1], "-1.0 + 0.0·He₁(x) + 3.0·He₂(x) - 1.0·He₃(x)"),
  55. (arange(12), ("0.0 + 1.0·He₁(x) + 2.0·He₂(x) + 3.0·He₃(x) + "
  56. "4.0·He₄(x) + 5.0·He₅(x) +\n6.0·He₆(x) + 7.0·He₇(x) + "
  57. "8.0·He₈(x) + 9.0·He₉(x) + 10.0·He₁₀(x) +\n"
  58. "11.0·He₁₁(x)")),
  59. ))
  60. def test_hermiteE_str(self, inp, tgt):
  61. res = str(poly.HermiteE(inp))
  62. assert_equal(res, tgt)
  63. @pytest.mark.parametrize(('inp', 'tgt'), (
  64. ([1, 2, 3], "1.0 + 2.0·L₁(x) + 3.0·L₂(x)"),
  65. ([-1, 0, 3, -1], "-1.0 + 0.0·L₁(x) + 3.0·L₂(x) - 1.0·L₃(x)"),
  66. (arange(12), ("0.0 + 1.0·L₁(x) + 2.0·L₂(x) + 3.0·L₃(x) + 4.0·L₄(x) + "
  67. "5.0·L₅(x) +\n6.0·L₆(x) + 7.0·L₇(x) + 8.0·L₈(x) + "
  68. "9.0·L₉(x) + 10.0·L₁₀(x) + 11.0·L₁₁(x)")),
  69. ))
  70. def test_laguerre_str(self, inp, tgt):
  71. res = str(poly.Laguerre(inp))
  72. assert_equal(res, tgt)
  73. class TestStrAscii:
  74. @pytest.fixture(scope='class', autouse=True)
  75. def use_ascii(self):
  76. poly.set_default_printstyle('ascii')
  77. @pytest.mark.parametrize(('inp', 'tgt'), (
  78. ([1, 2, 3], "1.0 + 2.0 x**1 + 3.0 x**2"),
  79. ([-1, 0, 3, -1], "-1.0 + 0.0 x**1 + 3.0 x**2 - 1.0 x**3"),
  80. (arange(12), ("0.0 + 1.0 x**1 + 2.0 x**2 + 3.0 x**3 + 4.0 x**4 + "
  81. "5.0 x**5 + 6.0 x**6 +\n7.0 x**7 + 8.0 x**8 + "
  82. "9.0 x**9 + 10.0 x**10 + 11.0 x**11")),
  83. ))
  84. def test_polynomial_str(self, inp, tgt):
  85. res = str(poly.Polynomial(inp))
  86. assert_equal(res, tgt)
  87. @pytest.mark.parametrize(('inp', 'tgt'), (
  88. ([1, 2, 3], "1.0 + 2.0 T_1(x) + 3.0 T_2(x)"),
  89. ([-1, 0, 3, -1], "-1.0 + 0.0 T_1(x) + 3.0 T_2(x) - 1.0 T_3(x)"),
  90. (arange(12), ("0.0 + 1.0 T_1(x) + 2.0 T_2(x) + 3.0 T_3(x) + "
  91. "4.0 T_4(x) + 5.0 T_5(x) +\n6.0 T_6(x) + 7.0 T_7(x) + "
  92. "8.0 T_8(x) + 9.0 T_9(x) + 10.0 T_10(x) +\n"
  93. "11.0 T_11(x)")),
  94. ))
  95. def test_chebyshev_str(self, inp, tgt):
  96. res = str(poly.Chebyshev(inp))
  97. assert_equal(res, tgt)
  98. @pytest.mark.parametrize(('inp', 'tgt'), (
  99. ([1, 2, 3], "1.0 + 2.0 P_1(x) + 3.0 P_2(x)"),
  100. ([-1, 0, 3, -1], "-1.0 + 0.0 P_1(x) + 3.0 P_2(x) - 1.0 P_3(x)"),
  101. (arange(12), ("0.0 + 1.0 P_1(x) + 2.0 P_2(x) + 3.0 P_3(x) + "
  102. "4.0 P_4(x) + 5.0 P_5(x) +\n6.0 P_6(x) + 7.0 P_7(x) + "
  103. "8.0 P_8(x) + 9.0 P_9(x) + 10.0 P_10(x) +\n"
  104. "11.0 P_11(x)")),
  105. ))
  106. def test_legendre_str(self, inp, tgt):
  107. res = str(poly.Legendre(inp))
  108. assert_equal(res, tgt)
  109. @pytest.mark.parametrize(('inp', 'tgt'), (
  110. ([1, 2, 3], "1.0 + 2.0 H_1(x) + 3.0 H_2(x)"),
  111. ([-1, 0, 3, -1], "-1.0 + 0.0 H_1(x) + 3.0 H_2(x) - 1.0 H_3(x)"),
  112. (arange(12), ("0.0 + 1.0 H_1(x) + 2.0 H_2(x) + 3.0 H_3(x) + "
  113. "4.0 H_4(x) + 5.0 H_5(x) +\n6.0 H_6(x) + 7.0 H_7(x) + "
  114. "8.0 H_8(x) + 9.0 H_9(x) + 10.0 H_10(x) +\n"
  115. "11.0 H_11(x)")),
  116. ))
  117. def test_hermite_str(self, inp, tgt):
  118. res = str(poly.Hermite(inp))
  119. assert_equal(res, tgt)
  120. @pytest.mark.parametrize(('inp', 'tgt'), (
  121. ([1, 2, 3], "1.0 + 2.0 He_1(x) + 3.0 He_2(x)"),
  122. ([-1, 0, 3, -1], "-1.0 + 0.0 He_1(x) + 3.0 He_2(x) - 1.0 He_3(x)"),
  123. (arange(12), ("0.0 + 1.0 He_1(x) + 2.0 He_2(x) + 3.0 He_3(x) + "
  124. "4.0 He_4(x) +\n5.0 He_5(x) + 6.0 He_6(x) + "
  125. "7.0 He_7(x) + 8.0 He_8(x) + 9.0 He_9(x) +\n"
  126. "10.0 He_10(x) + 11.0 He_11(x)")),
  127. ))
  128. def test_hermiteE_str(self, inp, tgt):
  129. res = str(poly.HermiteE(inp))
  130. assert_equal(res, tgt)
  131. @pytest.mark.parametrize(('inp', 'tgt'), (
  132. ([1, 2, 3], "1.0 + 2.0 L_1(x) + 3.0 L_2(x)"),
  133. ([-1, 0, 3, -1], "-1.0 + 0.0 L_1(x) + 3.0 L_2(x) - 1.0 L_3(x)"),
  134. (arange(12), ("0.0 + 1.0 L_1(x) + 2.0 L_2(x) + 3.0 L_3(x) + "
  135. "4.0 L_4(x) + 5.0 L_5(x) +\n6.0 L_6(x) + 7.0 L_7(x) + "
  136. "8.0 L_8(x) + 9.0 L_9(x) + 10.0 L_10(x) +\n"
  137. "11.0 L_11(x)")),
  138. ))
  139. def test_laguerre_str(self, inp, tgt):
  140. res = str(poly.Laguerre(inp))
  141. assert_equal(res, tgt)
  142. class TestLinebreaking:
  143. @pytest.fixture(scope='class', autouse=True)
  144. def use_ascii(self):
  145. poly.set_default_printstyle('ascii')
  146. def test_single_line_one_less(self):
  147. # With 'ascii' style, len(str(p)) is default linewidth - 1 (i.e. 74)
  148. p = poly.Polynomial([123456789, 123456789, 123456789, 1234, 1])
  149. assert_equal(len(str(p)), 74)
  150. assert_equal(str(p), (
  151. '123456789.0 + 123456789.0 x**1 + 123456789.0 x**2 + '
  152. '1234.0 x**3 + 1.0 x**4'
  153. ))
  154. def test_num_chars_is_linewidth(self):
  155. # len(str(p)) == default linewidth == 75
  156. p = poly.Polynomial([123456789, 123456789, 123456789, 1234, 10])
  157. assert_equal(len(str(p)), 75)
  158. assert_equal(str(p), (
  159. '123456789.0 + 123456789.0 x**1 + 123456789.0 x**2 + '
  160. '1234.0 x**3 +\n10.0 x**4'
  161. ))
  162. def test_first_linebreak_multiline_one_less_than_linewidth(self):
  163. # Multiline str where len(first_line) + len(next_term) == lw - 1 == 74
  164. p = poly.Polynomial(
  165. [123456789, 123456789, 123456789, 12, 1, 123456789]
  166. )
  167. assert_equal(len(str(p).split('\n')[0]), 74)
  168. assert_equal(str(p), (
  169. '123456789.0 + 123456789.0 x**1 + 123456789.0 x**2 + '
  170. '12.0 x**3 + 1.0 x**4 +\n123456789.0 x**5'
  171. ))
  172. def test_first_linebreak_multiline_on_linewidth(self):
  173. # First line is one character longer than previous test
  174. p = poly.Polynomial(
  175. [123456789, 123456789, 123456789, 123, 1, 123456789]
  176. )
  177. assert_equal(str(p), (
  178. '123456789.0 + 123456789.0 x**1 + 123456789.0 x**2 + '
  179. '123.0 x**3 +\n1.0 x**4 + 123456789.0 x**5'
  180. ))
  181. @pytest.mark.parametrize(('lw', 'tgt'), (
  182. (75, ('0.0 + 10.0 x**1 + 200.0 x**2 + 3000.0 x**3 + 40000.0 x**4 +\n'
  183. '500000.0 x**5 + 600000.0 x**6 + 70000.0 x**7 + 8000.0 x**8 + '
  184. '900.0 x**9')),
  185. (45, ('0.0 + 10.0 x**1 + 200.0 x**2 + 3000.0 x**3 +\n40000.0 x**4 + '
  186. '500000.0 x**5 +\n600000.0 x**6 + 70000.0 x**7 + 8000.0 x**8 +\n'
  187. '900.0 x**9')),
  188. (132, ('0.0 + 10.0 x**1 + 200.0 x**2 + 3000.0 x**3 + 40000.0 x**4 + '
  189. '500000.0 x**5 + 600000.0 x**6 + 70000.0 x**7 + 8000.0 x**8 + '
  190. '900.0 x**9')),
  191. ))
  192. def test_linewidth_printoption(self, lw, tgt):
  193. p = poly.Polynomial(
  194. [0, 10, 200, 3000, 40000, 500000, 600000, 70000, 8000, 900]
  195. )
  196. with printoptions(linewidth=lw):
  197. assert_equal(str(p), tgt)
  198. for line in str(p).split('\n'):
  199. assert_(len(line) < lw)
  200. def test_set_default_printoptions():
  201. p = poly.Polynomial([1, 2, 3])
  202. c = poly.Chebyshev([1, 2, 3])
  203. poly.set_default_printstyle('ascii')
  204. assert_equal(str(p), "1.0 + 2.0 x**1 + 3.0 x**2")
  205. assert_equal(str(c), "1.0 + 2.0 T_1(x) + 3.0 T_2(x)")
  206. poly.set_default_printstyle('unicode')
  207. assert_equal(str(p), "1.0 + 2.0·x¹ + 3.0·x²")
  208. assert_equal(str(c), "1.0 + 2.0·T₁(x) + 3.0·T₂(x)")
  209. with pytest.raises(ValueError):
  210. poly.set_default_printstyle('invalid_input')
  211. def test_complex_coefficients():
  212. """Test both numpy and built-in complex."""
  213. coefs = [0+1j, 1+1j, -2+2j, 3+0j]
  214. # numpy complex
  215. p1 = poly.Polynomial(coefs)
  216. # Python complex
  217. p2 = poly.Polynomial(array(coefs, dtype=object))
  218. poly.set_default_printstyle('unicode')
  219. assert_equal(str(p1), "1j + (1+1j)·x¹ - (2-2j)·x² + (3+0j)·x³")
  220. assert_equal(str(p2), "1j + (1+1j)·x¹ + (-2+2j)·x² + (3+0j)·x³")
  221. poly.set_default_printstyle('ascii')
  222. assert_equal(str(p1), "1j + (1+1j) x**1 - (2-2j) x**2 + (3+0j) x**3")
  223. assert_equal(str(p2), "1j + (1+1j) x**1 + (-2+2j) x**2 + (3+0j) x**3")
  224. @pytest.mark.parametrize(('coefs', 'tgt'), (
  225. (array([Fraction(1, 2), Fraction(3, 4)], dtype=object), (
  226. "1/2 + 3/4·x¹"
  227. )),
  228. (array([1, 2, Fraction(5, 7)], dtype=object), (
  229. "1 + 2·x¹ + 5/7·x²"
  230. )),
  231. (array([Decimal('1.00'), Decimal('2.2'), 3], dtype=object), (
  232. "1.00 + 2.2·x¹ + 3·x²"
  233. )),
  234. ))
  235. def test_numeric_object_coefficients(coefs, tgt):
  236. p = poly.Polynomial(coefs)
  237. poly.set_default_printstyle('unicode')
  238. assert_equal(str(p), tgt)
  239. @pytest.mark.parametrize(('coefs', 'tgt'), (
  240. (array([1, 2, 'f'], dtype=object), '1 + 2·x¹ + f·x²'),
  241. (array([1, 2, [3, 4]], dtype=object), '1 + 2·x¹ + [3, 4]·x²'),
  242. ))
  243. def test_nonnumeric_object_coefficients(coefs, tgt):
  244. """
  245. Test coef fallback for object arrays of non-numeric coefficients.
  246. """
  247. p = poly.Polynomial(coefs)
  248. poly.set_default_printstyle('unicode')
  249. assert_equal(str(p), tgt)
  250. class TestFormat:
  251. def test_format_unicode(self):
  252. poly.set_default_printstyle('ascii')
  253. p = poly.Polynomial([1, 2, 0, -1])
  254. assert_equal(format(p, 'unicode'), "1.0 + 2.0·x¹ + 0.0·x² - 1.0·x³")
  255. def test_format_ascii(self):
  256. poly.set_default_printstyle('unicode')
  257. p = poly.Polynomial([1, 2, 0, -1])
  258. assert_equal(
  259. format(p, 'ascii'), "1.0 + 2.0 x**1 + 0.0 x**2 - 1.0 x**3"
  260. )
  261. def test_empty_formatstr(self):
  262. poly.set_default_printstyle('ascii')
  263. p = poly.Polynomial([1, 2, 3])
  264. assert_equal(format(p), "1.0 + 2.0 x**1 + 3.0 x**2")
  265. assert_equal(f"{p}", "1.0 + 2.0 x**1 + 3.0 x**2")
  266. def test_bad_formatstr(self):
  267. p = poly.Polynomial([1, 2, 0, -1])
  268. with pytest.raises(ValueError):
  269. format(p, '.2f')
  270. class TestRepr:
  271. def test_polynomial_str(self):
  272. res = repr(poly.Polynomial([0, 1]))
  273. tgt = 'Polynomial([0., 1.], domain=[-1, 1], window=[-1, 1])'
  274. assert_equal(res, tgt)
  275. def test_chebyshev_str(self):
  276. res = repr(poly.Chebyshev([0, 1]))
  277. tgt = 'Chebyshev([0., 1.], domain=[-1, 1], window=[-1, 1])'
  278. assert_equal(res, tgt)
  279. def test_legendre_repr(self):
  280. res = repr(poly.Legendre([0, 1]))
  281. tgt = 'Legendre([0., 1.], domain=[-1, 1], window=[-1, 1])'
  282. assert_equal(res, tgt)
  283. def test_hermite_repr(self):
  284. res = repr(poly.Hermite([0, 1]))
  285. tgt = 'Hermite([0., 1.], domain=[-1, 1], window=[-1, 1])'
  286. assert_equal(res, tgt)
  287. def test_hermiteE_repr(self):
  288. res = repr(poly.HermiteE([0, 1]))
  289. tgt = 'HermiteE([0., 1.], domain=[-1, 1], window=[-1, 1])'
  290. assert_equal(res, tgt)
  291. def test_laguerre_repr(self):
  292. res = repr(poly.Laguerre([0, 1]))
  293. tgt = 'Laguerre([0., 1.], domain=[0, 1], window=[0, 1])'
  294. assert_equal(res, tgt)
  295. class TestLatexRepr:
  296. """Test the latex repr used by Jupyter"""
  297. def as_latex(self, obj):
  298. # right now we ignore the formatting of scalars in our tests, since
  299. # it makes them too verbose. Ideally, the formatting of scalars will
  300. # be fixed such that tests below continue to pass
  301. obj._repr_latex_scalar = lambda x: str(x)
  302. try:
  303. return obj._repr_latex_()
  304. finally:
  305. del obj._repr_latex_scalar
  306. def test_simple_polynomial(self):
  307. # default input
  308. p = poly.Polynomial([1, 2, 3])
  309. assert_equal(self.as_latex(p),
  310. r'$x \mapsto 1.0 + 2.0\,x + 3.0\,x^{2}$')
  311. # translated input
  312. p = poly.Polynomial([1, 2, 3], domain=[-2, 0])
  313. assert_equal(self.as_latex(p),
  314. r'$x \mapsto 1.0 + 2.0\,\left(1.0 + x\right) + 3.0\,\left(1.0 + x\right)^{2}$')
  315. # scaled input
  316. p = poly.Polynomial([1, 2, 3], domain=[-0.5, 0.5])
  317. assert_equal(self.as_latex(p),
  318. r'$x \mapsto 1.0 + 2.0\,\left(2.0x\right) + 3.0\,\left(2.0x\right)^{2}$')
  319. # affine input
  320. p = poly.Polynomial([1, 2, 3], domain=[-1, 0])
  321. assert_equal(self.as_latex(p),
  322. r'$x \mapsto 1.0 + 2.0\,\left(1.0 + 2.0x\right) + 3.0\,\left(1.0 + 2.0x\right)^{2}$')
  323. def test_basis_func(self):
  324. p = poly.Chebyshev([1, 2, 3])
  325. assert_equal(self.as_latex(p),
  326. r'$x \mapsto 1.0\,{T}_{0}(x) + 2.0\,{T}_{1}(x) + 3.0\,{T}_{2}(x)$')
  327. # affine input - check no surplus parens are added
  328. p = poly.Chebyshev([1, 2, 3], domain=[-1, 0])
  329. assert_equal(self.as_latex(p),
  330. r'$x \mapsto 1.0\,{T}_{0}(1.0 + 2.0x) + 2.0\,{T}_{1}(1.0 + 2.0x) + 3.0\,{T}_{2}(1.0 + 2.0x)$')
  331. def test_multichar_basis_func(self):
  332. p = poly.HermiteE([1, 2, 3])
  333. assert_equal(self.as_latex(p),
  334. r'$x \mapsto 1.0\,{He}_{0}(x) + 2.0\,{He}_{1}(x) + 3.0\,{He}_{2}(x)$')