m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

470 lines
14 KiB

6 months ago
  1. # Tests that require installed backends go into
  2. # sympy/test_external/test_autowrap
  3. import os
  4. import tempfile
  5. import shutil
  6. from io import StringIO
  7. from sympy.core import symbols, Eq
  8. from sympy.utilities.autowrap import (autowrap, binary_function,
  9. CythonCodeWrapper, UfuncifyCodeWrapper, CodeWrapper)
  10. from sympy.utilities.codegen import (
  11. CCodeGen, C99CodeGen, CodeGenArgumentListError, make_routine
  12. )
  13. from sympy.testing.pytest import raises
  14. from sympy.testing.tmpfiles import TmpFileManager
  15. def get_string(dump_fn, routines, prefix="file", **kwargs):
  16. """Wrapper for dump_fn. dump_fn writes its results to a stream object and
  17. this wrapper returns the contents of that stream as a string. This
  18. auxiliary function is used by many tests below.
  19. The header and the empty lines are not generator to facilitate the
  20. testing of the output.
  21. """
  22. output = StringIO()
  23. dump_fn(routines, output, prefix, **kwargs)
  24. source = output.getvalue()
  25. output.close()
  26. return source
  27. def test_cython_wrapper_scalar_function():
  28. x, y, z = symbols('x,y,z')
  29. expr = (x + y)*z
  30. routine = make_routine("test", expr)
  31. code_gen = CythonCodeWrapper(CCodeGen())
  32. source = get_string(code_gen.dump_pyx, [routine])
  33. expected = (
  34. "cdef extern from 'file.h':\n"
  35. " double test(double x, double y, double z)\n"
  36. "\n"
  37. "def test_c(double x, double y, double z):\n"
  38. "\n"
  39. " return test(x, y, z)")
  40. assert source == expected
  41. def test_cython_wrapper_outarg():
  42. from sympy.core.relational import Equality
  43. x, y, z = symbols('x,y,z')
  44. code_gen = CythonCodeWrapper(C99CodeGen())
  45. routine = make_routine("test", Equality(z, x + y))
  46. source = get_string(code_gen.dump_pyx, [routine])
  47. expected = (
  48. "cdef extern from 'file.h':\n"
  49. " void test(double x, double y, double *z)\n"
  50. "\n"
  51. "def test_c(double x, double y):\n"
  52. "\n"
  53. " cdef double z = 0\n"
  54. " test(x, y, &z)\n"
  55. " return z")
  56. assert source == expected
  57. def test_cython_wrapper_inoutarg():
  58. from sympy.core.relational import Equality
  59. x, y, z = symbols('x,y,z')
  60. code_gen = CythonCodeWrapper(C99CodeGen())
  61. routine = make_routine("test", Equality(z, x + y + z))
  62. source = get_string(code_gen.dump_pyx, [routine])
  63. expected = (
  64. "cdef extern from 'file.h':\n"
  65. " void test(double x, double y, double *z)\n"
  66. "\n"
  67. "def test_c(double x, double y, double z):\n"
  68. "\n"
  69. " test(x, y, &z)\n"
  70. " return z")
  71. assert source == expected
  72. def test_cython_wrapper_compile_flags():
  73. from sympy.core.relational import Equality
  74. x, y, z = symbols('x,y,z')
  75. routine = make_routine("test", Equality(z, x + y))
  76. code_gen = CythonCodeWrapper(CCodeGen())
  77. expected = """\
  78. try:
  79. from setuptools import setup
  80. from setuptools import Extension
  81. except ImportError:
  82. from distutils.core import setup
  83. from distutils.extension import Extension
  84. from Cython.Build import cythonize
  85. cy_opts = {}
  86. ext_mods = [Extension(
  87. 'wrapper_module_%(num)s', ['wrapper_module_%(num)s.pyx', 'wrapped_code_%(num)s.c'],
  88. include_dirs=[],
  89. library_dirs=[],
  90. libraries=[],
  91. extra_compile_args=['-std=c99'],
  92. extra_link_args=[]
  93. )]
  94. setup(ext_modules=cythonize(ext_mods, **cy_opts))
  95. """ % {'num': CodeWrapper._module_counter}
  96. temp_dir = tempfile.mkdtemp()
  97. TmpFileManager.tmp_folder(temp_dir)
  98. setup_file_path = os.path.join(temp_dir, 'setup.py')
  99. code_gen._prepare_files(routine, build_dir=temp_dir)
  100. with open(setup_file_path) as f:
  101. setup_text = f.read()
  102. assert setup_text == expected
  103. code_gen = CythonCodeWrapper(CCodeGen(),
  104. include_dirs=['/usr/local/include', '/opt/booger/include'],
  105. library_dirs=['/user/local/lib'],
  106. libraries=['thelib', 'nilib'],
  107. extra_compile_args=['-slow-math'],
  108. extra_link_args=['-lswamp', '-ltrident'],
  109. cythonize_options={'compiler_directives': {'boundscheck': False}}
  110. )
  111. expected = """\
  112. try:
  113. from setuptools import setup
  114. from setuptools import Extension
  115. except ImportError:
  116. from distutils.core import setup
  117. from distutils.extension import Extension
  118. from Cython.Build import cythonize
  119. cy_opts = {'compiler_directives': {'boundscheck': False}}
  120. ext_mods = [Extension(
  121. 'wrapper_module_%(num)s', ['wrapper_module_%(num)s.pyx', 'wrapped_code_%(num)s.c'],
  122. include_dirs=['/usr/local/include', '/opt/booger/include'],
  123. library_dirs=['/user/local/lib'],
  124. libraries=['thelib', 'nilib'],
  125. extra_compile_args=['-slow-math', '-std=c99'],
  126. extra_link_args=['-lswamp', '-ltrident']
  127. )]
  128. setup(ext_modules=cythonize(ext_mods, **cy_opts))
  129. """ % {'num': CodeWrapper._module_counter}
  130. code_gen._prepare_files(routine, build_dir=temp_dir)
  131. with open(setup_file_path) as f:
  132. setup_text = f.read()
  133. assert setup_text == expected
  134. expected = """\
  135. try:
  136. from setuptools import setup
  137. from setuptools import Extension
  138. except ImportError:
  139. from distutils.core import setup
  140. from distutils.extension import Extension
  141. from Cython.Build import cythonize
  142. cy_opts = {'compiler_directives': {'boundscheck': False}}
  143. import numpy as np
  144. ext_mods = [Extension(
  145. 'wrapper_module_%(num)s', ['wrapper_module_%(num)s.pyx', 'wrapped_code_%(num)s.c'],
  146. include_dirs=['/usr/local/include', '/opt/booger/include', np.get_include()],
  147. library_dirs=['/user/local/lib'],
  148. libraries=['thelib', 'nilib'],
  149. extra_compile_args=['-slow-math', '-std=c99'],
  150. extra_link_args=['-lswamp', '-ltrident']
  151. )]
  152. setup(ext_modules=cythonize(ext_mods, **cy_opts))
  153. """ % {'num': CodeWrapper._module_counter}
  154. code_gen._need_numpy = True
  155. code_gen._prepare_files(routine, build_dir=temp_dir)
  156. with open(setup_file_path) as f:
  157. setup_text = f.read()
  158. assert setup_text == expected
  159. TmpFileManager.cleanup()
  160. def test_cython_wrapper_unique_dummyvars():
  161. from sympy.core.relational import Equality
  162. from sympy.core.symbol import Dummy
  163. x, y, z = Dummy('x'), Dummy('y'), Dummy('z')
  164. x_id, y_id, z_id = [str(d.dummy_index) for d in [x, y, z]]
  165. expr = Equality(z, x + y)
  166. routine = make_routine("test", expr)
  167. code_gen = CythonCodeWrapper(CCodeGen())
  168. source = get_string(code_gen.dump_pyx, [routine])
  169. expected_template = (
  170. "cdef extern from 'file.h':\n"
  171. " void test(double x_{x_id}, double y_{y_id}, double *z_{z_id})\n"
  172. "\n"
  173. "def test_c(double x_{x_id}, double y_{y_id}):\n"
  174. "\n"
  175. " cdef double z_{z_id} = 0\n"
  176. " test(x_{x_id}, y_{y_id}, &z_{z_id})\n"
  177. " return z_{z_id}")
  178. expected = expected_template.format(x_id=x_id, y_id=y_id, z_id=z_id)
  179. assert source == expected
  180. def test_autowrap_dummy():
  181. x, y, z = symbols('x y z')
  182. # Uses DummyWrapper to test that codegen works as expected
  183. f = autowrap(x + y, backend='dummy')
  184. assert f() == str(x + y)
  185. assert f.args == "x, y"
  186. assert f.returns == "nameless"
  187. f = autowrap(Eq(z, x + y), backend='dummy')
  188. assert f() == str(x + y)
  189. assert f.args == "x, y"
  190. assert f.returns == "z"
  191. f = autowrap(Eq(z, x + y + z), backend='dummy')
  192. assert f() == str(x + y + z)
  193. assert f.args == "x, y, z"
  194. assert f.returns == "z"
  195. def test_autowrap_args():
  196. x, y, z = symbols('x y z')
  197. raises(CodeGenArgumentListError, lambda: autowrap(Eq(z, x + y),
  198. backend='dummy', args=[x]))
  199. f = autowrap(Eq(z, x + y), backend='dummy', args=[y, x])
  200. assert f() == str(x + y)
  201. assert f.args == "y, x"
  202. assert f.returns == "z"
  203. raises(CodeGenArgumentListError, lambda: autowrap(Eq(z, x + y + z),
  204. backend='dummy', args=[x, y]))
  205. f = autowrap(Eq(z, x + y + z), backend='dummy', args=[y, x, z])
  206. assert f() == str(x + y + z)
  207. assert f.args == "y, x, z"
  208. assert f.returns == "z"
  209. f = autowrap(Eq(z, x + y + z), backend='dummy', args=(y, x, z))
  210. assert f() == str(x + y + z)
  211. assert f.args == "y, x, z"
  212. assert f.returns == "z"
  213. def test_autowrap_store_files():
  214. x, y = symbols('x y')
  215. tmp = tempfile.mkdtemp()
  216. TmpFileManager.tmp_folder(tmp)
  217. f = autowrap(x + y, backend='dummy', tempdir=tmp)
  218. assert f() == str(x + y)
  219. assert os.access(tmp, os.F_OK)
  220. TmpFileManager.cleanup()
  221. def test_autowrap_store_files_issue_gh12939():
  222. x, y = symbols('x y')
  223. tmp = './tmp'
  224. try:
  225. f = autowrap(x + y, backend='dummy', tempdir=tmp)
  226. assert f() == str(x + y)
  227. assert os.access(tmp, os.F_OK)
  228. finally:
  229. shutil.rmtree(tmp)
  230. def test_binary_function():
  231. x, y = symbols('x y')
  232. f = binary_function('f', x + y, backend='dummy')
  233. assert f._imp_() == str(x + y)
  234. def test_ufuncify_source():
  235. x, y, z = symbols('x,y,z')
  236. code_wrapper = UfuncifyCodeWrapper(C99CodeGen("ufuncify"))
  237. routine = make_routine("test", x + y + z)
  238. source = get_string(code_wrapper.dump_c, [routine])
  239. expected = """\
  240. #include "Python.h"
  241. #include "math.h"
  242. #include "numpy/ndarraytypes.h"
  243. #include "numpy/ufuncobject.h"
  244. #include "numpy/halffloat.h"
  245. #include "file.h"
  246. static PyMethodDef wrapper_module_%(num)sMethods[] = {
  247. {NULL, NULL, 0, NULL}
  248. };
  249. static void test_ufunc(char **args, npy_intp *dimensions, npy_intp* steps, void* data)
  250. {
  251. npy_intp i;
  252. npy_intp n = dimensions[0];
  253. char *in0 = args[0];
  254. char *in1 = args[1];
  255. char *in2 = args[2];
  256. char *out0 = args[3];
  257. npy_intp in0_step = steps[0];
  258. npy_intp in1_step = steps[1];
  259. npy_intp in2_step = steps[2];
  260. npy_intp out0_step = steps[3];
  261. for (i = 0; i < n; i++) {
  262. *((double *)out0) = test(*(double *)in0, *(double *)in1, *(double *)in2);
  263. in0 += in0_step;
  264. in1 += in1_step;
  265. in2 += in2_step;
  266. out0 += out0_step;
  267. }
  268. }
  269. PyUFuncGenericFunction test_funcs[1] = {&test_ufunc};
  270. static char test_types[4] = {NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE};
  271. static void *test_data[1] = {NULL};
  272. #if PY_VERSION_HEX >= 0x03000000
  273. static struct PyModuleDef moduledef = {
  274. PyModuleDef_HEAD_INIT,
  275. "wrapper_module_%(num)s",
  276. NULL,
  277. -1,
  278. wrapper_module_%(num)sMethods,
  279. NULL,
  280. NULL,
  281. NULL,
  282. NULL
  283. };
  284. PyMODINIT_FUNC PyInit_wrapper_module_%(num)s(void)
  285. {
  286. PyObject *m, *d;
  287. PyObject *ufunc0;
  288. m = PyModule_Create(&moduledef);
  289. if (!m) {
  290. return NULL;
  291. }
  292. import_array();
  293. import_umath();
  294. d = PyModule_GetDict(m);
  295. ufunc0 = PyUFunc_FromFuncAndData(test_funcs, test_data, test_types, 1, 3, 1,
  296. PyUFunc_None, "wrapper_module_%(num)s", "Created in SymPy with Ufuncify", 0);
  297. PyDict_SetItemString(d, "test", ufunc0);
  298. Py_DECREF(ufunc0);
  299. return m;
  300. }
  301. #else
  302. PyMODINIT_FUNC initwrapper_module_%(num)s(void)
  303. {
  304. PyObject *m, *d;
  305. PyObject *ufunc0;
  306. m = Py_InitModule("wrapper_module_%(num)s", wrapper_module_%(num)sMethods);
  307. if (m == NULL) {
  308. return;
  309. }
  310. import_array();
  311. import_umath();
  312. d = PyModule_GetDict(m);
  313. ufunc0 = PyUFunc_FromFuncAndData(test_funcs, test_data, test_types, 1, 3, 1,
  314. PyUFunc_None, "wrapper_module_%(num)s", "Created in SymPy with Ufuncify", 0);
  315. PyDict_SetItemString(d, "test", ufunc0);
  316. Py_DECREF(ufunc0);
  317. }
  318. #endif""" % {'num': CodeWrapper._module_counter}
  319. assert source == expected
  320. def test_ufuncify_source_multioutput():
  321. x, y, z = symbols('x,y,z')
  322. var_symbols = (x, y, z)
  323. expr = x + y**3 + 10*z**2
  324. code_wrapper = UfuncifyCodeWrapper(C99CodeGen("ufuncify"))
  325. routines = [make_routine("func{}".format(i), expr.diff(var_symbols[i]), var_symbols) for i in range(len(var_symbols))]
  326. source = get_string(code_wrapper.dump_c, routines, funcname='multitest')
  327. expected = """\
  328. #include "Python.h"
  329. #include "math.h"
  330. #include "numpy/ndarraytypes.h"
  331. #include "numpy/ufuncobject.h"
  332. #include "numpy/halffloat.h"
  333. #include "file.h"
  334. static PyMethodDef wrapper_module_%(num)sMethods[] = {
  335. {NULL, NULL, 0, NULL}
  336. };
  337. static void multitest_ufunc(char **args, npy_intp *dimensions, npy_intp* steps, void* data)
  338. {
  339. npy_intp i;
  340. npy_intp n = dimensions[0];
  341. char *in0 = args[0];
  342. char *in1 = args[1];
  343. char *in2 = args[2];
  344. char *out0 = args[3];
  345. char *out1 = args[4];
  346. char *out2 = args[5];
  347. npy_intp in0_step = steps[0];
  348. npy_intp in1_step = steps[1];
  349. npy_intp in2_step = steps[2];
  350. npy_intp out0_step = steps[3];
  351. npy_intp out1_step = steps[4];
  352. npy_intp out2_step = steps[5];
  353. for (i = 0; i < n; i++) {
  354. *((double *)out0) = func0(*(double *)in0, *(double *)in1, *(double *)in2);
  355. *((double *)out1) = func1(*(double *)in0, *(double *)in1, *(double *)in2);
  356. *((double *)out2) = func2(*(double *)in0, *(double *)in1, *(double *)in2);
  357. in0 += in0_step;
  358. in1 += in1_step;
  359. in2 += in2_step;
  360. out0 += out0_step;
  361. out1 += out1_step;
  362. out2 += out2_step;
  363. }
  364. }
  365. PyUFuncGenericFunction multitest_funcs[1] = {&multitest_ufunc};
  366. static char multitest_types[6] = {NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE};
  367. static void *multitest_data[1] = {NULL};
  368. #if PY_VERSION_HEX >= 0x03000000
  369. static struct PyModuleDef moduledef = {
  370. PyModuleDef_HEAD_INIT,
  371. "wrapper_module_%(num)s",
  372. NULL,
  373. -1,
  374. wrapper_module_%(num)sMethods,
  375. NULL,
  376. NULL,
  377. NULL,
  378. NULL
  379. };
  380. PyMODINIT_FUNC PyInit_wrapper_module_%(num)s(void)
  381. {
  382. PyObject *m, *d;
  383. PyObject *ufunc0;
  384. m = PyModule_Create(&moduledef);
  385. if (!m) {
  386. return NULL;
  387. }
  388. import_array();
  389. import_umath();
  390. d = PyModule_GetDict(m);
  391. ufunc0 = PyUFunc_FromFuncAndData(multitest_funcs, multitest_data, multitest_types, 1, 3, 3,
  392. PyUFunc_None, "wrapper_module_%(num)s", "Created in SymPy with Ufuncify", 0);
  393. PyDict_SetItemString(d, "multitest", ufunc0);
  394. Py_DECREF(ufunc0);
  395. return m;
  396. }
  397. #else
  398. PyMODINIT_FUNC initwrapper_module_%(num)s(void)
  399. {
  400. PyObject *m, *d;
  401. PyObject *ufunc0;
  402. m = Py_InitModule("wrapper_module_%(num)s", wrapper_module_%(num)sMethods);
  403. if (m == NULL) {
  404. return;
  405. }
  406. import_array();
  407. import_umath();
  408. d = PyModule_GetDict(m);
  409. ufunc0 = PyUFunc_FromFuncAndData(multitest_funcs, multitest_data, multitest_types, 1, 3, 3,
  410. PyUFunc_None, "wrapper_module_%(num)s", "Created in SymPy with Ufuncify", 0);
  411. PyDict_SetItemString(d, "multitest", ufunc0);
  412. Py_DECREF(ufunc0);
  413. }
  414. #endif""" % {'num': CodeWrapper._module_counter}
  415. assert source == expected