图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

55 lines
1.8 KiB

  1. from sympy.concrete.summations import Sum
  2. from sympy.functions.elementary.exponential import log
  3. from sympy.functions.elementary.miscellaneous import sqrt
  4. from sympy.utilities.lambdify import lambdify
  5. from sympy.abc import x, i, a, b
  6. from sympy.codegen.numpy_nodes import logaddexp
  7. from sympy.printing.numpy import CuPyPrinter, _cupy_known_constants, _cupy_known_functions
  8. from sympy.testing.pytest import skip
  9. from sympy.external import import_module
  10. cp = import_module('cupy')
  11. def test_cupy_print():
  12. prntr = CuPyPrinter()
  13. assert prntr.doprint(logaddexp(a, b)) == 'cupy.logaddexp(a, b)'
  14. assert prntr.doprint(sqrt(x)) == 'cupy.sqrt(x)'
  15. assert prntr.doprint(log(x)) == 'cupy.log(x)'
  16. assert prntr.doprint("acos(x)") == 'cupy.arccos(x)'
  17. assert prntr.doprint("exp(x)") == 'cupy.exp(x)'
  18. assert prntr.doprint("Abs(x)") == 'abs(x)'
  19. def test_not_cupy_print():
  20. prntr = CuPyPrinter()
  21. assert "Not supported" in prntr.doprint("abcd(x)")
  22. def test_cupy_sum():
  23. if not cp:
  24. skip("CuPy not installed")
  25. s = Sum(x ** i, (i, a, b))
  26. f = lambdify((a, b, x), s, 'cupy')
  27. a_, b_ = 0, 10
  28. x_ = cp.linspace(-1, +1, 10)
  29. assert cp.allclose(f(a_, b_, x_), sum(x_ ** i_ for i_ in range(a_, b_ + 1)))
  30. s = Sum(i * x, (i, a, b))
  31. f = lambdify((a, b, x), s, 'numpy')
  32. a_, b_ = 0, 10
  33. x_ = cp.linspace(-1, +1, 10)
  34. assert cp.allclose(f(a_, b_, x_), sum(i_ * x_ for i_ in range(a_, b_ + 1)))
  35. def test_cupy_known_funcs_consts():
  36. assert _cupy_known_constants['NaN'] == 'cupy.nan'
  37. assert _cupy_known_constants['EulerGamma'] == 'cupy.euler_gamma'
  38. assert _cupy_known_functions['acos'] == 'cupy.arccos'
  39. assert _cupy_known_functions['log'] == 'cupy.log'
  40. def test_cupy_print_methods():
  41. prntr = CuPyPrinter()
  42. assert hasattr(prntr, '_print_acos')
  43. assert hasattr(prntr, '_print_log')