图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

310 lines
9.5 KiB

  1. import sympy
  2. import tempfile
  3. import os
  4. from sympy.core.mod import Mod
  5. from sympy.core.relational import Eq
  6. from sympy.core.symbol import symbols
  7. from sympy.external import import_module
  8. from sympy.tensor import IndexedBase, Idx
  9. from sympy.utilities.autowrap import autowrap, ufuncify, CodeWrapError
  10. from sympy.testing.pytest import skip
  11. numpy = import_module('numpy', min_module_version='1.6.1')
  12. Cython = import_module('Cython', min_module_version='0.15.1')
  13. f2py = import_module('numpy.f2py', import_kwargs={'fromlist': ['f2py']})
  14. f2pyworks = False
  15. if f2py:
  16. try:
  17. autowrap(symbols('x'), 'f95', 'f2py')
  18. except (CodeWrapError, ImportError, OSError):
  19. f2pyworks = False
  20. else:
  21. f2pyworks = True
  22. a, b, c = symbols('a b c')
  23. n, m, d = symbols('n m d', integer=True)
  24. A, B, C = symbols('A B C', cls=IndexedBase)
  25. i = Idx('i', m)
  26. j = Idx('j', n)
  27. k = Idx('k', d)
  28. def has_module(module):
  29. """
  30. Return True if module exists, otherwise run skip().
  31. module should be a string.
  32. """
  33. # To give a string of the module name to skip(), this function takes a
  34. # string. So we don't waste time running import_module() more than once,
  35. # just map the three modules tested here in this dict.
  36. modnames = {'numpy': numpy, 'Cython': Cython, 'f2py': f2py}
  37. if modnames[module]:
  38. if module == 'f2py' and not f2pyworks:
  39. skip("Couldn't run f2py.")
  40. return True
  41. skip("Couldn't import %s." % module)
  42. #
  43. # test runners used by several language-backend combinations
  44. #
  45. def runtest_autowrap_twice(language, backend):
  46. f = autowrap((((a + b)/c)**5).expand(), language, backend)
  47. g = autowrap((((a + b)/c)**4).expand(), language, backend)
  48. # check that autowrap updates the module name. Else, g gives the same as f
  49. assert f(1, -2, 1) == -1.0
  50. assert g(1, -2, 1) == 1.0
  51. def runtest_autowrap_trace(language, backend):
  52. has_module('numpy')
  53. trace = autowrap(A[i, i], language, backend)
  54. assert trace(numpy.eye(100)) == 100
  55. def runtest_autowrap_matrix_vector(language, backend):
  56. has_module('numpy')
  57. x, y = symbols('x y', cls=IndexedBase)
  58. expr = Eq(y[i], A[i, j]*x[j])
  59. mv = autowrap(expr, language, backend)
  60. # compare with numpy's dot product
  61. M = numpy.random.rand(10, 20)
  62. x = numpy.random.rand(20)
  63. y = numpy.dot(M, x)
  64. assert numpy.sum(numpy.abs(y - mv(M, x))) < 1e-13
  65. def runtest_autowrap_matrix_matrix(language, backend):
  66. has_module('numpy')
  67. expr = Eq(C[i, j], A[i, k]*B[k, j])
  68. matmat = autowrap(expr, language, backend)
  69. # compare with numpy's dot product
  70. M1 = numpy.random.rand(10, 20)
  71. M2 = numpy.random.rand(20, 15)
  72. M3 = numpy.dot(M1, M2)
  73. assert numpy.sum(numpy.abs(M3 - matmat(M1, M2))) < 1e-13
  74. def runtest_ufuncify(language, backend):
  75. has_module('numpy')
  76. a, b, c = symbols('a b c')
  77. fabc = ufuncify([a, b, c], a*b + c, backend=backend)
  78. facb = ufuncify([a, c, b], a*b + c, backend=backend)
  79. grid = numpy.linspace(-2, 2, 50)
  80. b = numpy.linspace(-5, 4, 50)
  81. c = numpy.linspace(-1, 1, 50)
  82. expected = grid*b + c
  83. numpy.testing.assert_allclose(fabc(grid, b, c), expected)
  84. numpy.testing.assert_allclose(facb(grid, c, b), expected)
  85. def runtest_issue_10274(language, backend):
  86. expr = (a - b + c)**(13)
  87. tmp = tempfile.mkdtemp()
  88. f = autowrap(expr, language, backend, tempdir=tmp,
  89. helpers=('helper', a - b + c, (a, b, c)))
  90. assert f(1, 1, 1) == 1
  91. for file in os.listdir(tmp):
  92. if file.startswith("wrapped_code_") and file.endswith(".c"):
  93. fil = open(tmp + '/' + file)
  94. lines = fil.readlines()
  95. assert lines[0] == "/******************************************************************************\n"
  96. assert "Code generated with SymPy " + sympy.__version__ in lines[1]
  97. assert lines[2:] == [
  98. " * *\n",
  99. " * See http://www.sympy.org/ for more information. *\n",
  100. " * *\n",
  101. " * This file is part of 'autowrap' *\n",
  102. " ******************************************************************************/\n",
  103. "#include " + '"' + file[:-1]+ 'h"' + "\n",
  104. "#include <math.h>\n",
  105. "\n",
  106. "double helper(double a, double b, double c) {\n",
  107. "\n",
  108. " double helper_result;\n",
  109. " helper_result = a - b + c;\n",
  110. " return helper_result;\n",
  111. "\n",
  112. "}\n",
  113. "\n",
  114. "double autofunc(double a, double b, double c) {\n",
  115. "\n",
  116. " double autofunc_result;\n",
  117. " autofunc_result = pow(helper(a, b, c), 13);\n",
  118. " return autofunc_result;\n",
  119. "\n",
  120. "}\n",
  121. ]
  122. def runtest_issue_15337(language, backend):
  123. has_module('numpy')
  124. # NOTE : autowrap was originally designed to only accept an iterable for
  125. # the kwarg "helpers", but in issue 10274 the user mistakenly thought that
  126. # if there was only a single helper it did not need to be passed via an
  127. # iterable that wrapped the helper tuple. There were no tests for this
  128. # behavior so when the code was changed to accept a single tuple it broke
  129. # the original behavior. These tests below ensure that both now work.
  130. a, b, c, d, e = symbols('a, b, c, d, e')
  131. expr = (a - b + c - d + e)**13
  132. exp_res = (1. - 2. + 3. - 4. + 5.)**13
  133. f = autowrap(expr, language, backend, args=(a, b, c, d, e),
  134. helpers=('f1', a - b + c, (a, b, c)))
  135. numpy.testing.assert_allclose(f(1, 2, 3, 4, 5), exp_res)
  136. f = autowrap(expr, language, backend, args=(a, b, c, d, e),
  137. helpers=(('f1', a - b, (a, b)), ('f2', c - d, (c, d))))
  138. numpy.testing.assert_allclose(f(1, 2, 3, 4, 5), exp_res)
  139. def test_issue_15230():
  140. has_module('f2py')
  141. x, y = symbols('x, y')
  142. expr = Mod(x, 3.0) - Mod(y, -2.0)
  143. f = autowrap(expr, args=[x, y], language='F95')
  144. exp_res = float(expr.xreplace({x: 3.5, y: 2.7}).evalf())
  145. assert abs(f(3.5, 2.7) - exp_res) < 1e-14
  146. x, y = symbols('x, y', integer=True)
  147. expr = Mod(x, 3) - Mod(y, -2)
  148. f = autowrap(expr, args=[x, y], language='F95')
  149. assert f(3, 2) == expr.xreplace({x: 3, y: 2})
  150. #
  151. # tests of language-backend combinations
  152. #
  153. # f2py
  154. def test_wrap_twice_f95_f2py():
  155. has_module('f2py')
  156. runtest_autowrap_twice('f95', 'f2py')
  157. def test_autowrap_trace_f95_f2py():
  158. has_module('f2py')
  159. runtest_autowrap_trace('f95', 'f2py')
  160. def test_autowrap_matrix_vector_f95_f2py():
  161. has_module('f2py')
  162. runtest_autowrap_matrix_vector('f95', 'f2py')
  163. def test_autowrap_matrix_matrix_f95_f2py():
  164. has_module('f2py')
  165. runtest_autowrap_matrix_matrix('f95', 'f2py')
  166. def test_ufuncify_f95_f2py():
  167. has_module('f2py')
  168. runtest_ufuncify('f95', 'f2py')
  169. def test_issue_15337_f95_f2py():
  170. has_module('f2py')
  171. runtest_issue_15337('f95', 'f2py')
  172. # Cython
  173. def test_wrap_twice_c_cython():
  174. has_module('Cython')
  175. runtest_autowrap_twice('C', 'cython')
  176. def test_autowrap_trace_C_Cython():
  177. has_module('Cython')
  178. runtest_autowrap_trace('C99', 'cython')
  179. def test_autowrap_matrix_vector_C_cython():
  180. has_module('Cython')
  181. runtest_autowrap_matrix_vector('C99', 'cython')
  182. def test_autowrap_matrix_matrix_C_cython():
  183. has_module('Cython')
  184. runtest_autowrap_matrix_matrix('C99', 'cython')
  185. def test_ufuncify_C_Cython():
  186. has_module('Cython')
  187. runtest_ufuncify('C99', 'cython')
  188. def test_issue_10274_C_cython():
  189. has_module('Cython')
  190. runtest_issue_10274('C89', 'cython')
  191. def test_issue_15337_C_cython():
  192. has_module('Cython')
  193. runtest_issue_15337('C89', 'cython')
  194. def test_autowrap_custom_printer():
  195. has_module('Cython')
  196. from sympy.core.numbers import pi
  197. from sympy.utilities.codegen import C99CodeGen
  198. from sympy.printing.c import C99CodePrinter
  199. class PiPrinter(C99CodePrinter):
  200. def _print_Pi(self, expr):
  201. return "S_PI"
  202. printer = PiPrinter()
  203. gen = C99CodeGen(printer=printer)
  204. gen.preprocessor_statements.append('#include "shortpi.h"')
  205. expr = pi * a
  206. expected = (
  207. '#include "%s"\n'
  208. '#include <math.h>\n'
  209. '#include "shortpi.h"\n'
  210. '\n'
  211. 'double autofunc(double a) {\n'
  212. '\n'
  213. ' double autofunc_result;\n'
  214. ' autofunc_result = S_PI*a;\n'
  215. ' return autofunc_result;\n'
  216. '\n'
  217. '}\n'
  218. )
  219. tmpdir = tempfile.mkdtemp()
  220. # write a trivial header file to use in the generated code
  221. open(os.path.join(tmpdir, 'shortpi.h'), 'w').write('#define S_PI 3.14')
  222. func = autowrap(expr, backend='cython', tempdir=tmpdir, code_gen=gen)
  223. assert func(4.2) == 3.14 * 4.2
  224. # check that the generated code is correct
  225. for filename in os.listdir(tmpdir):
  226. if filename.startswith('wrapped_code') and filename.endswith('.c'):
  227. with open(os.path.join(tmpdir, filename)) as f:
  228. lines = f.readlines()
  229. expected = expected % filename.replace('.c', '.h')
  230. assert ''.join(lines[7:]) == expected
  231. # Numpy
  232. def test_ufuncify_numpy():
  233. # This test doesn't use Cython, but if Cython works, then there is a valid
  234. # C compiler, which is needed.
  235. has_module('Cython')
  236. runtest_ufuncify('C99', 'numpy')