图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

639 lines
20 KiB

  1. """
  2. Python code printers
  3. This module contains Python code printers for plain Python as well as NumPy & SciPy enabled code.
  4. """
  5. from collections import defaultdict
  6. from itertools import chain
  7. from sympy.core import S
  8. from .precedence import precedence
  9. from .codeprinter import CodePrinter
  10. _kw = {
  11. 'and', 'as', 'assert', 'break', 'class', 'continue', 'def', 'del', 'elif',
  12. 'else', 'except', 'finally', 'for', 'from', 'global', 'if', 'import', 'in',
  13. 'is', 'lambda', 'not', 'or', 'pass', 'raise', 'return', 'try', 'while',
  14. 'with', 'yield', 'None', 'False', 'nonlocal', 'True'
  15. }
  16. _known_functions = {
  17. 'Abs': 'abs',
  18. }
  19. _known_functions_math = {
  20. 'acos': 'acos',
  21. 'acosh': 'acosh',
  22. 'asin': 'asin',
  23. 'asinh': 'asinh',
  24. 'atan': 'atan',
  25. 'atan2': 'atan2',
  26. 'atanh': 'atanh',
  27. 'ceiling': 'ceil',
  28. 'cos': 'cos',
  29. 'cosh': 'cosh',
  30. 'erf': 'erf',
  31. 'erfc': 'erfc',
  32. 'exp': 'exp',
  33. 'expm1': 'expm1',
  34. 'factorial': 'factorial',
  35. 'floor': 'floor',
  36. 'gamma': 'gamma',
  37. 'hypot': 'hypot',
  38. 'loggamma': 'lgamma',
  39. 'log': 'log',
  40. 'ln': 'log',
  41. 'log10': 'log10',
  42. 'log1p': 'log1p',
  43. 'log2': 'log2',
  44. 'sin': 'sin',
  45. 'sinh': 'sinh',
  46. 'Sqrt': 'sqrt',
  47. 'tan': 'tan',
  48. 'tanh': 'tanh'
  49. } # Not used from ``math``: [copysign isclose isfinite isinf isnan ldexp frexp pow modf
  50. # radians trunc fmod fsum gcd degrees fabs]
  51. _known_constants_math = {
  52. 'Exp1': 'e',
  53. 'Pi': 'pi',
  54. 'E': 'e',
  55. 'Infinity': 'inf',
  56. 'NaN': 'nan',
  57. 'ComplexInfinity': 'nan'
  58. }
  59. def _print_known_func(self, expr):
  60. known = self.known_functions[expr.__class__.__name__]
  61. return '{name}({args})'.format(name=self._module_format(known),
  62. args=', '.join(map(lambda arg: self._print(arg), expr.args)))
  63. def _print_known_const(self, expr):
  64. known = self.known_constants[expr.__class__.__name__]
  65. return self._module_format(known)
  66. class AbstractPythonCodePrinter(CodePrinter):
  67. printmethod = "_pythoncode"
  68. language = "Python"
  69. reserved_words = _kw
  70. modules = None # initialized to a set in __init__
  71. tab = ' '
  72. _kf = dict(chain(
  73. _known_functions.items(),
  74. [(k, 'math.' + v) for k, v in _known_functions_math.items()]
  75. ))
  76. _kc = {k: 'math.'+v for k, v in _known_constants_math.items()}
  77. _operators = {'and': 'and', 'or': 'or', 'not': 'not'}
  78. _default_settings = dict(
  79. CodePrinter._default_settings,
  80. user_functions={},
  81. precision=17,
  82. inline=True,
  83. fully_qualified_modules=True,
  84. contract=False,
  85. standard='python3',
  86. )
  87. def __init__(self, settings=None):
  88. super().__init__(settings)
  89. # Python standard handler
  90. std = self._settings['standard']
  91. if std is None:
  92. import sys
  93. std = 'python{}'.format(sys.version_info.major)
  94. if std != 'python3':
  95. raise ValueError('Only Python 3 is supported.')
  96. self.standard = std
  97. self.module_imports = defaultdict(set)
  98. # Known functions and constants handler
  99. self.known_functions = dict(self._kf, **(settings or {}).get(
  100. 'user_functions', {}))
  101. self.known_constants = dict(self._kc, **(settings or {}).get(
  102. 'user_constants', {}))
  103. def _declare_number_const(self, name, value):
  104. return "%s = %s" % (name, value)
  105. def _module_format(self, fqn, register=True):
  106. parts = fqn.split('.')
  107. if register and len(parts) > 1:
  108. self.module_imports['.'.join(parts[:-1])].add(parts[-1])
  109. if self._settings['fully_qualified_modules']:
  110. return fqn
  111. else:
  112. return fqn.split('(')[0].split('[')[0].split('.')[-1]
  113. def _format_code(self, lines):
  114. return lines
  115. def _get_statement(self, codestring):
  116. return "{}".format(codestring)
  117. def _get_comment(self, text):
  118. return " # {}".format(text)
  119. def _expand_fold_binary_op(self, op, args):
  120. """
  121. This method expands a fold on binary operations.
  122. ``functools.reduce`` is an example of a folded operation.
  123. For example, the expression
  124. `A + B + C + D`
  125. is folded into
  126. `((A + B) + C) + D`
  127. """
  128. if len(args) == 1:
  129. return self._print(args[0])
  130. else:
  131. return "%s(%s, %s)" % (
  132. self._module_format(op),
  133. self._expand_fold_binary_op(op, args[:-1]),
  134. self._print(args[-1]),
  135. )
  136. def _expand_reduce_binary_op(self, op, args):
  137. """
  138. This method expands a reductin on binary operations.
  139. Notice: this is NOT the same as ``functools.reduce``.
  140. For example, the expression
  141. `A + B + C + D`
  142. is reduced into:
  143. `(A + B) + (C + D)`
  144. """
  145. if len(args) == 1:
  146. return self._print(args[0])
  147. else:
  148. N = len(args)
  149. Nhalf = N // 2
  150. return "%s(%s, %s)" % (
  151. self._module_format(op),
  152. self._expand_reduce_binary_op(args[:Nhalf]),
  153. self._expand_reduce_binary_op(args[Nhalf:]),
  154. )
  155. def _get_einsum_string(self, subranks, contraction_indices):
  156. letters = self._get_letter_generator_for_einsum()
  157. contraction_string = ""
  158. counter = 0
  159. d = {j: min(i) for i in contraction_indices for j in i}
  160. indices = []
  161. for rank_arg in subranks:
  162. lindices = []
  163. for i in range(rank_arg):
  164. if counter in d:
  165. lindices.append(d[counter])
  166. else:
  167. lindices.append(counter)
  168. counter += 1
  169. indices.append(lindices)
  170. mapping = {}
  171. letters_free = []
  172. letters_dum = []
  173. for i in indices:
  174. for j in i:
  175. if j not in mapping:
  176. l = next(letters)
  177. mapping[j] = l
  178. else:
  179. l = mapping[j]
  180. contraction_string += l
  181. if j in d:
  182. if l not in letters_dum:
  183. letters_dum.append(l)
  184. else:
  185. letters_free.append(l)
  186. contraction_string += ","
  187. contraction_string = contraction_string[:-1]
  188. return contraction_string, letters_free, letters_dum
  189. def _print_NaN(self, expr):
  190. return "float('nan')"
  191. def _print_Infinity(self, expr):
  192. return "float('inf')"
  193. def _print_NegativeInfinity(self, expr):
  194. return "float('-inf')"
  195. def _print_ComplexInfinity(self, expr):
  196. return self._print_NaN(expr)
  197. def _print_Mod(self, expr):
  198. PREC = precedence(expr)
  199. return ('{} % {}'.format(*map(lambda x: self.parenthesize(x, PREC), expr.args)))
  200. def _print_Piecewise(self, expr):
  201. result = []
  202. i = 0
  203. for arg in expr.args:
  204. e = arg.expr
  205. c = arg.cond
  206. if i == 0:
  207. result.append('(')
  208. result.append('(')
  209. result.append(self._print(e))
  210. result.append(')')
  211. result.append(' if ')
  212. result.append(self._print(c))
  213. result.append(' else ')
  214. i += 1
  215. result = result[:-1]
  216. if result[-1] == 'True':
  217. result = result[:-2]
  218. result.append(')')
  219. else:
  220. result.append(' else None)')
  221. return ''.join(result)
  222. def _print_Relational(self, expr):
  223. "Relational printer for Equality and Unequality"
  224. op = {
  225. '==' :'equal',
  226. '!=' :'not_equal',
  227. '<' :'less',
  228. '<=' :'less_equal',
  229. '>' :'greater',
  230. '>=' :'greater_equal',
  231. }
  232. if expr.rel_op in op:
  233. lhs = self._print(expr.lhs)
  234. rhs = self._print(expr.rhs)
  235. return '({lhs} {op} {rhs})'.format(op=expr.rel_op, lhs=lhs, rhs=rhs)
  236. return super()._print_Relational(expr)
  237. def _print_ITE(self, expr):
  238. from sympy.functions.elementary.piecewise import Piecewise
  239. return self._print(expr.rewrite(Piecewise))
  240. def _print_Sum(self, expr):
  241. loops = (
  242. 'for {i} in range({a}, {b}+1)'.format(
  243. i=self._print(i),
  244. a=self._print(a),
  245. b=self._print(b))
  246. for i, a, b in expr.limits)
  247. return '(builtins.sum({function} {loops}))'.format(
  248. function=self._print(expr.function),
  249. loops=' '.join(loops))
  250. def _print_ImaginaryUnit(self, expr):
  251. return '1j'
  252. def _print_KroneckerDelta(self, expr):
  253. a, b = expr.args
  254. return '(1 if {a} == {b} else 0)'.format(
  255. a = self._print(a),
  256. b = self._print(b)
  257. )
  258. def _print_MatrixBase(self, expr):
  259. name = expr.__class__.__name__
  260. func = self.known_functions.get(name, name)
  261. return "%s(%s)" % (func, self._print(expr.tolist()))
  262. _print_SparseRepMatrix = \
  263. _print_MutableSparseMatrix = \
  264. _print_ImmutableSparseMatrix = \
  265. _print_Matrix = \
  266. _print_DenseMatrix = \
  267. _print_MutableDenseMatrix = \
  268. _print_ImmutableMatrix = \
  269. _print_ImmutableDenseMatrix = \
  270. lambda self, expr: self._print_MatrixBase(expr)
  271. def _indent_codestring(self, codestring):
  272. return '\n'.join([self.tab + line for line in codestring.split('\n')])
  273. def _print_FunctionDefinition(self, fd):
  274. body = '\n'.join(map(lambda arg: self._print(arg), fd.body))
  275. return "def {name}({parameters}):\n{body}".format(
  276. name=self._print(fd.name),
  277. parameters=', '.join([self._print(var.symbol) for var in fd.parameters]),
  278. body=self._indent_codestring(body)
  279. )
  280. def _print_While(self, whl):
  281. body = '\n'.join(map(lambda arg: self._print(arg), whl.body))
  282. return "while {cond}:\n{body}".format(
  283. cond=self._print(whl.condition),
  284. body=self._indent_codestring(body)
  285. )
  286. def _print_Declaration(self, decl):
  287. return '%s = %s' % (
  288. self._print(decl.variable.symbol),
  289. self._print(decl.variable.value)
  290. )
  291. def _print_Return(self, ret):
  292. arg, = ret.args
  293. return 'return %s' % self._print(arg)
  294. def _print_Print(self, prnt):
  295. print_args = ', '.join(map(lambda arg: self._print(arg), prnt.print_args))
  296. if prnt.format_string != None: # Must be '!= None', cannot be 'is not None'
  297. print_args = '{} % ({})'.format(
  298. self._print(prnt.format_string), print_args)
  299. if prnt.file != None: # Must be '!= None', cannot be 'is not None'
  300. print_args += ', file=%s' % self._print(prnt.file)
  301. return 'print(%s)' % print_args
  302. def _print_Stream(self, strm):
  303. if str(strm.name) == 'stdout':
  304. return self._module_format('sys.stdout')
  305. elif str(strm.name) == 'stderr':
  306. return self._module_format('sys.stderr')
  307. else:
  308. return self._print(strm.name)
  309. def _print_NoneToken(self, arg):
  310. return 'None'
  311. def _hprint_Pow(self, expr, rational=False, sqrt='math.sqrt'):
  312. """Printing helper function for ``Pow``
  313. Notes
  314. =====
  315. This only preprocesses the ``sqrt`` as math formatter
  316. Examples
  317. ========
  318. >>> from sympy import sqrt
  319. >>> from sympy.printing.pycode import PythonCodePrinter
  320. >>> from sympy.abc import x
  321. Python code printer automatically looks up ``math.sqrt``.
  322. >>> printer = PythonCodePrinter()
  323. >>> printer._hprint_Pow(sqrt(x), rational=True)
  324. 'x**(1/2)'
  325. >>> printer._hprint_Pow(sqrt(x), rational=False)
  326. 'math.sqrt(x)'
  327. >>> printer._hprint_Pow(1/sqrt(x), rational=True)
  328. 'x**(-1/2)'
  329. >>> printer._hprint_Pow(1/sqrt(x), rational=False)
  330. '1/math.sqrt(x)'
  331. Using sqrt from numpy or mpmath
  332. >>> printer._hprint_Pow(sqrt(x), sqrt='numpy.sqrt')
  333. 'numpy.sqrt(x)'
  334. >>> printer._hprint_Pow(sqrt(x), sqrt='mpmath.sqrt')
  335. 'mpmath.sqrt(x)'
  336. See Also
  337. ========
  338. sympy.printing.str.StrPrinter._print_Pow
  339. """
  340. PREC = precedence(expr)
  341. if expr.exp == S.Half and not rational:
  342. func = self._module_format(sqrt)
  343. arg = self._print(expr.base)
  344. return '{func}({arg})'.format(func=func, arg=arg)
  345. if expr.is_commutative:
  346. if -expr.exp is S.Half and not rational:
  347. func = self._module_format(sqrt)
  348. num = self._print(S.One)
  349. arg = self._print(expr.base)
  350. return "{num}/{func}({arg})".format(
  351. num=num, func=func, arg=arg)
  352. base_str = self.parenthesize(expr.base, PREC, strict=False)
  353. exp_str = self.parenthesize(expr.exp, PREC, strict=False)
  354. return "{}**{}".format(base_str, exp_str)
  355. class PythonCodePrinter(AbstractPythonCodePrinter):
  356. def _print_sign(self, e):
  357. return '(0.0 if {e} == 0 else {f}(1, {e}))'.format(
  358. f=self._module_format('math.copysign'), e=self._print(e.args[0]))
  359. def _print_Not(self, expr):
  360. PREC = precedence(expr)
  361. return self._operators['not'] + self.parenthesize(expr.args[0], PREC)
  362. def _print_Indexed(self, expr):
  363. base = expr.args[0]
  364. index = expr.args[1:]
  365. return "{}[{}]".format(str(base), ", ".join([self._print(ind) for ind in index]))
  366. def _print_Pow(self, expr, rational=False):
  367. return self._hprint_Pow(expr, rational=rational)
  368. def _print_Rational(self, expr):
  369. return '{}/{}'.format(expr.p, expr.q)
  370. def _print_Half(self, expr):
  371. return self._print_Rational(expr)
  372. def _print_frac(self, expr):
  373. from sympy.core.mod import Mod
  374. return self._print_Mod(Mod(expr.args[0], 1))
  375. def _print_Symbol(self, expr):
  376. name = super()._print_Symbol(expr)
  377. if name in self.reserved_words:
  378. if self._settings['error_on_reserved']:
  379. msg = ('This expression includes the symbol "{}" which is a '
  380. 'reserved keyword in this language.')
  381. raise ValueError(msg.format(name))
  382. return name + self._settings['reserved_word_suffix']
  383. elif '{' in name: # Remove curly braces from subscripted variables
  384. return name.replace('{', '').replace('}', '')
  385. else:
  386. return name
  387. _print_lowergamma = CodePrinter._print_not_supported
  388. _print_uppergamma = CodePrinter._print_not_supported
  389. _print_fresnelc = CodePrinter._print_not_supported
  390. _print_fresnels = CodePrinter._print_not_supported
  391. for k in PythonCodePrinter._kf:
  392. setattr(PythonCodePrinter, '_print_%s' % k, _print_known_func)
  393. for k in _known_constants_math:
  394. setattr(PythonCodePrinter, '_print_%s' % k, _print_known_const)
  395. def pycode(expr, **settings):
  396. """ Converts an expr to a string of Python code
  397. Parameters
  398. ==========
  399. expr : Expr
  400. A SymPy expression.
  401. fully_qualified_modules : bool
  402. Whether or not to write out full module names of functions
  403. (``math.sin`` vs. ``sin``). default: ``True``.
  404. standard : str or None, optional
  405. Only 'python3' (default) is supported.
  406. This parameter may be removed in the future.
  407. Examples
  408. ========
  409. >>> from sympy import pycode, tan, Symbol
  410. >>> pycode(tan(Symbol('x')) + 1)
  411. 'math.tan(x) + 1'
  412. """
  413. return PythonCodePrinter(settings).doprint(expr)
  414. _not_in_mpmath = 'log1p log2'.split()
  415. _in_mpmath = [(k, v) for k, v in _known_functions_math.items() if k not in _not_in_mpmath]
  416. _known_functions_mpmath = dict(_in_mpmath, **{
  417. 'beta': 'beta',
  418. 'frac': 'frac',
  419. 'fresnelc': 'fresnelc',
  420. 'fresnels': 'fresnels',
  421. 'sign': 'sign',
  422. 'loggamma': 'loggamma',
  423. 'hyper': 'hyper',
  424. 'meijerg': 'meijerg',
  425. 'besselj': 'besselj',
  426. 'bessely': 'bessely',
  427. 'besseli': 'besseli',
  428. 'besselk': 'besselk',
  429. })
  430. _known_constants_mpmath = {
  431. 'Exp1': 'e',
  432. 'Pi': 'pi',
  433. 'GoldenRatio': 'phi',
  434. 'EulerGamma': 'euler',
  435. 'Catalan': 'catalan',
  436. 'NaN': 'nan',
  437. 'Infinity': 'inf',
  438. 'NegativeInfinity': 'ninf'
  439. }
  440. def _unpack_integral_limits(integral_expr):
  441. """ helper function for _print_Integral that
  442. - accepts an Integral expression
  443. - returns a tuple of
  444. - a list variables of integration
  445. - a list of tuples of the upper and lower limits of integration
  446. """
  447. integration_vars = []
  448. limits = []
  449. for integration_range in integral_expr.limits:
  450. if len(integration_range) == 3:
  451. integration_var, lower_limit, upper_limit = integration_range
  452. else:
  453. raise NotImplementedError("Only definite integrals are supported")
  454. integration_vars.append(integration_var)
  455. limits.append((lower_limit, upper_limit))
  456. return integration_vars, limits
  457. class MpmathPrinter(PythonCodePrinter):
  458. """
  459. Lambda printer for mpmath which maintains precision for floats
  460. """
  461. printmethod = "_mpmathcode"
  462. language = "Python with mpmath"
  463. _kf = dict(chain(
  464. _known_functions.items(),
  465. [(k, 'mpmath.' + v) for k, v in _known_functions_mpmath.items()]
  466. ))
  467. _kc = {k: 'mpmath.'+v for k, v in _known_constants_mpmath.items()}
  468. def _print_Float(self, e):
  469. # XXX: This does not handle setting mpmath.mp.dps. It is assumed that
  470. # the caller of the lambdified function will have set it to sufficient
  471. # precision to match the Floats in the expression.
  472. # Remove 'mpz' if gmpy is installed.
  473. args = str(tuple(map(int, e._mpf_)))
  474. return '{func}({args})'.format(func=self._module_format('mpmath.mpf'), args=args)
  475. def _print_Rational(self, e):
  476. return "{func}({p})/{func}({q})".format(
  477. func=self._module_format('mpmath.mpf'),
  478. q=self._print(e.q),
  479. p=self._print(e.p)
  480. )
  481. def _print_Half(self, e):
  482. return self._print_Rational(e)
  483. def _print_uppergamma(self, e):
  484. return "{}({}, {}, {})".format(
  485. self._module_format('mpmath.gammainc'),
  486. self._print(e.args[0]),
  487. self._print(e.args[1]),
  488. self._module_format('mpmath.inf'))
  489. def _print_lowergamma(self, e):
  490. return "{}({}, 0, {})".format(
  491. self._module_format('mpmath.gammainc'),
  492. self._print(e.args[0]),
  493. self._print(e.args[1]))
  494. def _print_log2(self, e):
  495. return '{0}({1})/{0}(2)'.format(
  496. self._module_format('mpmath.log'), self._print(e.args[0]))
  497. def _print_log1p(self, e):
  498. return '{}({}+1)'.format(
  499. self._module_format('mpmath.log'), self._print(e.args[0]))
  500. def _print_Pow(self, expr, rational=False):
  501. return self._hprint_Pow(expr, rational=rational, sqrt='mpmath.sqrt')
  502. def _print_Integral(self, e):
  503. integration_vars, limits = _unpack_integral_limits(e)
  504. return "{}(lambda {}: {}, {})".format(
  505. self._module_format("mpmath.quad"),
  506. ", ".join(map(self._print, integration_vars)),
  507. self._print(e.args[0]),
  508. ", ".join("(%s, %s)" % tuple(map(self._print, l)) for l in limits))
  509. for k in MpmathPrinter._kf:
  510. setattr(MpmathPrinter, '_print_%s' % k, _print_known_func)
  511. for k in _known_constants_mpmath:
  512. setattr(MpmathPrinter, '_print_%s' % k, _print_known_const)
  513. class SymPyPrinter(AbstractPythonCodePrinter):
  514. language = "Python with SymPy"
  515. def _print_Function(self, expr):
  516. mod = expr.func.__module__ or ''
  517. return '%s(%s)' % (self._module_format(mod + ('.' if mod else '') + expr.func.__name__),
  518. ', '.join(map(lambda arg: self._print(arg), expr.args)))
  519. def _print_Pow(self, expr, rational=False):
  520. return self._hprint_Pow(expr, rational=rational, sqrt='sympy.sqrt')