m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

63 lines
1.8 KiB

6 months ago
  1. import shutil
  2. from sympy.external import import_module
  3. from sympy.testing.pytest import skip
  4. from sympy.utilities._compilation.compilation import compile_link_import_strings
  5. from sympy.utilities._compilation.util import may_xfail
  6. numpy = import_module('numpy')
  7. cython = import_module('cython')
  8. _sources1 = [
  9. ('sigmoid.c', r"""
  10. #include <math.h>
  11. void sigmoid(int n, const double * const restrict in,
  12. double * const restrict out, double lim){
  13. for (int i=0; i<n; ++i){
  14. const double x = in[i];
  15. out[i] = x*pow(pow(x/lim, 8)+1, -1./8.);
  16. }
  17. }
  18. """),
  19. ('_sigmoid.pyx', r"""
  20. import numpy as np
  21. cimport numpy as cnp
  22. cdef extern void c_sigmoid "sigmoid" (int, const double * const,
  23. double * const, double)
  24. def sigmoid(double [:] inp, double lim=350.0):
  25. cdef cnp.ndarray[cnp.float64_t, ndim=1] out = np.empty(
  26. inp.size, dtype=np.float64)
  27. c_sigmoid(inp.size, &inp[0], &out[0], lim)
  28. return out
  29. """)
  30. ]
  31. def npy(data, lim=350.0):
  32. return data/((data/lim)**8+1)**(1/8.)
  33. @may_xfail
  34. def test_compile_link_import_strings():
  35. if not numpy:
  36. skip("numpy not installed.")
  37. if not cython:
  38. skip("cython not installed.")
  39. from sympy.utilities._compilation import has_c
  40. if not has_c():
  41. skip("No C compiler found.")
  42. compile_kw = dict(std='c99', include_dirs=[numpy.get_include()])
  43. info = None
  44. try:
  45. mod, info = compile_link_import_strings(_sources1, compile_kwargs=compile_kw)
  46. data = numpy.random.random(1024*1024*8) # 64 MB of RAM needed..
  47. res_mod = mod.sigmoid(data)
  48. res_npy = npy(data)
  49. assert numpy.allclose(res_mod, res_npy)
  50. finally:
  51. if info and info['build_dir']:
  52. shutil.rmtree(info['build_dir'])