m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

326 lines
9.5 KiB

7 months ago
  1. # This testfile tests SymPy <-> NumPy compatibility
  2. # Don't test any SymPy features here. Just pure interaction with NumPy.
  3. # Always write regular SymPy tests for anything, that can be tested in pure
  4. # Python (without numpy). Here we test everything, that a user may need when
  5. # using SymPy with NumPy
  6. from sympy.external.importtools import version_tuple
  7. from sympy.external import import_module
  8. numpy = import_module('numpy')
  9. if numpy:
  10. array, matrix, ndarray = numpy.array, numpy.matrix, numpy.ndarray
  11. else:
  12. #bin/test will not execute any tests now
  13. disabled = True
  14. from sympy.core.numbers import (Float, Integer, Rational)
  15. from sympy.core.symbol import (Symbol, symbols)
  16. from sympy.functions.elementary.trigonometric import sin
  17. from sympy.matrices.dense import (Matrix, list2numpy, matrix2numpy, symarray)
  18. from sympy.utilities.lambdify import lambdify
  19. import sympy
  20. import mpmath
  21. from sympy.abc import x, y, z
  22. from sympy.utilities.decorator import conserve_mpmath_dps
  23. from sympy.testing.pytest import raises
  24. # first, systematically check, that all operations are implemented and don't
  25. # raise an exception
  26. def test_systematic_basic():
  27. def s(sympy_object, numpy_array):
  28. _ = [sympy_object + numpy_array,
  29. numpy_array + sympy_object,
  30. sympy_object - numpy_array,
  31. numpy_array - sympy_object,
  32. sympy_object * numpy_array,
  33. numpy_array * sympy_object,
  34. sympy_object / numpy_array,
  35. numpy_array / sympy_object,
  36. sympy_object ** numpy_array,
  37. numpy_array ** sympy_object]
  38. x = Symbol("x")
  39. y = Symbol("y")
  40. sympy_objs = [
  41. Rational(2, 3),
  42. Float("1.3"),
  43. x,
  44. y,
  45. pow(x, y)*y,
  46. Integer(5),
  47. Float(5.5),
  48. ]
  49. numpy_objs = [
  50. array([1]),
  51. array([3, 8, -1]),
  52. array([x, x**2, Rational(5)]),
  53. array([x/y*sin(y), 5, Rational(5)]),
  54. ]
  55. for x in sympy_objs:
  56. for y in numpy_objs:
  57. s(x, y)
  58. # now some random tests, that test particular problems and that also
  59. # check that the results of the operations are correct
  60. def test_basics():
  61. one = Rational(1)
  62. zero = Rational(0)
  63. assert array(1) == array(one)
  64. assert array([one]) == array([one])
  65. assert array([x]) == array([x])
  66. assert array(x) == array(Symbol("x"))
  67. assert array(one + x) == array(1 + x)
  68. X = array([one, zero, zero])
  69. assert (X == array([one, zero, zero])).all()
  70. assert (X == array([one, 0, 0])).all()
  71. def test_arrays():
  72. one = Rational(1)
  73. zero = Rational(0)
  74. X = array([one, zero, zero])
  75. Y = one*X
  76. X = array([Symbol("a") + Rational(1, 2)])
  77. Y = X + X
  78. assert Y == array([1 + 2*Symbol("a")])
  79. Y = Y + 1
  80. assert Y == array([2 + 2*Symbol("a")])
  81. Y = X - X
  82. assert Y == array([0])
  83. def test_conversion1():
  84. a = list2numpy([x**2, x])
  85. #looks like an array?
  86. assert isinstance(a, ndarray)
  87. assert a[0] == x**2
  88. assert a[1] == x
  89. assert len(a) == 2
  90. #yes, it's the array
  91. def test_conversion2():
  92. a = 2*list2numpy([x**2, x])
  93. b = list2numpy([2*x**2, 2*x])
  94. assert (a == b).all()
  95. one = Rational(1)
  96. zero = Rational(0)
  97. X = list2numpy([one, zero, zero])
  98. Y = one*X
  99. X = list2numpy([Symbol("a") + Rational(1, 2)])
  100. Y = X + X
  101. assert Y == array([1 + 2*Symbol("a")])
  102. Y = Y + 1
  103. assert Y == array([2 + 2*Symbol("a")])
  104. Y = X - X
  105. assert Y == array([0])
  106. def test_list2numpy():
  107. assert (array([x**2, x]) == list2numpy([x**2, x])).all()
  108. def test_Matrix1():
  109. m = Matrix([[x, x**2], [5, 2/x]])
  110. assert (array(m.subs(x, 2)) == array([[2, 4], [5, 1]])).all()
  111. m = Matrix([[sin(x), x**2], [5, 2/x]])
  112. assert (array(m.subs(x, 2)) == array([[sin(2), 4], [5, 1]])).all()
  113. def test_Matrix2():
  114. m = Matrix([[x, x**2], [5, 2/x]])
  115. assert (matrix(m.subs(x, 2)) == matrix([[2, 4], [5, 1]])).all()
  116. m = Matrix([[sin(x), x**2], [5, 2/x]])
  117. assert (matrix(m.subs(x, 2)) == matrix([[sin(2), 4], [5, 1]])).all()
  118. def test_Matrix3():
  119. a = array([[2, 4], [5, 1]])
  120. assert Matrix(a) == Matrix([[2, 4], [5, 1]])
  121. assert Matrix(a) != Matrix([[2, 4], [5, 2]])
  122. a = array([[sin(2), 4], [5, 1]])
  123. assert Matrix(a) == Matrix([[sin(2), 4], [5, 1]])
  124. assert Matrix(a) != Matrix([[sin(0), 4], [5, 1]])
  125. def test_Matrix4():
  126. a = matrix([[2, 4], [5, 1]])
  127. assert Matrix(a) == Matrix([[2, 4], [5, 1]])
  128. assert Matrix(a) != Matrix([[2, 4], [5, 2]])
  129. a = matrix([[sin(2), 4], [5, 1]])
  130. assert Matrix(a) == Matrix([[sin(2), 4], [5, 1]])
  131. assert Matrix(a) != Matrix([[sin(0), 4], [5, 1]])
  132. def test_Matrix_sum():
  133. M = Matrix([[1, 2, 3], [x, y, x], [2*y, -50, z*x]])
  134. m = matrix([[2, 3, 4], [x, 5, 6], [x, y, z**2]])
  135. assert M + m == Matrix([[3, 5, 7], [2*x, y + 5, x + 6], [2*y + x, y - 50, z*x + z**2]])
  136. assert m + M == Matrix([[3, 5, 7], [2*x, y + 5, x + 6], [2*y + x, y - 50, z*x + z**2]])
  137. assert M + m == M.add(m)
  138. def test_Matrix_mul():
  139. M = Matrix([[1, 2, 3], [x, y, x]])
  140. m = matrix([[2, 4], [x, 6], [x, z**2]])
  141. assert M*m == Matrix([
  142. [ 2 + 5*x, 16 + 3*z**2],
  143. [2*x + x*y + x**2, 4*x + 6*y + x*z**2],
  144. ])
  145. assert m*M == Matrix([
  146. [ 2 + 4*x, 4 + 4*y, 6 + 4*x],
  147. [ 7*x, 2*x + 6*y, 9*x],
  148. [x + x*z**2, 2*x + y*z**2, 3*x + x*z**2],
  149. ])
  150. a = array([2])
  151. assert a[0] * M == 2 * M
  152. assert M * a[0] == 2 * M
  153. def test_Matrix_array():
  154. class matarray:
  155. def __array__(self):
  156. from numpy import array
  157. return array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
  158. matarr = matarray()
  159. assert Matrix(matarr) == Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
  160. def test_matrix2numpy():
  161. a = matrix2numpy(Matrix([[1, x**2], [3*sin(x), 0]]))
  162. assert isinstance(a, ndarray)
  163. assert a.shape == (2, 2)
  164. assert a[0, 0] == 1
  165. assert a[0, 1] == x**2
  166. assert a[1, 0] == 3*sin(x)
  167. assert a[1, 1] == 0
  168. def test_matrix2numpy_conversion():
  169. a = Matrix([[1, 2, sin(x)], [x**2, x, Rational(1, 2)]])
  170. b = array([[1, 2, sin(x)], [x**2, x, Rational(1, 2)]])
  171. assert (matrix2numpy(a) == b).all()
  172. assert matrix2numpy(a).dtype == numpy.dtype('object')
  173. c = matrix2numpy(Matrix([[1, 2], [10, 20]]), dtype='int8')
  174. d = matrix2numpy(Matrix([[1, 2], [10, 20]]), dtype='float64')
  175. assert c.dtype == numpy.dtype('int8')
  176. assert d.dtype == numpy.dtype('float64')
  177. def test_issue_3728():
  178. assert (Rational(1, 2)*array([2*x, 0]) == array([x, 0])).all()
  179. assert (Rational(1, 2) + array(
  180. [2*x, 0]) == array([2*x + Rational(1, 2), Rational(1, 2)])).all()
  181. assert (Float("0.5")*array([2*x, 0]) == array([Float("1.0")*x, 0])).all()
  182. assert (Float("0.5") + array(
  183. [2*x, 0]) == array([2*x + Float("0.5"), Float("0.5")])).all()
  184. @conserve_mpmath_dps
  185. def test_lambdify():
  186. mpmath.mp.dps = 16
  187. sin02 = mpmath.mpf("0.198669330795061215459412627")
  188. f = lambdify(x, sin(x), "numpy")
  189. prec = 1e-15
  190. assert -prec < f(0.2) - sin02 < prec
  191. # if this succeeds, it can't be a numpy function
  192. if version_tuple(numpy.__version__) >= version_tuple('1.17'):
  193. with raises(TypeError):
  194. f(x)
  195. else:
  196. with raises(AttributeError):
  197. f(x)
  198. def test_lambdify_matrix():
  199. f = lambdify(x, Matrix([[x, 2*x], [1, 2]]), [{'ImmutableMatrix': numpy.array}, "numpy"])
  200. assert (f(1) == array([[1, 2], [1, 2]])).all()
  201. def test_lambdify_matrix_multi_input():
  202. M = sympy.Matrix([[x**2, x*y, x*z],
  203. [y*x, y**2, y*z],
  204. [z*x, z*y, z**2]])
  205. f = lambdify((x, y, z), M, [{'ImmutableMatrix': numpy.array}, "numpy"])
  206. xh, yh, zh = 1.0, 2.0, 3.0
  207. expected = array([[xh**2, xh*yh, xh*zh],
  208. [yh*xh, yh**2, yh*zh],
  209. [zh*xh, zh*yh, zh**2]])
  210. actual = f(xh, yh, zh)
  211. assert numpy.allclose(actual, expected)
  212. def test_lambdify_matrix_vec_input():
  213. X = sympy.DeferredVector('X')
  214. M = Matrix([
  215. [X[0]**2, X[0]*X[1], X[0]*X[2]],
  216. [X[1]*X[0], X[1]**2, X[1]*X[2]],
  217. [X[2]*X[0], X[2]*X[1], X[2]**2]])
  218. f = lambdify(X, M, [{'ImmutableMatrix': numpy.array}, "numpy"])
  219. Xh = array([1.0, 2.0, 3.0])
  220. expected = array([[Xh[0]**2, Xh[0]*Xh[1], Xh[0]*Xh[2]],
  221. [Xh[1]*Xh[0], Xh[1]**2, Xh[1]*Xh[2]],
  222. [Xh[2]*Xh[0], Xh[2]*Xh[1], Xh[2]**2]])
  223. actual = f(Xh)
  224. assert numpy.allclose(actual, expected)
  225. def test_lambdify_transl():
  226. from sympy.utilities.lambdify import NUMPY_TRANSLATIONS
  227. for sym, mat in NUMPY_TRANSLATIONS.items():
  228. assert sym in sympy.__dict__
  229. assert mat in numpy.__dict__
  230. def test_symarray():
  231. """Test creation of numpy arrays of SymPy symbols."""
  232. import numpy as np
  233. import numpy.testing as npt
  234. syms = symbols('_0,_1,_2')
  235. s1 = symarray("", 3)
  236. s2 = symarray("", 3)
  237. npt.assert_array_equal(s1, np.array(syms, dtype=object))
  238. assert s1[0] == s2[0]
  239. a = symarray('a', 3)
  240. b = symarray('b', 3)
  241. assert not(a[0] == b[0])
  242. asyms = symbols('a_0,a_1,a_2')
  243. npt.assert_array_equal(a, np.array(asyms, dtype=object))
  244. # Multidimensional checks
  245. a2d = symarray('a', (2, 3))
  246. assert a2d.shape == (2, 3)
  247. a00, a12 = symbols('a_0_0,a_1_2')
  248. assert a2d[0, 0] == a00
  249. assert a2d[1, 2] == a12
  250. a3d = symarray('a', (2, 3, 2))
  251. assert a3d.shape == (2, 3, 2)
  252. a000, a120, a121 = symbols('a_0_0_0,a_1_2_0,a_1_2_1')
  253. assert a3d[0, 0, 0] == a000
  254. assert a3d[1, 2, 0] == a120
  255. assert a3d[1, 2, 1] == a121
  256. def test_vectorize():
  257. assert (numpy.vectorize(
  258. sin)([1, 2, 3]) == numpy.array([sin(1), sin(2), sin(3)])).all()