图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

836 lines
34 KiB

  1. from sympy.core.add import Add
  2. from sympy.core.function import (Function, Lambda, diff)
  3. from sympy.core.mod import Mod
  4. from sympy.core import (Catalan, EulerGamma, GoldenRatio)
  5. from sympy.core.numbers import (E, Float, I, Integer, Rational, pi)
  6. from sympy.core.relational import Eq
  7. from sympy.core.singleton import S
  8. from sympy.core.symbol import (Dummy, symbols)
  9. from sympy.functions.combinatorial.factorials import factorial
  10. from sympy.functions.elementary.complexes import (conjugate, sign)
  11. from sympy.functions.elementary.exponential import (exp, log)
  12. from sympy.functions.elementary.miscellaneous import sqrt
  13. from sympy.functions.elementary.piecewise import Piecewise
  14. from sympy.functions.elementary.trigonometric import (atan2, cos, sin)
  15. from sympy.functions.special.gamma_functions import gamma
  16. from sympy.integrals.integrals import Integral
  17. from sympy.sets.fancysets import Range
  18. from sympy.codegen import For, Assignment, aug_assign
  19. from sympy.codegen.ast import Declaration, Variable, float32, float64, \
  20. value_const, real, bool_, While, FunctionPrototype, FunctionDefinition, \
  21. integer, Return
  22. from sympy.core.expr import UnevaluatedExpr
  23. from sympy.core.relational import Relational
  24. from sympy.logic.boolalg import And, Or, Not, Equivalent, Xor
  25. from sympy.matrices import Matrix, MatrixSymbol
  26. from sympy.printing.fortran import fcode, FCodePrinter
  27. from sympy.tensor import IndexedBase, Idx
  28. from sympy.utilities.lambdify import implemented_function
  29. from sympy.testing.pytest import raises, warns_deprecated_sympy
  30. def test_UnevaluatedExpr():
  31. p, q, r = symbols("p q r", real=True)
  32. q_r = UnevaluatedExpr(q + r)
  33. expr = abs(exp(p+q_r))
  34. assert fcode(expr, source_format="free") == "exp(p + (q + r))"
  35. x, y, z = symbols("x y z")
  36. y_z = UnevaluatedExpr(y + z)
  37. expr2 = abs(exp(x+y_z))
  38. assert fcode(expr2, human=False)[2].lstrip() == "exp(re(x) + re(y + z))"
  39. assert fcode(expr2, user_functions={"re": "realpart"}).lstrip() == "exp(realpart(x) + realpart(y + z))"
  40. def test_printmethod():
  41. x = symbols('x')
  42. class nint(Function):
  43. def _fcode(self, printer):
  44. return "nint(%s)" % printer._print(self.args[0])
  45. assert fcode(nint(x)) == " nint(x)"
  46. def test_fcode_sign(): #issue 12267
  47. x=symbols('x')
  48. y=symbols('y', integer=True)
  49. z=symbols('z', complex=True)
  50. assert fcode(sign(x), standard=95, source_format='free') == "merge(0d0, dsign(1d0, x), x == 0d0)"
  51. assert fcode(sign(y), standard=95, source_format='free') == "merge(0, isign(1, y), y == 0)"
  52. assert fcode(sign(z), standard=95, source_format='free') == "merge(cmplx(0d0, 0d0), z/abs(z), abs(z) == 0d0)"
  53. raises(NotImplementedError, lambda: fcode(sign(x)))
  54. def test_fcode_Pow():
  55. x, y = symbols('x,y')
  56. n = symbols('n', integer=True)
  57. assert fcode(x**3) == " x**3"
  58. assert fcode(x**(y**3)) == " x**(y**3)"
  59. assert fcode(1/(sin(x)*3.5)**(x - y**x)/(x**2 + y)) == \
  60. " (3.5d0*sin(x))**(-x + y**x)/(x**2 + y)"
  61. assert fcode(sqrt(x)) == ' sqrt(x)'
  62. assert fcode(sqrt(n)) == ' sqrt(dble(n))'
  63. assert fcode(x**0.5) == ' sqrt(x)'
  64. assert fcode(sqrt(x)) == ' sqrt(x)'
  65. assert fcode(sqrt(10)) == ' sqrt(10.0d0)'
  66. assert fcode(x**-1.0) == ' 1d0/x'
  67. assert fcode(x**-2.0, 'y', source_format='free') == 'y = x**(-2.0d0)' # 2823
  68. assert fcode(x**Rational(3, 7)) == ' x**(3.0d0/7.0d0)'
  69. def test_fcode_Rational():
  70. x = symbols('x')
  71. assert fcode(Rational(3, 7)) == " 3.0d0/7.0d0"
  72. assert fcode(Rational(18, 9)) == " 2"
  73. assert fcode(Rational(3, -7)) == " -3.0d0/7.0d0"
  74. assert fcode(Rational(-3, -7)) == " 3.0d0/7.0d0"
  75. assert fcode(x + Rational(3, 7)) == " x + 3.0d0/7.0d0"
  76. assert fcode(Rational(3, 7)*x) == " (3.0d0/7.0d0)*x"
  77. def test_fcode_Integer():
  78. assert fcode(Integer(67)) == " 67"
  79. assert fcode(Integer(-1)) == " -1"
  80. def test_fcode_Float():
  81. assert fcode(Float(42.0)) == " 42.0000000000000d0"
  82. assert fcode(Float(-1e20)) == " -1.00000000000000d+20"
  83. def test_fcode_functions():
  84. x, y = symbols('x,y')
  85. assert fcode(sin(x) ** cos(y)) == " sin(x)**cos(y)"
  86. raises(NotImplementedError, lambda: fcode(Mod(x, y), standard=66))
  87. raises(NotImplementedError, lambda: fcode(x % y, standard=66))
  88. raises(NotImplementedError, lambda: fcode(Mod(x, y), standard=77))
  89. raises(NotImplementedError, lambda: fcode(x % y, standard=77))
  90. for standard in [90, 95, 2003, 2008]:
  91. assert fcode(Mod(x, y), standard=standard) == " modulo(x, y)"
  92. assert fcode(x % y, standard=standard) == " modulo(x, y)"
  93. def test_case():
  94. ob = FCodePrinter()
  95. x,x_,x__,y,X,X_,Y = symbols('x,x_,x__,y,X,X_,Y')
  96. assert fcode(exp(x_) + sin(x*y) + cos(X*Y)) == \
  97. ' exp(x_) + sin(x*y) + cos(X__*Y_)'
  98. assert fcode(exp(x__) + 2*x*Y*X_**Rational(7, 2)) == \
  99. ' 2*X_**(7.0d0/2.0d0)*Y*x + exp(x__)'
  100. assert fcode(exp(x_) + sin(x*y) + cos(X*Y), name_mangling=False) == \
  101. ' exp(x_) + sin(x*y) + cos(X*Y)'
  102. assert fcode(x - cos(X), name_mangling=False) == ' x - cos(X)'
  103. assert ob.doprint(X*sin(x) + x_, assign_to='me') == ' me = X*sin(x_) + x__'
  104. assert ob.doprint(X*sin(x), assign_to='mu') == ' mu = X*sin(x_)'
  105. assert ob.doprint(x_, assign_to='ad') == ' ad = x__'
  106. n, m = symbols('n,m', integer=True)
  107. A = IndexedBase('A')
  108. x = IndexedBase('x')
  109. y = IndexedBase('y')
  110. i = Idx('i', m)
  111. I = Idx('I', n)
  112. assert fcode(A[i, I]*x[I], assign_to=y[i], source_format='free') == (
  113. "do i = 1, m\n"
  114. " y(i) = 0\n"
  115. "end do\n"
  116. "do i = 1, m\n"
  117. " do I_ = 1, n\n"
  118. " y(i) = A(i, I_)*x(I_) + y(i)\n"
  119. " end do\n"
  120. "end do" )
  121. #issue 6814
  122. def test_fcode_functions_with_integers():
  123. x= symbols('x')
  124. log10_17 = log(10).evalf(17)
  125. loglog10_17 = '0.8340324452479558d0'
  126. assert fcode(x * log(10)) == " x*%sd0" % log10_17
  127. assert fcode(x * log(10)) == " x*%sd0" % log10_17
  128. assert fcode(x * log(S(10))) == " x*%sd0" % log10_17
  129. assert fcode(log(S(10))) == " %sd0" % log10_17
  130. assert fcode(exp(10)) == " %sd0" % exp(10).evalf(17)
  131. assert fcode(x * log(log(10))) == " x*%s" % loglog10_17
  132. assert fcode(x * log(log(S(10)))) == " x*%s" % loglog10_17
  133. def test_fcode_NumberSymbol():
  134. prec = 17
  135. p = FCodePrinter()
  136. assert fcode(Catalan) == ' parameter (Catalan = %sd0)\n Catalan' % Catalan.evalf(prec)
  137. assert fcode(EulerGamma) == ' parameter (EulerGamma = %sd0)\n EulerGamma' % EulerGamma.evalf(prec)
  138. assert fcode(E) == ' parameter (E = %sd0)\n E' % E.evalf(prec)
  139. assert fcode(GoldenRatio) == ' parameter (GoldenRatio = %sd0)\n GoldenRatio' % GoldenRatio.evalf(prec)
  140. assert fcode(pi) == ' parameter (pi = %sd0)\n pi' % pi.evalf(prec)
  141. assert fcode(
  142. pi, precision=5) == ' parameter (pi = %sd0)\n pi' % pi.evalf(5)
  143. assert fcode(Catalan, human=False) == ({
  144. (Catalan, p._print(Catalan.evalf(prec)))}, set(), ' Catalan')
  145. assert fcode(EulerGamma, human=False) == ({(EulerGamma, p._print(
  146. EulerGamma.evalf(prec)))}, set(), ' EulerGamma')
  147. assert fcode(E, human=False) == (
  148. {(E, p._print(E.evalf(prec)))}, set(), ' E')
  149. assert fcode(GoldenRatio, human=False) == ({(GoldenRatio, p._print(
  150. GoldenRatio.evalf(prec)))}, set(), ' GoldenRatio')
  151. assert fcode(pi, human=False) == (
  152. {(pi, p._print(pi.evalf(prec)))}, set(), ' pi')
  153. assert fcode(pi, precision=5, human=False) == (
  154. {(pi, p._print(pi.evalf(5)))}, set(), ' pi')
  155. def test_fcode_complex():
  156. assert fcode(I) == " cmplx(0,1)"
  157. x = symbols('x')
  158. assert fcode(4*I) == " cmplx(0,4)"
  159. assert fcode(3 + 4*I) == " cmplx(3,4)"
  160. assert fcode(3 + 4*I + x) == " cmplx(3,4) + x"
  161. assert fcode(I*x) == " cmplx(0,1)*x"
  162. assert fcode(3 + 4*I - x) == " cmplx(3,4) - x"
  163. x = symbols('x', imaginary=True)
  164. assert fcode(5*x) == " 5*x"
  165. assert fcode(I*x) == " cmplx(0,1)*x"
  166. assert fcode(3 + x) == " x + 3"
  167. def test_implicit():
  168. x, y = symbols('x,y')
  169. assert fcode(sin(x)) == " sin(x)"
  170. assert fcode(atan2(x, y)) == " atan2(x, y)"
  171. assert fcode(conjugate(x)) == " conjg(x)"
  172. def test_not_fortran():
  173. x = symbols('x')
  174. g = Function('g')
  175. gamma_f = fcode(gamma(x))
  176. assert gamma_f == "C Not supported in Fortran:\nC gamma\n gamma(x)"
  177. assert fcode(Integral(sin(x))) == "C Not supported in Fortran:\nC Integral\n Integral(sin(x), x)"
  178. assert fcode(g(x)) == "C Not supported in Fortran:\nC g\n g(x)"
  179. def test_user_functions():
  180. x = symbols('x')
  181. assert fcode(sin(x), user_functions={"sin": "zsin"}) == " zsin(x)"
  182. x = symbols('x')
  183. assert fcode(
  184. gamma(x), user_functions={"gamma": "mygamma"}) == " mygamma(x)"
  185. g = Function('g')
  186. assert fcode(g(x), user_functions={"g": "great"}) == " great(x)"
  187. n = symbols('n', integer=True)
  188. assert fcode(
  189. factorial(n), user_functions={"factorial": "fct"}) == " fct(n)"
  190. def test_inline_function():
  191. x = symbols('x')
  192. g = implemented_function('g', Lambda(x, 2*x))
  193. assert fcode(g(x)) == " 2*x"
  194. g = implemented_function('g', Lambda(x, 2*pi/x))
  195. assert fcode(g(x)) == (
  196. " parameter (pi = %sd0)\n"
  197. " 2*pi/x"
  198. ) % pi.evalf(17)
  199. A = IndexedBase('A')
  200. i = Idx('i', symbols('n', integer=True))
  201. g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x)))
  202. assert fcode(g(A[i]), assign_to=A[i]) == (
  203. " do i = 1, n\n"
  204. " A(i) = (A(i) + 1)*(A(i) + 2)*A(i)\n"
  205. " end do"
  206. )
  207. def test_assign_to():
  208. x = symbols('x')
  209. assert fcode(sin(x), assign_to="s") == " s = sin(x)"
  210. def test_line_wrapping():
  211. x, y = symbols('x,y')
  212. assert fcode(((x + y)**10).expand(), assign_to="var") == (
  213. " var = x**10 + 10*x**9*y + 45*x**8*y**2 + 120*x**7*y**3 + 210*x**6*\n"
  214. " @ y**4 + 252*x**5*y**5 + 210*x**4*y**6 + 120*x**3*y**7 + 45*x**2*y\n"
  215. " @ **8 + 10*x*y**9 + y**10"
  216. )
  217. e = [x**i for i in range(11)]
  218. assert fcode(Add(*e)) == (
  219. " x**10 + x**9 + x**8 + x**7 + x**6 + x**5 + x**4 + x**3 + x**2 + x\n"
  220. " @ + 1"
  221. )
  222. def test_fcode_precedence():
  223. x, y = symbols("x y")
  224. assert fcode(And(x < y, y < x + 1), source_format="free") == \
  225. "x < y .and. y < x + 1"
  226. assert fcode(Or(x < y, y < x + 1), source_format="free") == \
  227. "x < y .or. y < x + 1"
  228. assert fcode(Xor(x < y, y < x + 1, evaluate=False),
  229. source_format="free") == "x < y .neqv. y < x + 1"
  230. assert fcode(Equivalent(x < y, y < x + 1), source_format="free") == \
  231. "x < y .eqv. y < x + 1"
  232. def test_fcode_Logical():
  233. x, y, z = symbols("x y z")
  234. # unary Not
  235. assert fcode(Not(x), source_format="free") == ".not. x"
  236. # binary And
  237. assert fcode(And(x, y), source_format="free") == "x .and. y"
  238. assert fcode(And(x, Not(y)), source_format="free") == "x .and. .not. y"
  239. assert fcode(And(Not(x), y), source_format="free") == "y .and. .not. x"
  240. assert fcode(And(Not(x), Not(y)), source_format="free") == \
  241. ".not. x .and. .not. y"
  242. assert fcode(Not(And(x, y), evaluate=False), source_format="free") == \
  243. ".not. (x .and. y)"
  244. # binary Or
  245. assert fcode(Or(x, y), source_format="free") == "x .or. y"
  246. assert fcode(Or(x, Not(y)), source_format="free") == "x .or. .not. y"
  247. assert fcode(Or(Not(x), y), source_format="free") == "y .or. .not. x"
  248. assert fcode(Or(Not(x), Not(y)), source_format="free") == \
  249. ".not. x .or. .not. y"
  250. assert fcode(Not(Or(x, y), evaluate=False), source_format="free") == \
  251. ".not. (x .or. y)"
  252. # mixed And/Or
  253. assert fcode(And(Or(y, z), x), source_format="free") == "x .and. (y .or. z)"
  254. assert fcode(And(Or(z, x), y), source_format="free") == "y .and. (x .or. z)"
  255. assert fcode(And(Or(x, y), z), source_format="free") == "z .and. (x .or. y)"
  256. assert fcode(Or(And(y, z), x), source_format="free") == "x .or. y .and. z"
  257. assert fcode(Or(And(z, x), y), source_format="free") == "y .or. x .and. z"
  258. assert fcode(Or(And(x, y), z), source_format="free") == "z .or. x .and. y"
  259. # trinary And
  260. assert fcode(And(x, y, z), source_format="free") == "x .and. y .and. z"
  261. assert fcode(And(x, y, Not(z)), source_format="free") == \
  262. "x .and. y .and. .not. z"
  263. assert fcode(And(x, Not(y), z), source_format="free") == \
  264. "x .and. z .and. .not. y"
  265. assert fcode(And(Not(x), y, z), source_format="free") == \
  266. "y .and. z .and. .not. x"
  267. assert fcode(Not(And(x, y, z), evaluate=False), source_format="free") == \
  268. ".not. (x .and. y .and. z)"
  269. # trinary Or
  270. assert fcode(Or(x, y, z), source_format="free") == "x .or. y .or. z"
  271. assert fcode(Or(x, y, Not(z)), source_format="free") == \
  272. "x .or. y .or. .not. z"
  273. assert fcode(Or(x, Not(y), z), source_format="free") == \
  274. "x .or. z .or. .not. y"
  275. assert fcode(Or(Not(x), y, z), source_format="free") == \
  276. "y .or. z .or. .not. x"
  277. assert fcode(Not(Or(x, y, z), evaluate=False), source_format="free") == \
  278. ".not. (x .or. y .or. z)"
  279. def test_fcode_Xlogical():
  280. x, y, z = symbols("x y z")
  281. # binary Xor
  282. assert fcode(Xor(x, y, evaluate=False), source_format="free") == \
  283. "x .neqv. y"
  284. assert fcode(Xor(x, Not(y), evaluate=False), source_format="free") == \
  285. "x .neqv. .not. y"
  286. assert fcode(Xor(Not(x), y, evaluate=False), source_format="free") == \
  287. "y .neqv. .not. x"
  288. assert fcode(Xor(Not(x), Not(y), evaluate=False),
  289. source_format="free") == ".not. x .neqv. .not. y"
  290. assert fcode(Not(Xor(x, y, evaluate=False), evaluate=False),
  291. source_format="free") == ".not. (x .neqv. y)"
  292. # binary Equivalent
  293. assert fcode(Equivalent(x, y), source_format="free") == "x .eqv. y"
  294. assert fcode(Equivalent(x, Not(y)), source_format="free") == \
  295. "x .eqv. .not. y"
  296. assert fcode(Equivalent(Not(x), y), source_format="free") == \
  297. "y .eqv. .not. x"
  298. assert fcode(Equivalent(Not(x), Not(y)), source_format="free") == \
  299. ".not. x .eqv. .not. y"
  300. assert fcode(Not(Equivalent(x, y), evaluate=False),
  301. source_format="free") == ".not. (x .eqv. y)"
  302. # mixed And/Equivalent
  303. assert fcode(Equivalent(And(y, z), x), source_format="free") == \
  304. "x .eqv. y .and. z"
  305. assert fcode(Equivalent(And(z, x), y), source_format="free") == \
  306. "y .eqv. x .and. z"
  307. assert fcode(Equivalent(And(x, y), z), source_format="free") == \
  308. "z .eqv. x .and. y"
  309. assert fcode(And(Equivalent(y, z), x), source_format="free") == \
  310. "x .and. (y .eqv. z)"
  311. assert fcode(And(Equivalent(z, x), y), source_format="free") == \
  312. "y .and. (x .eqv. z)"
  313. assert fcode(And(Equivalent(x, y), z), source_format="free") == \
  314. "z .and. (x .eqv. y)"
  315. # mixed Or/Equivalent
  316. assert fcode(Equivalent(Or(y, z), x), source_format="free") == \
  317. "x .eqv. y .or. z"
  318. assert fcode(Equivalent(Or(z, x), y), source_format="free") == \
  319. "y .eqv. x .or. z"
  320. assert fcode(Equivalent(Or(x, y), z), source_format="free") == \
  321. "z .eqv. x .or. y"
  322. assert fcode(Or(Equivalent(y, z), x), source_format="free") == \
  323. "x .or. (y .eqv. z)"
  324. assert fcode(Or(Equivalent(z, x), y), source_format="free") == \
  325. "y .or. (x .eqv. z)"
  326. assert fcode(Or(Equivalent(x, y), z), source_format="free") == \
  327. "z .or. (x .eqv. y)"
  328. # mixed Xor/Equivalent
  329. assert fcode(Equivalent(Xor(y, z, evaluate=False), x),
  330. source_format="free") == "x .eqv. (y .neqv. z)"
  331. assert fcode(Equivalent(Xor(z, x, evaluate=False), y),
  332. source_format="free") == "y .eqv. (x .neqv. z)"
  333. assert fcode(Equivalent(Xor(x, y, evaluate=False), z),
  334. source_format="free") == "z .eqv. (x .neqv. y)"
  335. assert fcode(Xor(Equivalent(y, z), x, evaluate=False),
  336. source_format="free") == "x .neqv. (y .eqv. z)"
  337. assert fcode(Xor(Equivalent(z, x), y, evaluate=False),
  338. source_format="free") == "y .neqv. (x .eqv. z)"
  339. assert fcode(Xor(Equivalent(x, y), z, evaluate=False),
  340. source_format="free") == "z .neqv. (x .eqv. y)"
  341. # mixed And/Xor
  342. assert fcode(Xor(And(y, z), x, evaluate=False), source_format="free") == \
  343. "x .neqv. y .and. z"
  344. assert fcode(Xor(And(z, x), y, evaluate=False), source_format="free") == \
  345. "y .neqv. x .and. z"
  346. assert fcode(Xor(And(x, y), z, evaluate=False), source_format="free") == \
  347. "z .neqv. x .and. y"
  348. assert fcode(And(Xor(y, z, evaluate=False), x), source_format="free") == \
  349. "x .and. (y .neqv. z)"
  350. assert fcode(And(Xor(z, x, evaluate=False), y), source_format="free") == \
  351. "y .and. (x .neqv. z)"
  352. assert fcode(And(Xor(x, y, evaluate=False), z), source_format="free") == \
  353. "z .and. (x .neqv. y)"
  354. # mixed Or/Xor
  355. assert fcode(Xor(Or(y, z), x, evaluate=False), source_format="free") == \
  356. "x .neqv. y .or. z"
  357. assert fcode(Xor(Or(z, x), y, evaluate=False), source_format="free") == \
  358. "y .neqv. x .or. z"
  359. assert fcode(Xor(Or(x, y), z, evaluate=False), source_format="free") == \
  360. "z .neqv. x .or. y"
  361. assert fcode(Or(Xor(y, z, evaluate=False), x), source_format="free") == \
  362. "x .or. (y .neqv. z)"
  363. assert fcode(Or(Xor(z, x, evaluate=False), y), source_format="free") == \
  364. "y .or. (x .neqv. z)"
  365. assert fcode(Or(Xor(x, y, evaluate=False), z), source_format="free") == \
  366. "z .or. (x .neqv. y)"
  367. # trinary Xor
  368. assert fcode(Xor(x, y, z, evaluate=False), source_format="free") == \
  369. "x .neqv. y .neqv. z"
  370. assert fcode(Xor(x, y, Not(z), evaluate=False), source_format="free") == \
  371. "x .neqv. y .neqv. .not. z"
  372. assert fcode(Xor(x, Not(y), z, evaluate=False), source_format="free") == \
  373. "x .neqv. z .neqv. .not. y"
  374. assert fcode(Xor(Not(x), y, z, evaluate=False), source_format="free") == \
  375. "y .neqv. z .neqv. .not. x"
  376. def test_fcode_Relational():
  377. x, y = symbols("x y")
  378. assert fcode(Relational(x, y, "=="), source_format="free") == "x == y"
  379. assert fcode(Relational(x, y, "!="), source_format="free") == "x /= y"
  380. assert fcode(Relational(x, y, ">="), source_format="free") == "x >= y"
  381. assert fcode(Relational(x, y, "<="), source_format="free") == "x <= y"
  382. assert fcode(Relational(x, y, ">"), source_format="free") == "x > y"
  383. assert fcode(Relational(x, y, "<"), source_format="free") == "x < y"
  384. def test_fcode_Piecewise():
  385. x = symbols('x')
  386. expr = Piecewise((x, x < 1), (x**2, True))
  387. # Check that inline conditional (merge) fails if standard isn't 95+
  388. raises(NotImplementedError, lambda: fcode(expr))
  389. code = fcode(expr, standard=95)
  390. expected = " merge(x, x**2, x < 1)"
  391. assert code == expected
  392. assert fcode(Piecewise((x, x < 1), (x**2, True)), assign_to="var") == (
  393. " if (x < 1) then\n"
  394. " var = x\n"
  395. " else\n"
  396. " var = x**2\n"
  397. " end if"
  398. )
  399. a = cos(x)/x
  400. b = sin(x)/x
  401. for i in range(10):
  402. a = diff(a, x)
  403. b = diff(b, x)
  404. expected = (
  405. " if (x < 0) then\n"
  406. " weird_name = -cos(x)/x + 10*sin(x)/x**2 + 90*cos(x)/x**3 - 720*\n"
  407. " @ sin(x)/x**4 - 5040*cos(x)/x**5 + 30240*sin(x)/x**6 + 151200*cos(x\n"
  408. " @ )/x**7 - 604800*sin(x)/x**8 - 1814400*cos(x)/x**9 + 3628800*sin(x\n"
  409. " @ )/x**10 + 3628800*cos(x)/x**11\n"
  410. " else\n"
  411. " weird_name = -sin(x)/x - 10*cos(x)/x**2 + 90*sin(x)/x**3 + 720*\n"
  412. " @ cos(x)/x**4 - 5040*sin(x)/x**5 - 30240*cos(x)/x**6 + 151200*sin(x\n"
  413. " @ )/x**7 + 604800*cos(x)/x**8 - 1814400*sin(x)/x**9 - 3628800*cos(x\n"
  414. " @ )/x**10 + 3628800*sin(x)/x**11\n"
  415. " end if"
  416. )
  417. code = fcode(Piecewise((a, x < 0), (b, True)), assign_to="weird_name")
  418. assert code == expected
  419. code = fcode(Piecewise((x, x < 1), (x**2, x > 1), (sin(x), True)), standard=95)
  420. expected = " merge(x, merge(x**2, sin(x), x > 1), x < 1)"
  421. assert code == expected
  422. # Check that Piecewise without a True (default) condition error
  423. expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
  424. raises(ValueError, lambda: fcode(expr))
  425. def test_wrap_fortran():
  426. # "########################################################################"
  427. printer = FCodePrinter()
  428. lines = [
  429. "C This is a long comment on a single line that must be wrapped properly to produce nice output",
  430. " this = is + a + long + and + nasty + fortran + statement + that * must + be + wrapped + properly",
  431. " this = is + a + long + and + nasty + fortran + statement + that * must + be + wrapped + properly",
  432. " this = is + a + long + and + nasty + fortran + statement + that * must + be + wrapped + properly",
  433. " this = is + a + long + and + nasty + fortran + statement + that*must + be + wrapped + properly",
  434. " this = is + a + long + and + nasty + fortran + statement + that*must + be + wrapped + properly",
  435. " this = is + a + long + and + nasty + fortran + statement + that*must + be + wrapped + properly",
  436. " this = is + a + long + and + nasty + fortran + statement + that*must + be + wrapped + properly",
  437. " this = is + a + long + and + nasty + fortran + statement + that**must + be + wrapped + properly",
  438. " this = is + a + long + and + nasty + fortran + statement + that**must + be + wrapped + properly",
  439. " this = is + a + long + and + nasty + fortran + statement + that**must + be + wrapped + properly",
  440. " this = is + a + long + and + nasty + fortran + statement + that**must + be + wrapped + properly",
  441. " this = is + a + long + and + nasty + fortran + statement + that**must + be + wrapped + properly",
  442. " this = is + a + long + and + nasty + fortran + statement(that)/must + be + wrapped + properly",
  443. " this = is + a + long + and + nasty + fortran + statement(that)/must + be + wrapped + properly",
  444. ]
  445. wrapped_lines = printer._wrap_fortran(lines)
  446. expected_lines = [
  447. "C This is a long comment on a single line that must be wrapped",
  448. "C properly to produce nice output",
  449. " this = is + a + long + and + nasty + fortran + statement + that *",
  450. " @ must + be + wrapped + properly",
  451. " this = is + a + long + and + nasty + fortran + statement + that *",
  452. " @ must + be + wrapped + properly",
  453. " this = is + a + long + and + nasty + fortran + statement + that",
  454. " @ * must + be + wrapped + properly",
  455. " this = is + a + long + and + nasty + fortran + statement + that*",
  456. " @ must + be + wrapped + properly",
  457. " this = is + a + long + and + nasty + fortran + statement + that*",
  458. " @ must + be + wrapped + properly",
  459. " this = is + a + long + and + nasty + fortran + statement + that",
  460. " @ *must + be + wrapped + properly",
  461. " this = is + a + long + and + nasty + fortran + statement +",
  462. " @ that*must + be + wrapped + properly",
  463. " this = is + a + long + and + nasty + fortran + statement + that**",
  464. " @ must + be + wrapped + properly",
  465. " this = is + a + long + and + nasty + fortran + statement + that**",
  466. " @ must + be + wrapped + properly",
  467. " this = is + a + long + and + nasty + fortran + statement + that",
  468. " @ **must + be + wrapped + properly",
  469. " this = is + a + long + and + nasty + fortran + statement + that",
  470. " @ **must + be + wrapped + properly",
  471. " this = is + a + long + and + nasty + fortran + statement +",
  472. " @ that**must + be + wrapped + properly",
  473. " this = is + a + long + and + nasty + fortran + statement(that)/",
  474. " @ must + be + wrapped + properly",
  475. " this = is + a + long + and + nasty + fortran + statement(that)",
  476. " @ /must + be + wrapped + properly",
  477. ]
  478. for line in wrapped_lines:
  479. assert len(line) <= 72
  480. for w, e in zip(wrapped_lines, expected_lines):
  481. assert w == e
  482. assert len(wrapped_lines) == len(expected_lines)
  483. def test_wrap_fortran_keep_d0():
  484. printer = FCodePrinter()
  485. lines = [
  486. ' this_variable_is_very_long_because_we_try_to_test_line_break=1.0d0',
  487. ' this_variable_is_very_long_because_we_try_to_test_line_break =1.0d0',
  488. ' this_variable_is_very_long_because_we_try_to_test_line_break = 1.0d0',
  489. ' this_variable_is_very_long_because_we_try_to_test_line_break = 1.0d0',
  490. ' this_variable_is_very_long_because_we_try_to_test_line_break = 1.0d0',
  491. ' this_variable_is_very_long_because_we_try_to_test_line_break = 10.0d0'
  492. ]
  493. expected = [
  494. ' this_variable_is_very_long_because_we_try_to_test_line_break=1.0d0',
  495. ' this_variable_is_very_long_because_we_try_to_test_line_break =',
  496. ' @ 1.0d0',
  497. ' this_variable_is_very_long_because_we_try_to_test_line_break =',
  498. ' @ 1.0d0',
  499. ' this_variable_is_very_long_because_we_try_to_test_line_break =',
  500. ' @ 1.0d0',
  501. ' this_variable_is_very_long_because_we_try_to_test_line_break =',
  502. ' @ 1.0d0',
  503. ' this_variable_is_very_long_because_we_try_to_test_line_break =',
  504. ' @ 10.0d0'
  505. ]
  506. assert printer._wrap_fortran(lines) == expected
  507. def test_settings():
  508. raises(TypeError, lambda: fcode(S(4), method="garbage"))
  509. def test_free_form_code_line():
  510. x, y = symbols('x,y')
  511. assert fcode(cos(x) + sin(y), source_format='free') == "sin(y) + cos(x)"
  512. def test_free_form_continuation_line():
  513. x, y = symbols('x,y')
  514. result = fcode(((cos(x) + sin(y))**(7)).expand(), source_format='free')
  515. expected = (
  516. 'sin(y)**7 + 7*sin(y)**6*cos(x) + 21*sin(y)**5*cos(x)**2 + 35*sin(y)**4* &\n'
  517. ' cos(x)**3 + 35*sin(y)**3*cos(x)**4 + 21*sin(y)**2*cos(x)**5 + 7* &\n'
  518. ' sin(y)*cos(x)**6 + cos(x)**7'
  519. )
  520. assert result == expected
  521. def test_free_form_comment_line():
  522. printer = FCodePrinter({'source_format': 'free'})
  523. lines = [ "! This is a long comment on a single line that must be wrapped properly to produce nice output"]
  524. expected = [
  525. '! This is a long comment on a single line that must be wrapped properly',
  526. '! to produce nice output']
  527. assert printer._wrap_fortran(lines) == expected
  528. def test_loops():
  529. n, m = symbols('n,m', integer=True)
  530. A = IndexedBase('A')
  531. x = IndexedBase('x')
  532. y = IndexedBase('y')
  533. i = Idx('i', m)
  534. j = Idx('j', n)
  535. expected = (
  536. 'do i = 1, m\n'
  537. ' y(i) = 0\n'
  538. 'end do\n'
  539. 'do i = 1, m\n'
  540. ' do j = 1, n\n'
  541. ' y(i) = %(rhs)s\n'
  542. ' end do\n'
  543. 'end do'
  544. )
  545. code = fcode(A[i, j]*x[j], assign_to=y[i], source_format='free')
  546. assert (code == expected % {'rhs': 'y(i) + A(i, j)*x(j)'} or
  547. code == expected % {'rhs': 'y(i) + x(j)*A(i, j)'} or
  548. code == expected % {'rhs': 'x(j)*A(i, j) + y(i)'} or
  549. code == expected % {'rhs': 'A(i, j)*x(j) + y(i)'})
  550. def test_dummy_loops():
  551. i, m = symbols('i m', integer=True, cls=Dummy)
  552. x = IndexedBase('x')
  553. y = IndexedBase('y')
  554. i = Idx(i, m)
  555. expected = (
  556. 'do i_%(icount)i = 1, m_%(mcount)i\n'
  557. ' y(i_%(icount)i) = x(i_%(icount)i)\n'
  558. 'end do'
  559. ) % {'icount': i.label.dummy_index, 'mcount': m.dummy_index}
  560. code = fcode(x[i], assign_to=y[i], source_format='free')
  561. assert code == expected
  562. def test_fcode_Indexed_without_looking_for_contraction():
  563. len_y = 5
  564. y = IndexedBase('y', shape=(len_y,))
  565. x = IndexedBase('x', shape=(len_y,))
  566. Dy = IndexedBase('Dy', shape=(len_y-1,))
  567. i = Idx('i', len_y-1)
  568. e=Eq(Dy[i], (y[i+1]-y[i])/(x[i+1]-x[i]))
  569. code0 = fcode(e.rhs, assign_to=e.lhs, contract=False)
  570. assert code0.endswith('Dy(i) = (y(i + 1) - y(i))/(x(i + 1) - x(i))')
  571. def test_derived_classes():
  572. class MyFancyFCodePrinter(FCodePrinter):
  573. _default_settings = FCodePrinter._default_settings.copy()
  574. printer = MyFancyFCodePrinter()
  575. x = symbols('x')
  576. assert printer.doprint(sin(x), "bork") == " bork = sin(x)"
  577. def test_indent():
  578. codelines = (
  579. 'subroutine test(a)\n'
  580. 'integer :: a, i, j\n'
  581. '\n'
  582. 'do\n'
  583. 'do \n'
  584. 'do j = 1, 5\n'
  585. 'if (a>b) then\n'
  586. 'if(b>0) then\n'
  587. 'a = 3\n'
  588. 'donot_indent_me = 2\n'
  589. 'do_not_indent_me_either = 2\n'
  590. 'ifIam_indented_something_went_wrong = 2\n'
  591. 'if_I_am_indented_something_went_wrong = 2\n'
  592. 'end should not be unindented here\n'
  593. 'end if\n'
  594. 'endif\n'
  595. 'end do\n'
  596. 'end do\n'
  597. 'enddo\n'
  598. 'end subroutine\n'
  599. '\n'
  600. 'subroutine test2(a)\n'
  601. 'integer :: a\n'
  602. 'do\n'
  603. 'a = a + 1\n'
  604. 'end do \n'
  605. 'end subroutine\n'
  606. )
  607. expected = (
  608. 'subroutine test(a)\n'
  609. 'integer :: a, i, j\n'
  610. '\n'
  611. 'do\n'
  612. ' do \n'
  613. ' do j = 1, 5\n'
  614. ' if (a>b) then\n'
  615. ' if(b>0) then\n'
  616. ' a = 3\n'
  617. ' donot_indent_me = 2\n'
  618. ' do_not_indent_me_either = 2\n'
  619. ' ifIam_indented_something_went_wrong = 2\n'
  620. ' if_I_am_indented_something_went_wrong = 2\n'
  621. ' end should not be unindented here\n'
  622. ' end if\n'
  623. ' endif\n'
  624. ' end do\n'
  625. ' end do\n'
  626. 'enddo\n'
  627. 'end subroutine\n'
  628. '\n'
  629. 'subroutine test2(a)\n'
  630. 'integer :: a\n'
  631. 'do\n'
  632. ' a = a + 1\n'
  633. 'end do \n'
  634. 'end subroutine\n'
  635. )
  636. p = FCodePrinter({'source_format': 'free'})
  637. result = p.indent_code(codelines)
  638. assert result == expected
  639. def test_Matrix_printing():
  640. x, y, z = symbols('x,y,z')
  641. # Test returning a Matrix
  642. mat = Matrix([x*y, Piecewise((2 + x, y>0), (y, True)), sin(z)])
  643. A = MatrixSymbol('A', 3, 1)
  644. assert fcode(mat, A) == (
  645. " A(1, 1) = x*y\n"
  646. " if (y > 0) then\n"
  647. " A(2, 1) = x + 2\n"
  648. " else\n"
  649. " A(2, 1) = y\n"
  650. " end if\n"
  651. " A(3, 1) = sin(z)")
  652. # Test using MatrixElements in expressions
  653. expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0]
  654. assert fcode(expr, standard=95) == (
  655. " merge(2*A(3, 1), A(3, 1), x > 0) + sin(A(2, 1)) + A(1, 1)")
  656. # Test using MatrixElements in a Matrix
  657. q = MatrixSymbol('q', 5, 1)
  658. M = MatrixSymbol('M', 3, 3)
  659. m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])],
  660. [q[1,0] + q[2,0], q[3, 0], 5],
  661. [2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]])
  662. assert fcode(m, M) == (
  663. " M(1, 1) = sin(q(2, 1))\n"
  664. " M(2, 1) = q(2, 1) + q(3, 1)\n"
  665. " M(3, 1) = 2*q(5, 1)/q(2, 1)\n"
  666. " M(1, 2) = 0\n"
  667. " M(2, 2) = q(4, 1)\n"
  668. " M(3, 2) = sqrt(q(1, 1)) + 4\n"
  669. " M(1, 3) = cos(q(3, 1))\n"
  670. " M(2, 3) = 5\n"
  671. " M(3, 3) = 0")
  672. def test_fcode_For():
  673. x, y = symbols('x y')
  674. f = For(x, Range(0, 10, 2), [Assignment(y, x * y)])
  675. sol = fcode(f)
  676. assert sol == (" do x = 0, 10, 2\n"
  677. " y = x*y\n"
  678. " end do")
  679. def test_fcode_Declaration():
  680. def check(expr, ref, **kwargs):
  681. assert fcode(expr, standard=95, source_format='free', **kwargs) == ref
  682. i = symbols('i', integer=True)
  683. var1 = Variable.deduced(i)
  684. dcl1 = Declaration(var1)
  685. check(dcl1, "integer*4 :: i")
  686. x, y = symbols('x y')
  687. var2 = Variable(x, float32, value=42, attrs={value_const})
  688. dcl2b = Declaration(var2)
  689. check(dcl2b, 'real*4, parameter :: x = 42')
  690. var3 = Variable(y, type=bool_)
  691. dcl3 = Declaration(var3)
  692. check(dcl3, 'logical :: y')
  693. check(float32, "real*4")
  694. check(float64, "real*8")
  695. check(real, "real*4", type_aliases={real: float32})
  696. check(real, "real*8", type_aliases={real: float64})
  697. def test_MatrixElement_printing():
  698. # test cases for issue #11821
  699. A = MatrixSymbol("A", 1, 3)
  700. B = MatrixSymbol("B", 1, 3)
  701. C = MatrixSymbol("C", 1, 3)
  702. assert(fcode(A[0, 0]) == " A(1, 1)")
  703. assert(fcode(3 * A[0, 0]) == " 3*A(1, 1)")
  704. F = C[0, 0].subs(C, A - B)
  705. assert(fcode(F) == " (A - B)(1, 1)")
  706. def test_aug_assign():
  707. x = symbols('x')
  708. assert fcode(aug_assign(x, '+', 1), source_format='free') == 'x = x + 1'
  709. def test_While():
  710. x = symbols('x')
  711. assert fcode(While(abs(x) > 1, [aug_assign(x, '-', 1)]), source_format='free') == (
  712. 'do while (abs(x) > 1)\n'
  713. ' x = x - 1\n'
  714. 'end do'
  715. )
  716. def test_FunctionPrototype_print():
  717. x = symbols('x')
  718. n = symbols('n', integer=True)
  719. vx = Variable(x, type=real)
  720. vn = Variable(n, type=integer)
  721. fp1 = FunctionPrototype(real, 'power', [vx, vn])
  722. # Should be changed to proper test once multi-line generation is working
  723. # see https://github.com/sympy/sympy/issues/15824
  724. raises(NotImplementedError, lambda: fcode(fp1))
  725. def test_FunctionDefinition_print():
  726. x = symbols('x')
  727. n = symbols('n', integer=True)
  728. vx = Variable(x, type=real)
  729. vn = Variable(n, type=integer)
  730. body = [Assignment(x, x**n), Return(x)]
  731. fd1 = FunctionDefinition(real, 'power', [vx, vn], body)
  732. # Should be changed to proper test once multi-line generation is working
  733. # see https://github.com/sympy/sympy/issues/15824
  734. raises(NotImplementedError, lambda: fcode(fd1))
  735. def test_fcode_submodule():
  736. # Test the compatibility sympy.printing.fcode module imports
  737. with warns_deprecated_sympy():
  738. import sympy.printing.fcode # noqa:F401