m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

3376 lines
112 KiB

7 months ago
  1. """
  2. There are three types of functions implemented in SymPy:
  3. 1) defined functions (in the sense that they can be evaluated) like
  4. exp or sin; they have a name and a body:
  5. f = exp
  6. 2) undefined function which have a name but no body. Undefined
  7. functions can be defined using a Function class as follows:
  8. f = Function('f')
  9. (the result will be a Function instance)
  10. 3) anonymous function (or lambda function) which have a body (defined
  11. with dummy variables) but have no name:
  12. f = Lambda(x, exp(x)*x)
  13. f = Lambda((x, y), exp(x)*y)
  14. The fourth type of functions are composites, like (sin + cos)(x); these work in
  15. SymPy core, but are not yet part of SymPy.
  16. Examples
  17. ========
  18. >>> import sympy
  19. >>> f = sympy.Function("f")
  20. >>> from sympy.abc import x
  21. >>> f(x)
  22. f(x)
  23. >>> print(sympy.srepr(f(x).func))
  24. Function('f')
  25. >>> f(x).args
  26. (x,)
  27. """
  28. from typing import Any, Dict as tDict, Optional, Set as tSet, Tuple as tTuple, Union as tUnion
  29. from collections.abc import Iterable
  30. from .add import Add
  31. from .assumptions import ManagedProperties
  32. from .basic import Basic, _atomic
  33. from .cache import cacheit
  34. from .containers import Tuple, Dict
  35. from .decorators import _sympifyit
  36. from .expr import Expr, AtomicExpr
  37. from .logic import fuzzy_and, fuzzy_or, fuzzy_not, FuzzyBool
  38. from .mul import Mul
  39. from .numbers import Rational, Float, Integer
  40. from .operations import LatticeOp
  41. from .parameters import global_parameters
  42. from .rules import Transform
  43. from .singleton import S
  44. from .sympify import sympify
  45. from .sorting import default_sort_key, ordered
  46. from sympy.utilities.exceptions import (sympy_deprecation_warning,
  47. SymPyDeprecationWarning, ignore_warnings)
  48. from sympy.utilities.iterables import (has_dups, sift, iterable,
  49. is_sequence, uniq, topological_sort)
  50. from sympy.utilities.lambdify import MPMATH_TRANSLATIONS
  51. from sympy.utilities.misc import as_int, filldedent, func_name
  52. import mpmath
  53. from mpmath.libmp.libmpf import prec_to_dps
  54. import inspect
  55. from collections import Counter
  56. def _coeff_isneg(a):
  57. """Return True if the leading Number is negative.
  58. Examples
  59. ========
  60. >>> from sympy.core.function import _coeff_isneg
  61. >>> from sympy import S, Symbol, oo, pi
  62. >>> _coeff_isneg(-3*pi)
  63. True
  64. >>> _coeff_isneg(S(3))
  65. False
  66. >>> _coeff_isneg(-oo)
  67. True
  68. >>> _coeff_isneg(Symbol('n', negative=True)) # coeff is 1
  69. False
  70. For matrix expressions:
  71. >>> from sympy import MatrixSymbol, sqrt
  72. >>> A = MatrixSymbol("A", 3, 3)
  73. >>> _coeff_isneg(-sqrt(2)*A)
  74. True
  75. >>> _coeff_isneg(sqrt(2)*A)
  76. False
  77. """
  78. if a.is_MatMul:
  79. a = a.args[0]
  80. if a.is_Mul:
  81. a = a.args[0]
  82. return a.is_Number and a.is_extended_negative
  83. class PoleError(Exception):
  84. pass
  85. class ArgumentIndexError(ValueError):
  86. def __str__(self):
  87. return ("Invalid operation with argument number %s for Function %s" %
  88. (self.args[1], self.args[0]))
  89. class BadSignatureError(TypeError):
  90. '''Raised when a Lambda is created with an invalid signature'''
  91. pass
  92. class BadArgumentsError(TypeError):
  93. '''Raised when a Lambda is called with an incorrect number of arguments'''
  94. pass
  95. # Python 3 version that does not raise a Deprecation warning
  96. def arity(cls):
  97. """Return the arity of the function if it is known, else None.
  98. Explanation
  99. ===========
  100. When default values are specified for some arguments, they are
  101. optional and the arity is reported as a tuple of possible values.
  102. Examples
  103. ========
  104. >>> from sympy import arity, log
  105. >>> arity(lambda x: x)
  106. 1
  107. >>> arity(log)
  108. (1, 2)
  109. >>> arity(lambda *x: sum(x)) is None
  110. True
  111. """
  112. eval_ = getattr(cls, 'eval', cls)
  113. parameters = inspect.signature(eval_).parameters.items()
  114. if [p for _, p in parameters if p.kind == p.VAR_POSITIONAL]:
  115. return
  116. p_or_k = [p for _, p in parameters if p.kind == p.POSITIONAL_OR_KEYWORD]
  117. # how many have no default and how many have a default value
  118. no, yes = map(len, sift(p_or_k,
  119. lambda p:p.default == p.empty, binary=True))
  120. return no if not yes else tuple(range(no, no + yes + 1))
  121. class FunctionClass(ManagedProperties):
  122. """
  123. Base class for function classes. FunctionClass is a subclass of type.
  124. Use Function('<function name>' [ , signature ]) to create
  125. undefined function classes.
  126. """
  127. _new = type.__new__
  128. def __init__(cls, *args, **kwargs):
  129. # honor kwarg value or class-defined value before using
  130. # the number of arguments in the eval function (if present)
  131. nargs = kwargs.pop('nargs', cls.__dict__.get('nargs', arity(cls)))
  132. if nargs is None and 'nargs' not in cls.__dict__:
  133. for supcls in cls.__mro__:
  134. if hasattr(supcls, '_nargs'):
  135. nargs = supcls._nargs
  136. break
  137. else:
  138. continue
  139. # Canonicalize nargs here; change to set in nargs.
  140. if is_sequence(nargs):
  141. if not nargs:
  142. raise ValueError(filldedent('''
  143. Incorrectly specified nargs as %s:
  144. if there are no arguments, it should be
  145. `nargs = 0`;
  146. if there are any number of arguments,
  147. it should be
  148. `nargs = None`''' % str(nargs)))
  149. nargs = tuple(ordered(set(nargs)))
  150. elif nargs is not None:
  151. nargs = (as_int(nargs),)
  152. cls._nargs = nargs
  153. super().__init__(*args, **kwargs)
  154. @property
  155. def __signature__(self):
  156. """
  157. Allow Python 3's inspect.signature to give a useful signature for
  158. Function subclasses.
  159. """
  160. # Python 3 only, but backports (like the one in IPython) still might
  161. # call this.
  162. try:
  163. from inspect import signature
  164. except ImportError:
  165. return None
  166. # TODO: Look at nargs
  167. return signature(self.eval)
  168. @property
  169. def free_symbols(self):
  170. return set()
  171. @property
  172. def xreplace(self):
  173. # Function needs args so we define a property that returns
  174. # a function that takes args...and then use that function
  175. # to return the right value
  176. return lambda rule, **_: rule.get(self, self)
  177. @property
  178. def nargs(self):
  179. """Return a set of the allowed number of arguments for the function.
  180. Examples
  181. ========
  182. >>> from sympy import Function
  183. >>> f = Function('f')
  184. If the function can take any number of arguments, the set of whole
  185. numbers is returned:
  186. >>> Function('f').nargs
  187. Naturals0
  188. If the function was initialized to accept one or more arguments, a
  189. corresponding set will be returned:
  190. >>> Function('f', nargs=1).nargs
  191. {1}
  192. >>> Function('f', nargs=(2, 1)).nargs
  193. {1, 2}
  194. The undefined function, after application, also has the nargs
  195. attribute; the actual number of arguments is always available by
  196. checking the ``args`` attribute:
  197. >>> f = Function('f')
  198. >>> f(1).nargs
  199. Naturals0
  200. >>> len(f(1).args)
  201. 1
  202. """
  203. from sympy.sets.sets import FiniteSet
  204. # XXX it would be nice to handle this in __init__ but there are import
  205. # problems with trying to import FiniteSet there
  206. return FiniteSet(*self._nargs) if self._nargs else S.Naturals0
  207. def __repr__(cls):
  208. return cls.__name__
  209. class Application(Basic, metaclass=FunctionClass):
  210. """
  211. Base class for applied functions.
  212. Explanation
  213. ===========
  214. Instances of Application represent the result of applying an application of
  215. any type to any object.
  216. """
  217. is_Function = True
  218. @cacheit
  219. def __new__(cls, *args, **options):
  220. from sympy.sets.fancysets import Naturals0
  221. from sympy.sets.sets import FiniteSet
  222. args = list(map(sympify, args))
  223. evaluate = options.pop('evaluate', global_parameters.evaluate)
  224. # WildFunction (and anything else like it) may have nargs defined
  225. # and we throw that value away here
  226. options.pop('nargs', None)
  227. if options:
  228. raise ValueError("Unknown options: %s" % options)
  229. if evaluate:
  230. evaluated = cls.eval(*args)
  231. if evaluated is not None:
  232. return evaluated
  233. obj = super().__new__(cls, *args, **options)
  234. # make nargs uniform here
  235. sentinel = object()
  236. objnargs = getattr(obj, "nargs", sentinel)
  237. if objnargs is not sentinel:
  238. # things passing through here:
  239. # - functions subclassed from Function (e.g. myfunc(1).nargs)
  240. # - functions like cos(1).nargs
  241. # - AppliedUndef with given nargs like Function('f', nargs=1)(1).nargs
  242. # Canonicalize nargs here
  243. if is_sequence(objnargs):
  244. nargs = tuple(ordered(set(objnargs)))
  245. elif objnargs is not None:
  246. nargs = (as_int(objnargs),)
  247. else:
  248. nargs = None
  249. else:
  250. # things passing through here:
  251. # - WildFunction('f').nargs
  252. # - AppliedUndef with no nargs like Function('f')(1).nargs
  253. nargs = obj._nargs # note the underscore here
  254. # convert to FiniteSet
  255. obj.nargs = FiniteSet(*nargs) if nargs else Naturals0()
  256. return obj
  257. @classmethod
  258. def eval(cls, *args):
  259. """
  260. Returns a canonical form of cls applied to arguments args.
  261. Explanation
  262. ===========
  263. The eval() method is called when the class cls is about to be
  264. instantiated and it should return either some simplified instance
  265. (possible of some other class), or if the class cls should be
  266. unmodified, return None.
  267. Examples of eval() for the function "sign"
  268. ---------------------------------------------
  269. .. code-block:: python
  270. @classmethod
  271. def eval(cls, arg):
  272. if arg is S.NaN:
  273. return S.NaN
  274. if arg.is_zero: return S.Zero
  275. if arg.is_positive: return S.One
  276. if arg.is_negative: return S.NegativeOne
  277. if isinstance(arg, Mul):
  278. coeff, terms = arg.as_coeff_Mul(rational=True)
  279. if coeff is not S.One:
  280. return cls(coeff) * cls(terms)
  281. """
  282. return
  283. @property
  284. def func(self):
  285. return self.__class__
  286. def _eval_subs(self, old, new):
  287. if (old.is_Function and new.is_Function and
  288. callable(old) and callable(new) and
  289. old == self.func and len(self.args) in new.nargs):
  290. return new(*[i._subs(old, new) for i in self.args])
  291. class Function(Application, Expr):
  292. """
  293. Base class for applied mathematical functions.
  294. It also serves as a constructor for undefined function classes.
  295. Examples
  296. ========
  297. First example shows how to use Function as a constructor for undefined
  298. function classes:
  299. >>> from sympy import Function, Symbol
  300. >>> x = Symbol('x')
  301. >>> f = Function('f')
  302. >>> g = Function('g')(x)
  303. >>> f
  304. f
  305. >>> f(x)
  306. f(x)
  307. >>> g
  308. g(x)
  309. >>> f(x).diff(x)
  310. Derivative(f(x), x)
  311. >>> g.diff(x)
  312. Derivative(g(x), x)
  313. Assumptions can be passed to Function, and if function is initialized with a
  314. Symbol, the function inherits the name and assumptions associated with the Symbol:
  315. >>> f_real = Function('f', real=True)
  316. >>> f_real(x).is_real
  317. True
  318. >>> f_real_inherit = Function(Symbol('f', real=True))
  319. >>> f_real_inherit(x).is_real
  320. True
  321. Note that assumptions on a function are unrelated to the assumptions on
  322. the variable it is called on. If you want to add a relationship, subclass
  323. Function and define the appropriate ``_eval_is_assumption`` methods.
  324. In the following example Function is used as a base class for
  325. ``my_func`` that represents a mathematical function *my_func*. Suppose
  326. that it is well known, that *my_func(0)* is *1* and *my_func* at infinity
  327. goes to *0*, so we want those two simplifications to occur automatically.
  328. Suppose also that *my_func(x)* is real exactly when *x* is real. Here is
  329. an implementation that honours those requirements:
  330. >>> from sympy import Function, S, oo, I, sin
  331. >>> class my_func(Function):
  332. ...
  333. ... @classmethod
  334. ... def eval(cls, x):
  335. ... if x.is_Number:
  336. ... if x.is_zero:
  337. ... return S.One
  338. ... elif x is S.Infinity:
  339. ... return S.Zero
  340. ...
  341. ... def _eval_is_real(self):
  342. ... return self.args[0].is_real
  343. ...
  344. >>> x = S('x')
  345. >>> my_func(0) + sin(0)
  346. 1
  347. >>> my_func(oo)
  348. 0
  349. >>> my_func(3.54).n() # Not yet implemented for my_func.
  350. my_func(3.54)
  351. >>> my_func(I).is_real
  352. False
  353. In order for ``my_func`` to become useful, several other methods would
  354. need to be implemented. See source code of some of the already
  355. implemented functions for more complete examples.
  356. Also, if the function can take more than one argument, then ``nargs``
  357. must be defined, e.g. if ``my_func`` can take one or two arguments
  358. then,
  359. >>> class my_func(Function):
  360. ... nargs = (1, 2)
  361. ...
  362. >>>
  363. """
  364. @property
  365. def _diff_wrt(self):
  366. return False
  367. @cacheit
  368. def __new__(cls, *args, **options):
  369. # Handle calls like Function('f')
  370. if cls is Function:
  371. return UndefinedFunction(*args, **options)
  372. n = len(args)
  373. if n not in cls.nargs:
  374. # XXX: exception message must be in exactly this format to
  375. # make it work with NumPy's functions like vectorize(). See,
  376. # for example, https://github.com/numpy/numpy/issues/1697.
  377. # The ideal solution would be just to attach metadata to
  378. # the exception and change NumPy to take advantage of this.
  379. temp = ('%(name)s takes %(qual)s %(args)s '
  380. 'argument%(plural)s (%(given)s given)')
  381. raise TypeError(temp % {
  382. 'name': cls,
  383. 'qual': 'exactly' if len(cls.nargs) == 1 else 'at least',
  384. 'args': min(cls.nargs),
  385. 'plural': 's'*(min(cls.nargs) != 1),
  386. 'given': n})
  387. evaluate = options.get('evaluate', global_parameters.evaluate)
  388. result = super().__new__(cls, *args, **options)
  389. if evaluate and isinstance(result, cls) and result.args:
  390. pr2 = min(cls._should_evalf(a) for a in result.args)
  391. if pr2 > 0:
  392. pr = max(cls._should_evalf(a) for a in result.args)
  393. result = result.evalf(prec_to_dps(pr))
  394. return result
  395. @classmethod
  396. def _should_evalf(cls, arg):
  397. """
  398. Decide if the function should automatically evalf().
  399. Explanation
  400. ===========
  401. By default (in this implementation), this happens if (and only if) the
  402. ARG is a floating point number.
  403. This function is used by __new__.
  404. Returns the precision to evalf to, or -1 if it shouldn't evalf.
  405. """
  406. if arg.is_Float:
  407. return arg._prec
  408. if not arg.is_Add:
  409. return -1
  410. from .evalf import pure_complex
  411. m = pure_complex(arg)
  412. if m is None or not (m[0].is_Float or m[1].is_Float):
  413. return -1
  414. l = [i._prec for i in m if i.is_Float]
  415. l.append(-1)
  416. return max(l)
  417. @classmethod
  418. def class_key(cls):
  419. from sympy.sets.fancysets import Naturals0
  420. funcs = {
  421. 'exp': 10,
  422. 'log': 11,
  423. 'sin': 20,
  424. 'cos': 21,
  425. 'tan': 22,
  426. 'cot': 23,
  427. 'sinh': 30,
  428. 'cosh': 31,
  429. 'tanh': 32,
  430. 'coth': 33,
  431. 'conjugate': 40,
  432. 're': 41,
  433. 'im': 42,
  434. 'arg': 43,
  435. }
  436. name = cls.__name__
  437. try:
  438. i = funcs[name]
  439. except KeyError:
  440. i = 0 if isinstance(cls.nargs, Naturals0) else 10000
  441. return 4, i, name
  442. def _eval_evalf(self, prec):
  443. def _get_mpmath_func(fname):
  444. """Lookup mpmath function based on name"""
  445. if isinstance(self, AppliedUndef):
  446. # Shouldn't lookup in mpmath but might have ._imp_
  447. return None
  448. if not hasattr(mpmath, fname):
  449. fname = MPMATH_TRANSLATIONS.get(fname, None)
  450. if fname is None:
  451. return None
  452. return getattr(mpmath, fname)
  453. _eval_mpmath = getattr(self, '_eval_mpmath', None)
  454. if _eval_mpmath is None:
  455. func = _get_mpmath_func(self.func.__name__)
  456. args = self.args
  457. else:
  458. func, args = _eval_mpmath()
  459. # Fall-back evaluation
  460. if func is None:
  461. imp = getattr(self, '_imp_', None)
  462. if imp is None:
  463. return None
  464. try:
  465. return Float(imp(*[i.evalf(prec) for i in self.args]), prec)
  466. except (TypeError, ValueError):
  467. return None
  468. # Convert all args to mpf or mpc
  469. # Convert the arguments to *higher* precision than requested for the
  470. # final result.
  471. # XXX + 5 is a guess, it is similar to what is used in evalf.py. Should
  472. # we be more intelligent about it?
  473. try:
  474. args = [arg._to_mpmath(prec + 5) for arg in args]
  475. def bad(m):
  476. from mpmath import mpf, mpc
  477. # the precision of an mpf value is the last element
  478. # if that is 1 (and m[1] is not 1 which would indicate a
  479. # power of 2), then the eval failed; so check that none of
  480. # the arguments failed to compute to a finite precision.
  481. # Note: An mpc value has two parts, the re and imag tuple;
  482. # check each of those parts, too. Anything else is allowed to
  483. # pass
  484. if isinstance(m, mpf):
  485. m = m._mpf_
  486. return m[1] !=1 and m[-1] == 1
  487. elif isinstance(m, mpc):
  488. m, n = m._mpc_
  489. return m[1] !=1 and m[-1] == 1 and \
  490. n[1] !=1 and n[-1] == 1
  491. else:
  492. return False
  493. if any(bad(a) for a in args):
  494. raise ValueError # one or more args failed to compute with significance
  495. except ValueError:
  496. return
  497. with mpmath.workprec(prec):
  498. v = func(*args)
  499. return Expr._from_mpmath(v, prec)
  500. def _eval_derivative(self, s):
  501. # f(x).diff(s) -> x.diff(s) * f.fdiff(1)(s)
  502. i = 0
  503. l = []
  504. for a in self.args:
  505. i += 1
  506. da = a.diff(s)
  507. if da.is_zero:
  508. continue
  509. try:
  510. df = self.fdiff(i)
  511. except ArgumentIndexError:
  512. df = Function.fdiff(self, i)
  513. l.append(df * da)
  514. return Add(*l)
  515. def _eval_is_commutative(self):
  516. return fuzzy_and(a.is_commutative for a in self.args)
  517. def _eval_is_meromorphic(self, x, a):
  518. if not self.args:
  519. return True
  520. if any(arg.has(x) for arg in self.args[1:]):
  521. return False
  522. arg = self.args[0]
  523. if not arg._eval_is_meromorphic(x, a):
  524. return None
  525. return fuzzy_not(type(self).is_singular(arg.subs(x, a)))
  526. _singularities = None # type: tUnion[FuzzyBool, tTuple[Expr, ...]]
  527. @classmethod
  528. def is_singular(cls, a):
  529. """
  530. Tests whether the argument is an essential singularity
  531. or a branch point, or the functions is non-holomorphic.
  532. """
  533. ss = cls._singularities
  534. if ss in (True, None, False):
  535. return ss
  536. return fuzzy_or(a.is_infinite if s is S.ComplexInfinity
  537. else (a - s).is_zero for s in ss)
  538. def as_base_exp(self):
  539. """
  540. Returns the method as the 2-tuple (base, exponent).
  541. """
  542. return self, S.One
  543. def _eval_aseries(self, n, args0, x, logx):
  544. """
  545. Compute an asymptotic expansion around args0, in terms of self.args.
  546. This function is only used internally by _eval_nseries and should not
  547. be called directly; derived classes can overwrite this to implement
  548. asymptotic expansions.
  549. """
  550. raise PoleError(filldedent('''
  551. Asymptotic expansion of %s around %s is
  552. not implemented.''' % (type(self), args0)))
  553. def _eval_nseries(self, x, n, logx, cdir=0):
  554. """
  555. This function does compute series for multivariate functions,
  556. but the expansion is always in terms of *one* variable.
  557. Examples
  558. ========
  559. >>> from sympy import atan2
  560. >>> from sympy.abc import x, y
  561. >>> atan2(x, y).series(x, n=2)
  562. atan2(0, y) + x/y + O(x**2)
  563. >>> atan2(x, y).series(y, n=2)
  564. -y/x + atan2(x, 0) + O(y**2)
  565. This function also computes asymptotic expansions, if necessary
  566. and possible:
  567. >>> from sympy import loggamma
  568. >>> loggamma(1/x)._eval_nseries(x,0,None)
  569. -1/x - log(x)/x + log(x)/2 + O(1)
  570. """
  571. from .symbol import uniquely_named_symbol
  572. from sympy.series.order import Order
  573. from sympy.sets.sets import FiniteSet
  574. args = self.args
  575. args0 = [t.limit(x, 0) for t in args]
  576. if any(t.is_finite is False for t in args0):
  577. from .numbers import oo, zoo, nan
  578. # XXX could use t.as_leading_term(x) here but it's a little
  579. # slower
  580. a = [t.compute_leading_term(x, logx=logx) for t in args]
  581. a0 = [t.limit(x, 0) for t in a]
  582. if any(t.has(oo, -oo, zoo, nan) for t in a0):
  583. return self._eval_aseries(n, args0, x, logx)
  584. # Careful: the argument goes to oo, but only logarithmically so. We
  585. # are supposed to do a power series expansion "around the
  586. # logarithmic term". e.g.
  587. # f(1+x+log(x))
  588. # -> f(1+logx) + x*f'(1+logx) + O(x**2)
  589. # where 'logx' is given in the argument
  590. a = [t._eval_nseries(x, n, logx) for t in args]
  591. z = [r - r0 for (r, r0) in zip(a, a0)]
  592. p = [Dummy() for _ in z]
  593. q = []
  594. v = None
  595. for ai, zi, pi in zip(a0, z, p):
  596. if zi.has(x):
  597. if v is not None:
  598. raise NotImplementedError
  599. q.append(ai + pi)
  600. v = pi
  601. else:
  602. q.append(ai)
  603. e1 = self.func(*q)
  604. if v is None:
  605. return e1
  606. s = e1._eval_nseries(v, n, logx)
  607. o = s.getO()
  608. s = s.removeO()
  609. s = s.subs(v, zi).expand() + Order(o.expr.subs(v, zi), x)
  610. return s
  611. if (self.func.nargs is S.Naturals0
  612. or (self.func.nargs == FiniteSet(1) and args0[0])
  613. or any(c > 1 for c in self.func.nargs)):
  614. e = self
  615. e1 = e.expand()
  616. if e == e1:
  617. #for example when e = sin(x+1) or e = sin(cos(x))
  618. #let's try the general algorithm
  619. if len(e.args) == 1:
  620. # issue 14411
  621. e = e.func(e.args[0].cancel())
  622. term = e.subs(x, S.Zero)
  623. if term.is_finite is False or term is S.NaN:
  624. raise PoleError("Cannot expand %s around 0" % (self))
  625. series = term
  626. fact = S.One
  627. _x = uniquely_named_symbol('xi', self)
  628. e = e.subs(x, _x)
  629. for i in range(n - 1):
  630. i += 1
  631. fact *= Rational(i)
  632. e = e.diff(_x)
  633. subs = e.subs(_x, S.Zero)
  634. if subs is S.NaN:
  635. # try to evaluate a limit if we have to
  636. subs = e.limit(_x, S.Zero)
  637. if subs.is_finite is False:
  638. raise PoleError("Cannot expand %s around 0" % (self))
  639. term = subs*(x**i)/fact
  640. term = term.expand()
  641. series += term
  642. return series + Order(x**n, x)
  643. return e1.nseries(x, n=n, logx=logx)
  644. arg = self.args[0]
  645. l = []
  646. g = None
  647. # try to predict a number of terms needed
  648. nterms = n + 2
  649. cf = Order(arg.as_leading_term(x), x).getn()
  650. if cf != 0:
  651. nterms = (n/cf).ceiling()
  652. for i in range(nterms):
  653. g = self.taylor_term(i, arg, g)
  654. g = g.nseries(x, n=n, logx=logx)
  655. l.append(g)
  656. return Add(*l) + Order(x**n, x)
  657. def fdiff(self, argindex=1):
  658. """
  659. Returns the first derivative of the function.
  660. """
  661. if not (1 <= argindex <= len(self.args)):
  662. raise ArgumentIndexError(self, argindex)
  663. ix = argindex - 1
  664. A = self.args[ix]
  665. if A._diff_wrt:
  666. if len(self.args) == 1 or not A.is_Symbol:
  667. return _derivative_dispatch(self, A)
  668. for i, v in enumerate(self.args):
  669. if i != ix and A in v.free_symbols:
  670. # it can't be in any other argument's free symbols
  671. # issue 8510
  672. break
  673. else:
  674. return _derivative_dispatch(self, A)
  675. # See issue 4624 and issue 4719, 5600 and 8510
  676. D = Dummy('xi_%i' % argindex, dummy_index=hash(A))
  677. args = self.args[:ix] + (D,) + self.args[ix + 1:]
  678. return Subs(Derivative(self.func(*args), D), D, A)
  679. def _eval_as_leading_term(self, x, logx=None, cdir=0):
  680. """Stub that should be overridden by new Functions to return
  681. the first non-zero term in a series if ever an x-dependent
  682. argument whose leading term vanishes as x -> 0 might be encountered.
  683. See, for example, cos._eval_as_leading_term.
  684. """
  685. from sympy.series.order import Order
  686. args = [a.as_leading_term(x, logx=logx) for a in self.args]
  687. o = Order(1, x)
  688. if any(x in a.free_symbols and o.contains(a) for a in args):
  689. # Whereas x and any finite number are contained in O(1, x),
  690. # expressions like 1/x are not. If any arg simplified to a
  691. # vanishing expression as x -> 0 (like x or x**2, but not
  692. # 3, 1/x, etc...) then the _eval_as_leading_term is needed
  693. # to supply the first non-zero term of the series,
  694. #
  695. # e.g. expression leading term
  696. # ---------- ------------
  697. # cos(1/x) cos(1/x)
  698. # cos(cos(x)) cos(1)
  699. # cos(x) 1 <- _eval_as_leading_term needed
  700. # sin(x) x <- _eval_as_leading_term needed
  701. #
  702. raise NotImplementedError(
  703. '%s has no _eval_as_leading_term routine' % self.func)
  704. else:
  705. return self.func(*args)
  706. class AppliedUndef(Function):
  707. """
  708. Base class for expressions resulting from the application of an undefined
  709. function.
  710. """
  711. is_number = False
  712. def __new__(cls, *args, **options):
  713. args = list(map(sympify, args))
  714. u = [a.name for a in args if isinstance(a, UndefinedFunction)]
  715. if u:
  716. raise TypeError('Invalid argument: expecting an expression, not UndefinedFunction%s: %s' % (
  717. 's'*(len(u) > 1), ', '.join(u)))
  718. obj = super().__new__(cls, *args, **options)
  719. return obj
  720. def _eval_as_leading_term(self, x, logx=None, cdir=0):
  721. return self
  722. @property
  723. def _diff_wrt(self):
  724. """
  725. Allow derivatives wrt to undefined functions.
  726. Examples
  727. ========
  728. >>> from sympy import Function, Symbol
  729. >>> f = Function('f')
  730. >>> x = Symbol('x')
  731. >>> f(x)._diff_wrt
  732. True
  733. >>> f(x).diff(x)
  734. Derivative(f(x), x)
  735. """
  736. return True
  737. class UndefSageHelper:
  738. """
  739. Helper to facilitate Sage conversion.
  740. """
  741. def __get__(self, ins, typ):
  742. import sage.all as sage
  743. if ins is None:
  744. return lambda: sage.function(typ.__name__)
  745. else:
  746. args = [arg._sage_() for arg in ins.args]
  747. return lambda : sage.function(ins.__class__.__name__)(*args)
  748. _undef_sage_helper = UndefSageHelper()
  749. class UndefinedFunction(FunctionClass):
  750. """
  751. The (meta)class of undefined functions.
  752. """
  753. def __new__(mcl, name, bases=(AppliedUndef,), __dict__=None, **kwargs):
  754. from .symbol import _filter_assumptions
  755. # Allow Function('f', real=True)
  756. # and/or Function(Symbol('f', real=True))
  757. assumptions, kwargs = _filter_assumptions(kwargs)
  758. if isinstance(name, Symbol):
  759. assumptions = name._merge(assumptions)
  760. name = name.name
  761. elif not isinstance(name, str):
  762. raise TypeError('expecting string or Symbol for name')
  763. else:
  764. commutative = assumptions.get('commutative', None)
  765. assumptions = Symbol(name, **assumptions).assumptions0
  766. if commutative is None:
  767. assumptions.pop('commutative')
  768. __dict__ = __dict__ or {}
  769. # put the `is_*` for into __dict__
  770. __dict__.update({'is_%s' % k: v for k, v in assumptions.items()})
  771. # You can add other attributes, although they do have to be hashable
  772. # (but seriously, if you want to add anything other than assumptions,
  773. # just subclass Function)
  774. __dict__.update(kwargs)
  775. # add back the sanitized assumptions without the is_ prefix
  776. kwargs.update(assumptions)
  777. # Save these for __eq__
  778. __dict__.update({'_kwargs': kwargs})
  779. # do this for pickling
  780. __dict__['__module__'] = None
  781. obj = super().__new__(mcl, name, bases, __dict__)
  782. obj.name = name
  783. obj._sage_ = _undef_sage_helper
  784. return obj
  785. def __instancecheck__(cls, instance):
  786. return cls in type(instance).__mro__
  787. _kwargs = {} # type: tDict[str, Optional[bool]]
  788. def __hash__(self):
  789. return hash((self.class_key(), frozenset(self._kwargs.items())))
  790. def __eq__(self, other):
  791. return (isinstance(other, self.__class__) and
  792. self.class_key() == other.class_key() and
  793. self._kwargs == other._kwargs)
  794. def __ne__(self, other):
  795. return not self == other
  796. @property
  797. def _diff_wrt(self):
  798. return False
  799. # XXX: The type: ignore on WildFunction is because mypy complains:
  800. #
  801. # sympy/core/function.py:939: error: Cannot determine type of 'sort_key' in
  802. # base class 'Expr'
  803. #
  804. # Somehow this is because of the @cacheit decorator but it is not clear how to
  805. # fix it.
  806. class WildFunction(Function, AtomicExpr): # type: ignore
  807. """
  808. A WildFunction function matches any function (with its arguments).
  809. Examples
  810. ========
  811. >>> from sympy import WildFunction, Function, cos
  812. >>> from sympy.abc import x, y
  813. >>> F = WildFunction('F')
  814. >>> f = Function('f')
  815. >>> F.nargs
  816. Naturals0
  817. >>> x.match(F)
  818. >>> F.match(F)
  819. {F_: F_}
  820. >>> f(x).match(F)
  821. {F_: f(x)}
  822. >>> cos(x).match(F)
  823. {F_: cos(x)}
  824. >>> f(x, y).match(F)
  825. {F_: f(x, y)}
  826. To match functions with a given number of arguments, set ``nargs`` to the
  827. desired value at instantiation:
  828. >>> F = WildFunction('F', nargs=2)
  829. >>> F.nargs
  830. {2}
  831. >>> f(x).match(F)
  832. >>> f(x, y).match(F)
  833. {F_: f(x, y)}
  834. To match functions with a range of arguments, set ``nargs`` to a tuple
  835. containing the desired number of arguments, e.g. if ``nargs = (1, 2)``
  836. then functions with 1 or 2 arguments will be matched.
  837. >>> F = WildFunction('F', nargs=(1, 2))
  838. >>> F.nargs
  839. {1, 2}
  840. >>> f(x).match(F)
  841. {F_: f(x)}
  842. >>> f(x, y).match(F)
  843. {F_: f(x, y)}
  844. >>> f(x, y, 1).match(F)
  845. """
  846. # XXX: What is this class attribute used for?
  847. include = set() # type: tSet[Any]
  848. def __init__(cls, name, **assumptions):
  849. from sympy.sets.sets import Set, FiniteSet
  850. cls.name = name
  851. nargs = assumptions.pop('nargs', S.Naturals0)
  852. if not isinstance(nargs, Set):
  853. # Canonicalize nargs here. See also FunctionClass.
  854. if is_sequence(nargs):
  855. nargs = tuple(ordered(set(nargs)))
  856. elif nargs is not None:
  857. nargs = (as_int(nargs),)
  858. nargs = FiniteSet(*nargs)
  859. cls.nargs = nargs
  860. def matches(self, expr, repl_dict=None, old=False):
  861. if not isinstance(expr, (AppliedUndef, Function)):
  862. return None
  863. if len(expr.args) not in self.nargs:
  864. return None
  865. if repl_dict is None:
  866. repl_dict = dict()
  867. else:
  868. repl_dict = repl_dict.copy()
  869. repl_dict[self] = expr
  870. return repl_dict
  871. class Derivative(Expr):
  872. """
  873. Carries out differentiation of the given expression with respect to symbols.
  874. Examples
  875. ========
  876. >>> from sympy import Derivative, Function, symbols, Subs
  877. >>> from sympy.abc import x, y
  878. >>> f, g = symbols('f g', cls=Function)
  879. >>> Derivative(x**2, x, evaluate=True)
  880. 2*x
  881. Denesting of derivatives retains the ordering of variables:
  882. >>> Derivative(Derivative(f(x, y), y), x)
  883. Derivative(f(x, y), y, x)
  884. Contiguously identical symbols are merged into a tuple giving
  885. the symbol and the count:
  886. >>> Derivative(f(x), x, x, y, x)
  887. Derivative(f(x), (x, 2), y, x)
  888. If the derivative cannot be performed, and evaluate is True, the
  889. order of the variables of differentiation will be made canonical:
  890. >>> Derivative(f(x, y), y, x, evaluate=True)
  891. Derivative(f(x, y), x, y)
  892. Derivatives with respect to undefined functions can be calculated:
  893. >>> Derivative(f(x)**2, f(x), evaluate=True)
  894. 2*f(x)
  895. Such derivatives will show up when the chain rule is used to
  896. evalulate a derivative:
  897. >>> f(g(x)).diff(x)
  898. Derivative(f(g(x)), g(x))*Derivative(g(x), x)
  899. Substitution is used to represent derivatives of functions with
  900. arguments that are not symbols or functions:
  901. >>> f(2*x + 3).diff(x) == 2*Subs(f(y).diff(y), y, 2*x + 3)
  902. True
  903. Notes
  904. =====
  905. Simplification of high-order derivatives:
  906. Because there can be a significant amount of simplification that can be
  907. done when multiple differentiations are performed, results will be
  908. automatically simplified in a fairly conservative fashion unless the
  909. keyword ``simplify`` is set to False.
  910. >>> from sympy import sqrt, diff, Function, symbols
  911. >>> from sympy.abc import x, y, z
  912. >>> f, g = symbols('f,g', cls=Function)
  913. >>> e = sqrt((x + 1)**2 + x)
  914. >>> diff(e, (x, 5), simplify=False).count_ops()
  915. 136
  916. >>> diff(e, (x, 5)).count_ops()
  917. 30
  918. Ordering of variables:
  919. If evaluate is set to True and the expression cannot be evaluated, the
  920. list of differentiation symbols will be sorted, that is, the expression is
  921. assumed to have continuous derivatives up to the order asked.
  922. Derivative wrt non-Symbols:
  923. For the most part, one may not differentiate wrt non-symbols.
  924. For example, we do not allow differentiation wrt `x*y` because
  925. there are multiple ways of structurally defining where x*y appears
  926. in an expression: a very strict definition would make
  927. (x*y*z).diff(x*y) == 0. Derivatives wrt defined functions (like
  928. cos(x)) are not allowed, either:
  929. >>> (x*y*z).diff(x*y)
  930. Traceback (most recent call last):
  931. ...
  932. ValueError: Can't calculate derivative wrt x*y.
  933. To make it easier to work with variational calculus, however,
  934. derivatives wrt AppliedUndef and Derivatives are allowed.
  935. For example, in the Euler-Lagrange method one may write
  936. F(t, u, v) where u = f(t) and v = f'(t). These variables can be
  937. written explicitly as functions of time::
  938. >>> from sympy.abc import t
  939. >>> F = Function('F')
  940. >>> U = f(t)
  941. >>> V = U.diff(t)
  942. The derivative wrt f(t) can be obtained directly:
  943. >>> direct = F(t, U, V).diff(U)
  944. When differentiation wrt a non-Symbol is attempted, the non-Symbol
  945. is temporarily converted to a Symbol while the differentiation
  946. is performed and the same answer is obtained:
  947. >>> indirect = F(t, U, V).subs(U, x).diff(x).subs(x, U)
  948. >>> assert direct == indirect
  949. The implication of this non-symbol replacement is that all
  950. functions are treated as independent of other functions and the
  951. symbols are independent of the functions that contain them::
  952. >>> x.diff(f(x))
  953. 0
  954. >>> g(x).diff(f(x))
  955. 0
  956. It also means that derivatives are assumed to depend only
  957. on the variables of differentiation, not on anything contained
  958. within the expression being differentiated::
  959. >>> F = f(x)
  960. >>> Fx = F.diff(x)
  961. >>> Fx.diff(F) # derivative depends on x, not F
  962. 0
  963. >>> Fxx = Fx.diff(x)
  964. >>> Fxx.diff(Fx) # derivative depends on x, not Fx
  965. 0
  966. The last example can be made explicit by showing the replacement
  967. of Fx in Fxx with y:
  968. >>> Fxx.subs(Fx, y)
  969. Derivative(y, x)
  970. Since that in itself will evaluate to zero, differentiating
  971. wrt Fx will also be zero:
  972. >>> _.doit()
  973. 0
  974. Replacing undefined functions with concrete expressions
  975. One must be careful to replace undefined functions with expressions
  976. that contain variables consistent with the function definition and
  977. the variables of differentiation or else insconsistent result will
  978. be obtained. Consider the following example:
  979. >>> eq = f(x)*g(y)
  980. >>> eq.subs(f(x), x*y).diff(x, y).doit()
  981. y*Derivative(g(y), y) + g(y)
  982. >>> eq.diff(x, y).subs(f(x), x*y).doit()
  983. y*Derivative(g(y), y)
  984. The results differ because `f(x)` was replaced with an expression
  985. that involved both variables of differentiation. In the abstract
  986. case, differentiation of `f(x)` by `y` is 0; in the concrete case,
  987. the presence of `y` made that derivative nonvanishing and produced
  988. the extra `g(y)` term.
  989. Defining differentiation for an object
  990. An object must define ._eval_derivative(symbol) method that returns
  991. the differentiation result. This function only needs to consider the
  992. non-trivial case where expr contains symbol and it should call the diff()
  993. method internally (not _eval_derivative); Derivative should be the only
  994. one to call _eval_derivative.
  995. Any class can allow derivatives to be taken with respect to
  996. itself (while indicating its scalar nature). See the
  997. docstring of Expr._diff_wrt.
  998. See Also
  999. ========
  1000. _sort_variable_count
  1001. """
  1002. is_Derivative = True
  1003. @property
  1004. def _diff_wrt(self):
  1005. """An expression may be differentiated wrt a Derivative if
  1006. it is in elementary form.
  1007. Examples
  1008. ========
  1009. >>> from sympy import Function, Derivative, cos
  1010. >>> from sympy.abc import x
  1011. >>> f = Function('f')
  1012. >>> Derivative(f(x), x)._diff_wrt
  1013. True
  1014. >>> Derivative(cos(x), x)._diff_wrt
  1015. False
  1016. >>> Derivative(x + 1, x)._diff_wrt
  1017. False
  1018. A Derivative might be an unevaluated form of what will not be
  1019. a valid variable of differentiation if evaluated. For example,
  1020. >>> Derivative(f(f(x)), x).doit()
  1021. Derivative(f(x), x)*Derivative(f(f(x)), f(x))
  1022. Such an expression will present the same ambiguities as arise
  1023. when dealing with any other product, like ``2*x``, so ``_diff_wrt``
  1024. is False:
  1025. >>> Derivative(f(f(x)), x)._diff_wrt
  1026. False
  1027. """
  1028. return self.expr._diff_wrt and isinstance(self.doit(), Derivative)
  1029. def __new__(cls, expr, *variables, **kwargs):
  1030. expr = sympify(expr)
  1031. symbols_or_none = getattr(expr, "free_symbols", None)
  1032. has_symbol_set = isinstance(symbols_or_none, set)
  1033. if not has_symbol_set:
  1034. raise ValueError(filldedent('''
  1035. Since there are no variables in the expression %s,
  1036. it cannot be differentiated.''' % expr))
  1037. # determine value for variables if it wasn't given
  1038. if not variables:
  1039. variables = expr.free_symbols
  1040. if len(variables) != 1:
  1041. if expr.is_number:
  1042. return S.Zero
  1043. if len(variables) == 0:
  1044. raise ValueError(filldedent('''
  1045. Since there are no variables in the expression,
  1046. the variable(s) of differentiation must be supplied
  1047. to differentiate %s''' % expr))
  1048. else:
  1049. raise ValueError(filldedent('''
  1050. Since there is more than one variable in the
  1051. expression, the variable(s) of differentiation
  1052. must be supplied to differentiate %s''' % expr))
  1053. # Split the list of variables into a list of the variables we are diff
  1054. # wrt, where each element of the list has the form (s, count) where
  1055. # s is the entity to diff wrt and count is the order of the
  1056. # derivative.
  1057. variable_count = []
  1058. array_likes = (tuple, list, Tuple)
  1059. from sympy.tensor.array import Array, NDimArray
  1060. for i, v in enumerate(variables):
  1061. if isinstance(v, UndefinedFunction):
  1062. raise TypeError(
  1063. "cannot differentiate wrt "
  1064. "UndefinedFunction: %s" % v)
  1065. if isinstance(v, array_likes):
  1066. if len(v) == 0:
  1067. # Ignore empty tuples: Derivative(expr, ... , (), ... )
  1068. continue
  1069. if isinstance(v[0], array_likes):
  1070. # Derive by array: Derivative(expr, ... , [[x, y, z]], ... )
  1071. if len(v) == 1:
  1072. v = Array(v[0])
  1073. count = 1
  1074. else:
  1075. v, count = v
  1076. v = Array(v)
  1077. else:
  1078. v, count = v
  1079. if count == 0:
  1080. continue
  1081. variable_count.append(Tuple(v, count))
  1082. continue
  1083. v = sympify(v)
  1084. if isinstance(v, Integer):
  1085. if i == 0:
  1086. raise ValueError("First variable cannot be a number: %i" % v)
  1087. count = v
  1088. prev, prevcount = variable_count[-1]
  1089. if prevcount != 1:
  1090. raise TypeError("tuple {} followed by number {}".format((prev, prevcount), v))
  1091. if count == 0:
  1092. variable_count.pop()
  1093. else:
  1094. variable_count[-1] = Tuple(prev, count)
  1095. else:
  1096. count = 1
  1097. variable_count.append(Tuple(v, count))
  1098. # light evaluation of contiguous, identical
  1099. # items: (x, 1), (x, 1) -> (x, 2)
  1100. merged = []
  1101. for t in variable_count:
  1102. v, c = t
  1103. if c.is_negative:
  1104. raise ValueError(
  1105. 'order of differentiation must be nonnegative')
  1106. if merged and merged[-1][0] == v:
  1107. c += merged[-1][1]
  1108. if not c:
  1109. merged.pop()
  1110. else:
  1111. merged[-1] = Tuple(v, c)
  1112. else:
  1113. merged.append(t)
  1114. variable_count = merged
  1115. # sanity check of variables of differentation; we waited
  1116. # until the counts were computed since some variables may
  1117. # have been removed because the count was 0
  1118. for v, c in variable_count:
  1119. # v must have _diff_wrt True
  1120. if not v._diff_wrt:
  1121. __ = '' # filler to make error message neater
  1122. raise ValueError(filldedent('''
  1123. Can't calculate derivative wrt %s.%s''' % (v,
  1124. __)))
  1125. # We make a special case for 0th derivative, because there is no
  1126. # good way to unambiguously print this.
  1127. if len(variable_count) == 0:
  1128. return expr
  1129. evaluate = kwargs.get('evaluate', False)
  1130. if evaluate:
  1131. if isinstance(expr, Derivative):
  1132. expr = expr.canonical
  1133. variable_count = [
  1134. (v.canonical if isinstance(v, Derivative) else v, c)
  1135. for v, c in variable_count]
  1136. # Look for a quick exit if there are symbols that don't appear in
  1137. # expression at all. Note, this cannot check non-symbols like
  1138. # Derivatives as those can be created by intermediate
  1139. # derivatives.
  1140. zero = False
  1141. free = expr.free_symbols
  1142. from sympy.matrices.expressions.matexpr import MatrixExpr
  1143. for v, c in variable_count:
  1144. vfree = v.free_symbols
  1145. if c.is_positive and vfree:
  1146. if isinstance(v, AppliedUndef):
  1147. # these match exactly since
  1148. # x.diff(f(x)) == g(x).diff(f(x)) == 0
  1149. # and are not created by differentiation
  1150. D = Dummy()
  1151. if not expr.xreplace({v: D}).has(D):
  1152. zero = True
  1153. break
  1154. elif isinstance(v, MatrixExpr):
  1155. zero = False
  1156. break
  1157. elif isinstance(v, Symbol) and v not in free:
  1158. zero = True
  1159. break
  1160. else:
  1161. if not free & vfree:
  1162. # e.g. v is IndexedBase or Matrix
  1163. zero = True
  1164. break
  1165. if zero:
  1166. return cls._get_zero_with_shape_like(expr)
  1167. # make the order of symbols canonical
  1168. #TODO: check if assumption of discontinuous derivatives exist
  1169. variable_count = cls._sort_variable_count(variable_count)
  1170. # denest
  1171. if isinstance(expr, Derivative):
  1172. variable_count = list(expr.variable_count) + variable_count
  1173. expr = expr.expr
  1174. return _derivative_dispatch(expr, *variable_count, **kwargs)
  1175. # we return here if evaluate is False or if there is no
  1176. # _eval_derivative method
  1177. if not evaluate or not hasattr(expr, '_eval_derivative'):
  1178. # return an unevaluated Derivative
  1179. if evaluate and variable_count == [(expr, 1)] and expr.is_scalar:
  1180. # special hack providing evaluation for classes
  1181. # that have defined is_scalar=True but have no
  1182. # _eval_derivative defined
  1183. return S.One
  1184. return Expr.__new__(cls, expr, *variable_count)
  1185. # evaluate the derivative by calling _eval_derivative method
  1186. # of expr for each variable
  1187. # -------------------------------------------------------------
  1188. nderivs = 0 # how many derivatives were performed
  1189. unhandled = []
  1190. from sympy.matrices.common import MatrixCommon
  1191. for i, (v, count) in enumerate(variable_count):
  1192. old_expr = expr
  1193. old_v = None
  1194. is_symbol = v.is_symbol or isinstance(v,
  1195. (Iterable, Tuple, MatrixCommon, NDimArray))
  1196. if not is_symbol:
  1197. old_v = v
  1198. v = Dummy('xi')
  1199. expr = expr.xreplace({old_v: v})
  1200. # Derivatives and UndefinedFunctions are independent
  1201. # of all others
  1202. clashing = not (isinstance(old_v, Derivative) or \
  1203. isinstance(old_v, AppliedUndef))
  1204. if v not in expr.free_symbols and not clashing:
  1205. return expr.diff(v) # expr's version of 0
  1206. if not old_v.is_scalar and not hasattr(
  1207. old_v, '_eval_derivative'):
  1208. # special hack providing evaluation for classes
  1209. # that have defined is_scalar=True but have no
  1210. # _eval_derivative defined
  1211. expr *= old_v.diff(old_v)
  1212. obj = cls._dispatch_eval_derivative_n_times(expr, v, count)
  1213. if obj is not None and obj.is_zero:
  1214. return obj
  1215. nderivs += count
  1216. if old_v is not None:
  1217. if obj is not None:
  1218. # remove the dummy that was used
  1219. obj = obj.subs(v, old_v)
  1220. # restore expr
  1221. expr = old_expr
  1222. if obj is None:
  1223. # we've already checked for quick-exit conditions
  1224. # that give 0 so the remaining variables
  1225. # are contained in the expression but the expression
  1226. # did not compute a derivative so we stop taking
  1227. # derivatives
  1228. unhandled = variable_count[i:]
  1229. break
  1230. expr = obj
  1231. # what we have so far can be made canonical
  1232. expr = expr.replace(
  1233. lambda x: isinstance(x, Derivative),
  1234. lambda x: x.canonical)
  1235. if unhandled:
  1236. if isinstance(expr, Derivative):
  1237. unhandled = list(expr.variable_count) + unhandled
  1238. expr = expr.expr
  1239. expr = Expr.__new__(cls, expr, *unhandled)
  1240. if (nderivs > 1) == True and kwargs.get('simplify', True):
  1241. from .exprtools import factor_terms
  1242. from sympy.simplify.simplify import signsimp
  1243. expr = factor_terms(signsimp(expr))
  1244. return expr
  1245. @property
  1246. def canonical(cls):
  1247. return cls.func(cls.expr,
  1248. *Derivative._sort_variable_count(cls.variable_count))
  1249. @classmethod
  1250. def _sort_variable_count(cls, vc):
  1251. """
  1252. Sort (variable, count) pairs into canonical order while
  1253. retaining order of variables that do not commute during
  1254. differentiation:
  1255. * symbols and functions commute with each other
  1256. * derivatives commute with each other
  1257. * a derivative doesn't commute with anything it contains
  1258. * any other object is not allowed to commute if it has
  1259. free symbols in common with another object
  1260. Examples
  1261. ========
  1262. >>> from sympy import Derivative, Function, symbols
  1263. >>> vsort = Derivative._sort_variable_count
  1264. >>> x, y, z = symbols('x y z')
  1265. >>> f, g, h = symbols('f g h', cls=Function)
  1266. Contiguous items are collapsed into one pair:
  1267. >>> vsort([(x, 1), (x, 1)])
  1268. [(x, 2)]
  1269. >>> vsort([(y, 1), (f(x), 1), (y, 1), (f(x), 1)])
  1270. [(y, 2), (f(x), 2)]
  1271. Ordering is canonical.
  1272. >>> def vsort0(*v):
  1273. ... # docstring helper to
  1274. ... # change vi -> (vi, 0), sort, and return vi vals
  1275. ... return [i[0] for i in vsort([(i, 0) for i in v])]
  1276. >>> vsort0(y, x)
  1277. [x, y]
  1278. >>> vsort0(g(y), g(x), f(y))
  1279. [f(y), g(x), g(y)]
  1280. Symbols are sorted as far to the left as possible but never
  1281. move to the left of a derivative having the same symbol in
  1282. its variables; the same applies to AppliedUndef which are
  1283. always sorted after Symbols:
  1284. >>> dfx = f(x).diff(x)
  1285. >>> assert vsort0(dfx, y) == [y, dfx]
  1286. >>> assert vsort0(dfx, x) == [dfx, x]
  1287. """
  1288. if not vc:
  1289. return []
  1290. vc = list(vc)
  1291. if len(vc) == 1:
  1292. return [Tuple(*vc[0])]
  1293. V = list(range(len(vc)))
  1294. E = []
  1295. v = lambda i: vc[i][0]
  1296. D = Dummy()
  1297. def _block(d, v, wrt=False):
  1298. # return True if v should not come before d else False
  1299. if d == v:
  1300. return wrt
  1301. if d.is_Symbol:
  1302. return False
  1303. if isinstance(d, Derivative):
  1304. # a derivative blocks if any of it's variables contain
  1305. # v; the wrt flag will return True for an exact match
  1306. # and will cause an AppliedUndef to block if v is in
  1307. # the arguments
  1308. if any(_block(k, v, wrt=True)
  1309. for k in d._wrt_variables):
  1310. return True
  1311. return False
  1312. if not wrt and isinstance(d, AppliedUndef):
  1313. return False
  1314. if v.is_Symbol:
  1315. return v in d.free_symbols
  1316. if isinstance(v, AppliedUndef):
  1317. return _block(d.xreplace({v: D}), D)
  1318. return d.free_symbols & v.free_symbols
  1319. for i in range(len(vc)):
  1320. for j in range(i):
  1321. if _block(v(j), v(i)):
  1322. E.append((j,i))
  1323. # this is the default ordering to use in case of ties
  1324. O = dict(zip(ordered(uniq([i for i, c in vc])), range(len(vc))))
  1325. ix = topological_sort((V, E), key=lambda i: O[v(i)])
  1326. # merge counts of contiguously identical items
  1327. merged = []
  1328. for v, c in [vc[i] for i in ix]:
  1329. if merged and merged[-1][0] == v:
  1330. merged[-1][1] += c
  1331. else:
  1332. merged.append([v, c])
  1333. return [Tuple(*i) for i in merged]
  1334. def _eval_is_commutative(self):
  1335. return self.expr.is_commutative
  1336. def _eval_derivative(self, v):
  1337. # If v (the variable of differentiation) is not in
  1338. # self.variables, we might be able to take the derivative.
  1339. if v not in self._wrt_variables:
  1340. dedv = self.expr.diff(v)
  1341. if isinstance(dedv, Derivative):
  1342. return dedv.func(dedv.expr, *(self.variable_count + dedv.variable_count))
  1343. # dedv (d(self.expr)/dv) could have simplified things such that the
  1344. # derivative wrt things in self.variables can now be done. Thus,
  1345. # we set evaluate=True to see if there are any other derivatives
  1346. # that can be done. The most common case is when dedv is a simple
  1347. # number so that the derivative wrt anything else will vanish.
  1348. return self.func(dedv, *self.variables, evaluate=True)
  1349. # In this case v was in self.variables so the derivative wrt v has
  1350. # already been attempted and was not computed, either because it
  1351. # couldn't be or evaluate=False originally.
  1352. variable_count = list(self.variable_count)
  1353. variable_count.append((v, 1))
  1354. return self.func(self.expr, *variable_count, evaluate=False)
  1355. def doit(self, **hints):
  1356. expr = self.expr
  1357. if hints.get('deep', True):
  1358. expr = expr.doit(**hints)
  1359. hints['evaluate'] = True
  1360. rv = self.func(expr, *self.variable_count, **hints)
  1361. if rv!= self and rv.has(Derivative):
  1362. rv = rv.doit(**hints)
  1363. return rv
  1364. @_sympifyit('z0', NotImplementedError)
  1365. def doit_numerically(self, z0):
  1366. """
  1367. Evaluate the derivative at z numerically.
  1368. When we can represent derivatives at a point, this should be folded
  1369. into the normal evalf. For now, we need a special method.
  1370. """
  1371. if len(self.free_symbols) != 1 or len(self.variables) != 1:
  1372. raise NotImplementedError('partials and higher order derivatives')
  1373. z = list(self.free_symbols)[0]
  1374. def eval(x):
  1375. f0 = self.expr.subs(z, Expr._from_mpmath(x, prec=mpmath.mp.prec))
  1376. f0 = f0.evalf(prec_to_dps(mpmath.mp.prec))
  1377. return f0._to_mpmath(mpmath.mp.prec)
  1378. return Expr._from_mpmath(mpmath.diff(eval,
  1379. z0._to_mpmath(mpmath.mp.prec)),
  1380. mpmath.mp.prec)
  1381. @property
  1382. def expr(self):
  1383. return self._args[0]
  1384. @property
  1385. def _wrt_variables(self):
  1386. # return the variables of differentiation without
  1387. # respect to the type of count (int or symbolic)
  1388. return [i[0] for i in self.variable_count]
  1389. @property
  1390. def variables(self):
  1391. # TODO: deprecate? YES, make this 'enumerated_variables' and
  1392. # name _wrt_variables as variables
  1393. # TODO: support for `d^n`?
  1394. rv = []
  1395. for v, count in self.variable_count:
  1396. if not count.is_Integer:
  1397. raise TypeError(filldedent('''
  1398. Cannot give expansion for symbolic count. If you just
  1399. want a list of all variables of differentiation, use
  1400. _wrt_variables.'''))
  1401. rv.extend([v]*count)
  1402. return tuple(rv)
  1403. @property
  1404. def variable_count(self):
  1405. return self._args[1:]
  1406. @property
  1407. def derivative_count(self):
  1408. return sum([count for _, count in self.variable_count], 0)
  1409. @property
  1410. def free_symbols(self):
  1411. ret = self.expr.free_symbols
  1412. # Add symbolic counts to free_symbols
  1413. for _, count in self.variable_count:
  1414. ret.update(count.free_symbols)
  1415. return ret
  1416. @property
  1417. def kind(self):
  1418. return self.args[0].kind
  1419. def _eval_subs(self, old, new):
  1420. # The substitution (old, new) cannot be done inside
  1421. # Derivative(expr, vars) for a variety of reasons
  1422. # as handled below.
  1423. if old in self._wrt_variables:
  1424. # first handle the counts
  1425. expr = self.func(self.expr, *[(v, c.subs(old, new))
  1426. for v, c in self.variable_count])
  1427. if expr != self:
  1428. return expr._eval_subs(old, new)
  1429. # quick exit case
  1430. if not getattr(new, '_diff_wrt', False):
  1431. # case (0): new is not a valid variable of
  1432. # differentiation
  1433. if isinstance(old, Symbol):
  1434. # don't introduce a new symbol if the old will do
  1435. return Subs(self, old, new)
  1436. else:
  1437. xi = Dummy('xi')
  1438. return Subs(self.xreplace({old: xi}), xi, new)
  1439. # If both are Derivatives with the same expr, check if old is
  1440. # equivalent to self or if old is a subderivative of self.
  1441. if old.is_Derivative and old.expr == self.expr:
  1442. if self.canonical == old.canonical:
  1443. return new
  1444. # collections.Counter doesn't have __le__
  1445. def _subset(a, b):
  1446. return all((a[i] <= b[i]) == True for i in a)
  1447. old_vars = Counter(dict(reversed(old.variable_count)))
  1448. self_vars = Counter(dict(reversed(self.variable_count)))
  1449. if _subset(old_vars, self_vars):
  1450. return _derivative_dispatch(new, *(self_vars - old_vars).items()).canonical
  1451. args = list(self.args)
  1452. newargs = list(x._subs(old, new) for x in args)
  1453. if args[0] == old:
  1454. # complete replacement of self.expr
  1455. # we already checked that the new is valid so we know
  1456. # it won't be a problem should it appear in variables
  1457. return _derivative_dispatch(*newargs)
  1458. if newargs[0] != args[0]:
  1459. # case (1) can't change expr by introducing something that is in
  1460. # the _wrt_variables if it was already in the expr
  1461. # e.g.
  1462. # for Derivative(f(x, g(y)), y), x cannot be replaced with
  1463. # anything that has y in it; for f(g(x), g(y)).diff(g(y))
  1464. # g(x) cannot be replaced with anything that has g(y)
  1465. syms = {vi: Dummy() for vi in self._wrt_variables
  1466. if not vi.is_Symbol}
  1467. wrt = {syms.get(vi, vi) for vi in self._wrt_variables}
  1468. forbidden = args[0].xreplace(syms).free_symbols & wrt
  1469. nfree = new.xreplace(syms).free_symbols
  1470. ofree = old.xreplace(syms).free_symbols
  1471. if (nfree - ofree) & forbidden:
  1472. return Subs(self, old, new)
  1473. viter = ((i, j) for ((i, _), (j, _)) in zip(newargs[1:], args[1:]))
  1474. if any(i != j for i, j in viter): # a wrt-variable change
  1475. # case (2) can't change vars by introducing a variable
  1476. # that is contained in expr, e.g.
  1477. # for Derivative(f(z, g(h(x), y)), y), y cannot be changed to
  1478. # x, h(x), or g(h(x), y)
  1479. for a in _atomic(self.expr, recursive=True):
  1480. for i in range(1, len(newargs)):
  1481. vi, _ = newargs[i]
  1482. if a == vi and vi != args[i][0]:
  1483. return Subs(self, old, new)
  1484. # more arg-wise checks
  1485. vc = newargs[1:]
  1486. oldv = self._wrt_variables
  1487. newe = self.expr
  1488. subs = []
  1489. for i, (vi, ci) in enumerate(vc):
  1490. if not vi._diff_wrt:
  1491. # case (3) invalid differentiation expression so
  1492. # create a replacement dummy
  1493. xi = Dummy('xi_%i' % i)
  1494. # replace the old valid variable with the dummy
  1495. # in the expression
  1496. newe = newe.xreplace({oldv[i]: xi})
  1497. # and replace the bad variable with the dummy
  1498. vc[i] = (xi, ci)
  1499. # and record the dummy with the new (invalid)
  1500. # differentiation expression
  1501. subs.append((xi, vi))
  1502. if subs:
  1503. # handle any residual substitution in the expression
  1504. newe = newe._subs(old, new)
  1505. # return the Subs-wrapped derivative
  1506. return Subs(Derivative(newe, *vc), *zip(*subs))
  1507. # everything was ok
  1508. return _derivative_dispatch(*newargs)
  1509. def _eval_lseries(self, x, logx, cdir=0):
  1510. dx = self.variables
  1511. for term in self.expr.lseries(x, logx=logx, cdir=cdir):
  1512. yield self.func(term, *dx)
  1513. def _eval_nseries(self, x, n, logx, cdir=0):
  1514. arg = self.expr.nseries(x, n=n, logx=logx)
  1515. o = arg.getO()
  1516. dx = self.variables
  1517. rv = [self.func(a, *dx) for a in Add.make_args(arg.removeO())]
  1518. if o:
  1519. rv.append(o/x)
  1520. return Add(*rv)
  1521. def _eval_as_leading_term(self, x, logx=None, cdir=0):
  1522. series_gen = self.expr.lseries(x)
  1523. d = S.Zero
  1524. for leading_term in series_gen:
  1525. d = diff(leading_term, *self.variables)
  1526. if d != 0:
  1527. break
  1528. return d
  1529. def as_finite_difference(self, points=1, x0=None, wrt=None):
  1530. """ Expresses a Derivative instance as a finite difference.
  1531. Parameters
  1532. ==========
  1533. points : sequence or coefficient, optional
  1534. If sequence: discrete values (length >= order+1) of the
  1535. independent variable used for generating the finite
  1536. difference weights.
  1537. If it is a coefficient, it will be used as the step-size
  1538. for generating an equidistant sequence of length order+1
  1539. centered around ``x0``. Default: 1 (step-size 1)
  1540. x0 : number or Symbol, optional
  1541. the value of the independent variable (``wrt``) at which the
  1542. derivative is to be approximated. Default: same as ``wrt``.
  1543. wrt : Symbol, optional
  1544. "with respect to" the variable for which the (partial)
  1545. derivative is to be approximated for. If not provided it
  1546. is required that the derivative is ordinary. Default: ``None``.
  1547. Examples
  1548. ========
  1549. >>> from sympy import symbols, Function, exp, sqrt, Symbol
  1550. >>> x, h = symbols('x h')
  1551. >>> f = Function('f')
  1552. >>> f(x).diff(x).as_finite_difference()
  1553. -f(x - 1/2) + f(x + 1/2)
  1554. The default step size and number of points are 1 and
  1555. ``order + 1`` respectively. We can change the step size by
  1556. passing a symbol as a parameter:
  1557. >>> f(x).diff(x).as_finite_difference(h)
  1558. -f(-h/2 + x)/h + f(h/2 + x)/h
  1559. We can also specify the discretized values to be used in a
  1560. sequence:
  1561. >>> f(x).diff(x).as_finite_difference([x, x+h, x+2*h])
  1562. -3*f(x)/(2*h) + 2*f(h + x)/h - f(2*h + x)/(2*h)
  1563. The algorithm is not restricted to use equidistant spacing, nor
  1564. do we need to make the approximation around ``x0``, but we can get
  1565. an expression estimating the derivative at an offset:
  1566. >>> e, sq2 = exp(1), sqrt(2)
  1567. >>> xl = [x-h, x+h, x+e*h]
  1568. >>> f(x).diff(x, 1).as_finite_difference(xl, x+h*sq2) # doctest: +ELLIPSIS
  1569. 2*h*((h + sqrt(2)*h)/(2*h) - (-sqrt(2)*h + h)/(2*h))*f(E*h + x)/...
  1570. To approximate ``Derivative`` around ``x0`` using a non-equidistant
  1571. spacing step, the algorithm supports assignment of undefined
  1572. functions to ``points``:
  1573. >>> dx = Function('dx')
  1574. >>> f(x).diff(x).as_finite_difference(points=dx(x), x0=x-h)
  1575. -f(-h + x - dx(-h + x)/2)/dx(-h + x) + f(-h + x + dx(-h + x)/2)/dx(-h + x)
  1576. Partial derivatives are also supported:
  1577. >>> y = Symbol('y')
  1578. >>> d2fdxdy=f(x,y).diff(x,y)
  1579. >>> d2fdxdy.as_finite_difference(wrt=x)
  1580. -Derivative(f(x - 1/2, y), y) + Derivative(f(x + 1/2, y), y)
  1581. We can apply ``as_finite_difference`` to ``Derivative`` instances in
  1582. compound expressions using ``replace``:
  1583. >>> (1 + 42**f(x).diff(x)).replace(lambda arg: arg.is_Derivative,
  1584. ... lambda arg: arg.as_finite_difference())
  1585. 42**(-f(x - 1/2) + f(x + 1/2)) + 1
  1586. See also
  1587. ========
  1588. sympy.calculus.finite_diff.apply_finite_diff
  1589. sympy.calculus.finite_diff.differentiate_finite
  1590. sympy.calculus.finite_diff.finite_diff_weights
  1591. """
  1592. from sympy.calculus.finite_diff import _as_finite_diff
  1593. return _as_finite_diff(self, points, x0, wrt)
  1594. @classmethod
  1595. def _get_zero_with_shape_like(cls, expr):
  1596. return S.Zero
  1597. @classmethod
  1598. def _dispatch_eval_derivative_n_times(cls, expr, v, count):
  1599. # Evaluate the derivative `n` times. If
  1600. # `_eval_derivative_n_times` is not overridden by the current
  1601. # object, the default in `Basic` will call a loop over
  1602. # `_eval_derivative`:
  1603. return expr._eval_derivative_n_times(v, count)
  1604. def _derivative_dispatch(expr, *variables, **kwargs):
  1605. from sympy.matrices.common import MatrixCommon
  1606. from sympy.matrices.expressions.matexpr import MatrixExpr
  1607. from sympy.tensor.array import NDimArray
  1608. array_types = (MatrixCommon, MatrixExpr, NDimArray, list, tuple, Tuple)
  1609. if isinstance(expr, array_types) or any(isinstance(i[0], array_types) if isinstance(i, (tuple, list, Tuple)) else isinstance(i, array_types) for i in variables):
  1610. from sympy.tensor.array.array_derivatives import ArrayDerivative
  1611. return ArrayDerivative(expr, *variables, **kwargs)
  1612. return Derivative(expr, *variables, **kwargs)
  1613. class Lambda(Expr):
  1614. """
  1615. Lambda(x, expr) represents a lambda function similar to Python's
  1616. 'lambda x: expr'. A function of several variables is written as
  1617. Lambda((x, y, ...), expr).
  1618. Examples
  1619. ========
  1620. A simple example:
  1621. >>> from sympy import Lambda
  1622. >>> from sympy.abc import x
  1623. >>> f = Lambda(x, x**2)
  1624. >>> f(4)
  1625. 16
  1626. For multivariate functions, use:
  1627. >>> from sympy.abc import y, z, t
  1628. >>> f2 = Lambda((x, y, z, t), x + y**z + t**z)
  1629. >>> f2(1, 2, 3, 4)
  1630. 73
  1631. It is also possible to unpack tuple arguments:
  1632. >>> f = Lambda(((x, y), z), x + y + z)
  1633. >>> f((1, 2), 3)
  1634. 6
  1635. A handy shortcut for lots of arguments:
  1636. >>> p = x, y, z
  1637. >>> f = Lambda(p, x + y*z)
  1638. >>> f(*p)
  1639. x + y*z
  1640. """
  1641. is_Function = True
  1642. def __new__(cls, signature, expr):
  1643. if iterable(signature) and not isinstance(signature, (tuple, Tuple)):
  1644. sympy_deprecation_warning(
  1645. """
  1646. Using a non-tuple iterable as the first argument to Lambda
  1647. is deprecated. Use Lambda(tuple(args), expr) instead.
  1648. """,
  1649. deprecated_since_version="1.5",
  1650. active_deprecations_target="deprecated-non-tuple-lambda",
  1651. )
  1652. signature = tuple(signature)
  1653. sig = signature if iterable(signature) else (signature,)
  1654. sig = sympify(sig)
  1655. cls._check_signature(sig)
  1656. if len(sig) == 1 and sig[0] == expr:
  1657. return S.IdentityFunction
  1658. return Expr.__new__(cls, sig, sympify(expr))
  1659. @classmethod
  1660. def _check_signature(cls, sig):
  1661. syms = set()
  1662. def rcheck(args):
  1663. for a in args:
  1664. if a.is_symbol:
  1665. if a in syms:
  1666. raise BadSignatureError("Duplicate symbol %s" % a)
  1667. syms.add(a)
  1668. elif isinstance(a, Tuple):
  1669. rcheck(a)
  1670. else:
  1671. raise BadSignatureError("Lambda signature should be only tuples"
  1672. " and symbols, not %s" % a)
  1673. if not isinstance(sig, Tuple):
  1674. raise BadSignatureError("Lambda signature should be a tuple not %s" % sig)
  1675. # Recurse through the signature:
  1676. rcheck(sig)
  1677. @property
  1678. def signature(self):
  1679. """The expected form of the arguments to be unpacked into variables"""
  1680. return self._args[0]
  1681. @property
  1682. def expr(self):
  1683. """The return value of the function"""
  1684. return self._args[1]
  1685. @property
  1686. def variables(self):
  1687. """The variables used in the internal representation of the function"""
  1688. def _variables(args):
  1689. if isinstance(args, Tuple):
  1690. for arg in args:
  1691. yield from _variables(arg)
  1692. else:
  1693. yield args
  1694. return tuple(_variables(self.signature))
  1695. @property
  1696. def nargs(self):
  1697. from sympy.sets.sets import FiniteSet
  1698. return FiniteSet(len(self.signature))
  1699. bound_symbols = variables
  1700. @property
  1701. def free_symbols(self):
  1702. return self.expr.free_symbols - set(self.variables)
  1703. def __call__(self, *args):
  1704. n = len(args)
  1705. if n not in self.nargs: # Lambda only ever has 1 value in nargs
  1706. # XXX: exception message must be in exactly this format to
  1707. # make it work with NumPy's functions like vectorize(). See,
  1708. # for example, https://github.com/numpy/numpy/issues/1697.
  1709. # The ideal solution would be just to attach metadata to
  1710. # the exception and change NumPy to take advantage of this.
  1711. ## XXX does this apply to Lambda? If not, remove this comment.
  1712. temp = ('%(name)s takes exactly %(args)s '
  1713. 'argument%(plural)s (%(given)s given)')
  1714. raise BadArgumentsError(temp % {
  1715. 'name': self,
  1716. 'args': list(self.nargs)[0],
  1717. 'plural': 's'*(list(self.nargs)[0] != 1),
  1718. 'given': n})
  1719. d = self._match_signature(self.signature, args)
  1720. return self.expr.xreplace(d)
  1721. def _match_signature(self, sig, args):
  1722. symargmap = {}
  1723. def rmatch(pars, args):
  1724. for par, arg in zip(pars, args):
  1725. if par.is_symbol:
  1726. symargmap[par] = arg
  1727. elif isinstance(par, Tuple):
  1728. if not isinstance(arg, (tuple, Tuple)) or len(args) != len(pars):
  1729. raise BadArgumentsError("Can't match %s and %s" % (args, pars))
  1730. rmatch(par, arg)
  1731. rmatch(sig, args)
  1732. return symargmap
  1733. @property
  1734. def is_identity(self):
  1735. """Return ``True`` if this ``Lambda`` is an identity function. """
  1736. return self.signature == self.expr
  1737. def _eval_evalf(self, prec):
  1738. return self.func(self.args[0], self.args[1].evalf(n=prec_to_dps(prec)))
  1739. class Subs(Expr):
  1740. """
  1741. Represents unevaluated substitutions of an expression.
  1742. ``Subs(expr, x, x0)`` represents the expression resulting
  1743. from substituting x with x0 in expr.
  1744. Parameters
  1745. ==========
  1746. expr : Expr
  1747. An expression.
  1748. x : tuple, variable
  1749. A variable or list of distinct variables.
  1750. x0 : tuple or list of tuples
  1751. A point or list of evaluation points
  1752. corresponding to those variables.
  1753. Notes
  1754. =====
  1755. ``Subs`` objects are generally useful to represent unevaluated derivatives
  1756. calculated at a point.
  1757. The variables may be expressions, but they are subjected to the limitations
  1758. of subs(), so it is usually a good practice to use only symbols for
  1759. variables, since in that case there can be no ambiguity.
  1760. There's no automatic expansion - use the method .doit() to effect all
  1761. possible substitutions of the object and also of objects inside the
  1762. expression.
  1763. When evaluating derivatives at a point that is not a symbol, a Subs object
  1764. is returned. One is also able to calculate derivatives of Subs objects - in
  1765. this case the expression is always expanded (for the unevaluated form, use
  1766. Derivative()).
  1767. Examples
  1768. ========
  1769. >>> from sympy import Subs, Function, sin, cos
  1770. >>> from sympy.abc import x, y, z
  1771. >>> f = Function('f')
  1772. Subs are created when a particular substitution cannot be made. The
  1773. x in the derivative cannot be replaced with 0 because 0 is not a
  1774. valid variables of differentiation:
  1775. >>> f(x).diff(x).subs(x, 0)
  1776. Subs(Derivative(f(x), x), x, 0)
  1777. Once f is known, the derivative and evaluation at 0 can be done:
  1778. >>> _.subs(f, sin).doit() == sin(x).diff(x).subs(x, 0) == cos(0)
  1779. True
  1780. Subs can also be created directly with one or more variables:
  1781. >>> Subs(f(x)*sin(y) + z, (x, y), (0, 1))
  1782. Subs(z + f(x)*sin(y), (x, y), (0, 1))
  1783. >>> _.doit()
  1784. z + f(0)*sin(1)
  1785. Notes
  1786. =====
  1787. In order to allow expressions to combine before doit is done, a
  1788. representation of the Subs expression is used internally to make
  1789. expressions that are superficially different compare the same:
  1790. >>> a, b = Subs(x, x, 0), Subs(y, y, 0)
  1791. >>> a + b
  1792. 2*Subs(x, x, 0)
  1793. This can lead to unexpected consequences when using methods
  1794. like `has` that are cached:
  1795. >>> s = Subs(x, x, 0)
  1796. >>> s.has(x), s.has(y)
  1797. (True, False)
  1798. >>> ss = s.subs(x, y)
  1799. >>> ss.has(x), ss.has(y)
  1800. (True, False)
  1801. >>> s, ss
  1802. (Subs(x, x, 0), Subs(y, y, 0))
  1803. """
  1804. def __new__(cls, expr, variables, point, **assumptions):
  1805. if not is_sequence(variables, Tuple):
  1806. variables = [variables]
  1807. variables = Tuple(*variables)
  1808. if has_dups(variables):
  1809. repeated = [str(v) for v, i in Counter(variables).items() if i > 1]
  1810. __ = ', '.join(repeated)
  1811. raise ValueError(filldedent('''
  1812. The following expressions appear more than once: %s
  1813. ''' % __))
  1814. point = Tuple(*(point if is_sequence(point, Tuple) else [point]))
  1815. if len(point) != len(variables):
  1816. raise ValueError('Number of point values must be the same as '
  1817. 'the number of variables.')
  1818. if not point:
  1819. return sympify(expr)
  1820. # denest
  1821. if isinstance(expr, Subs):
  1822. variables = expr.variables + variables
  1823. point = expr.point + point
  1824. expr = expr.expr
  1825. else:
  1826. expr = sympify(expr)
  1827. # use symbols with names equal to the point value (with prepended _)
  1828. # to give a variable-independent expression
  1829. pre = "_"
  1830. pts = sorted(set(point), key=default_sort_key)
  1831. from sympy.printing.str import StrPrinter
  1832. class CustomStrPrinter(StrPrinter):
  1833. def _print_Dummy(self, expr):
  1834. return str(expr) + str(expr.dummy_index)
  1835. def mystr(expr, **settings):
  1836. p = CustomStrPrinter(settings)
  1837. return p.doprint(expr)
  1838. while 1:
  1839. s_pts = {p: Symbol(pre + mystr(p)) for p in pts}
  1840. reps = [(v, s_pts[p])
  1841. for v, p in zip(variables, point)]
  1842. # if any underscore-prepended symbol is already a free symbol
  1843. # and is a variable with a different point value, then there
  1844. # is a clash, e.g. _0 clashes in Subs(_0 + _1, (_0, _1), (1, 0))
  1845. # because the new symbol that would be created is _1 but _1
  1846. # is already mapped to 0 so __0 and __1 are used for the new
  1847. # symbols
  1848. if any(r in expr.free_symbols and
  1849. r in variables and
  1850. Symbol(pre + mystr(point[variables.index(r)])) != r
  1851. for _, r in reps):
  1852. pre += "_"
  1853. continue
  1854. break
  1855. obj = Expr.__new__(cls, expr, Tuple(*variables), point)
  1856. obj._expr = expr.xreplace(dict(reps))
  1857. return obj
  1858. def _eval_is_commutative(self):
  1859. return self.expr.is_commutative
  1860. def doit(self, **hints):
  1861. e, v, p = self.args
  1862. # remove self mappings
  1863. for i, (vi, pi) in enumerate(zip(v, p)):
  1864. if vi == pi:
  1865. v = v[:i] + v[i + 1:]
  1866. p = p[:i] + p[i + 1:]
  1867. if not v:
  1868. return self.expr
  1869. if isinstance(e, Derivative):
  1870. # apply functions first, e.g. f -> cos
  1871. undone = []
  1872. for i, vi in enumerate(v):
  1873. if isinstance(vi, FunctionClass):
  1874. e = e.subs(vi, p[i])
  1875. else:
  1876. undone.append((vi, p[i]))
  1877. if not isinstance(e, Derivative):
  1878. e = e.doit()
  1879. if isinstance(e, Derivative):
  1880. # do Subs that aren't related to differentiation
  1881. undone2 = []
  1882. D = Dummy()
  1883. arg = e.args[0]
  1884. for vi, pi in undone:
  1885. if D not in e.xreplace({vi: D}).free_symbols:
  1886. if arg.has(vi):
  1887. e = e.subs(vi, pi)
  1888. else:
  1889. undone2.append((vi, pi))
  1890. undone = undone2
  1891. # differentiate wrt variables that are present
  1892. wrt = []
  1893. D = Dummy()
  1894. expr = e.expr
  1895. free = expr.free_symbols
  1896. for vi, ci in e.variable_count:
  1897. if isinstance(vi, Symbol) and vi in free:
  1898. expr = expr.diff((vi, ci))
  1899. elif D in expr.subs(vi, D).free_symbols:
  1900. expr = expr.diff((vi, ci))
  1901. else:
  1902. wrt.append((vi, ci))
  1903. # inject remaining subs
  1904. rv = expr.subs(undone)
  1905. # do remaining differentiation *in order given*
  1906. for vc in wrt:
  1907. rv = rv.diff(vc)
  1908. else:
  1909. # inject remaining subs
  1910. rv = e.subs(undone)
  1911. else:
  1912. rv = e.doit(**hints).subs(list(zip(v, p)))
  1913. if hints.get('deep', True) and rv != self:
  1914. rv = rv.doit(**hints)
  1915. return rv
  1916. def evalf(self, prec=None, **options):
  1917. return self.doit().evalf(prec, **options)
  1918. n = evalf # type:ignore
  1919. @property
  1920. def variables(self):
  1921. """The variables to be evaluated"""
  1922. return self._args[1]
  1923. bound_symbols = variables
  1924. @property
  1925. def expr(self):
  1926. """The expression on which the substitution operates"""
  1927. return self._args[0]
  1928. @property
  1929. def point(self):
  1930. """The values for which the variables are to be substituted"""
  1931. return self._args[2]
  1932. @property
  1933. def free_symbols(self):
  1934. return (self.expr.free_symbols - set(self.variables) |
  1935. set(self.point.free_symbols))
  1936. @property
  1937. def expr_free_symbols(self):
  1938. sympy_deprecation_warning("""
  1939. The expr_free_symbols property is deprecated. Use free_symbols to get
  1940. the free symbols of an expression.
  1941. """,
  1942. deprecated_since_version="1.9",
  1943. active_deprecations_target="deprecated-expr-free-symbols")
  1944. # Don't show the warning twice from the recursive call
  1945. with ignore_warnings(SymPyDeprecationWarning):
  1946. return (self.expr.expr_free_symbols - set(self.variables) |
  1947. set(self.point.expr_free_symbols))
  1948. def __eq__(self, other):
  1949. if not isinstance(other, Subs):
  1950. return False
  1951. return self._hashable_content() == other._hashable_content()
  1952. def __ne__(self, other):
  1953. return not(self == other)
  1954. def __hash__(self):
  1955. return super().__hash__()
  1956. def _hashable_content(self):
  1957. return (self._expr.xreplace(self.canonical_variables),
  1958. ) + tuple(ordered([(v, p) for v, p in
  1959. zip(self.variables, self.point) if not self.expr.has(v)]))
  1960. def _eval_subs(self, old, new):
  1961. # Subs doit will do the variables in order; the semantics
  1962. # of subs for Subs is have the following invariant for
  1963. # Subs object foo:
  1964. # foo.doit().subs(reps) == foo.subs(reps).doit()
  1965. pt = list(self.point)
  1966. if old in self.variables:
  1967. if _atomic(new) == {new} and not any(
  1968. i.has(new) for i in self.args):
  1969. # the substitution is neutral
  1970. return self.xreplace({old: new})
  1971. # any occurrence of old before this point will get
  1972. # handled by replacements from here on
  1973. i = self.variables.index(old)
  1974. for j in range(i, len(self.variables)):
  1975. pt[j] = pt[j]._subs(old, new)
  1976. return self.func(self.expr, self.variables, pt)
  1977. v = [i._subs(old, new) for i in self.variables]
  1978. if v != list(self.variables):
  1979. return self.func(self.expr, self.variables + (old,), pt + [new])
  1980. expr = self.expr._subs(old, new)
  1981. pt = [i._subs(old, new) for i in self.point]
  1982. return self.func(expr, v, pt)
  1983. def _eval_derivative(self, s):
  1984. # Apply the chain rule of the derivative on the substitution variables:
  1985. f = self.expr
  1986. vp = V, P = self.variables, self.point
  1987. val = Add.fromiter(p.diff(s)*Subs(f.diff(v), *vp).doit()
  1988. for v, p in zip(V, P))
  1989. # these are all the free symbols in the expr
  1990. efree = f.free_symbols
  1991. # some symbols like IndexedBase include themselves and args
  1992. # as free symbols
  1993. compound = {i for i in efree if len(i.free_symbols) > 1}
  1994. # hide them and see what independent free symbols remain
  1995. dums = {Dummy() for i in compound}
  1996. masked = f.xreplace(dict(zip(compound, dums)))
  1997. ifree = masked.free_symbols - dums
  1998. # include the compound symbols
  1999. free = ifree | compound
  2000. # remove the variables already handled
  2001. free -= set(V)
  2002. # add back any free symbols of remaining compound symbols
  2003. free |= {i for j in free & compound for i in j.free_symbols}
  2004. # if symbols of s are in free then there is more to do
  2005. if free & s.free_symbols:
  2006. val += Subs(f.diff(s), self.variables, self.point).doit()
  2007. return val
  2008. def _eval_nseries(self, x, n, logx, cdir=0):
  2009. if x in self.point:
  2010. # x is the variable being substituted into
  2011. apos = self.point.index(x)
  2012. other = self.variables[apos]
  2013. else:
  2014. other = x
  2015. arg = self.expr.nseries(other, n=n, logx=logx)
  2016. o = arg.getO()
  2017. terms = Add.make_args(arg.removeO())
  2018. rv = Add(*[self.func(a, *self.args[1:]) for a in terms])
  2019. if o:
  2020. rv += o.subs(other, x)
  2021. return rv
  2022. def _eval_as_leading_term(self, x, logx=None, cdir=0):
  2023. if x in self.point:
  2024. ipos = self.point.index(x)
  2025. xvar = self.variables[ipos]
  2026. return self.expr.as_leading_term(xvar)
  2027. if x in self.variables:
  2028. # if `x` is a dummy variable, it means it won't exist after the
  2029. # substitution has been performed:
  2030. return self
  2031. # The variable is independent of the substitution:
  2032. return self.expr.as_leading_term(x)
  2033. def diff(f, *symbols, **kwargs):
  2034. """
  2035. Differentiate f with respect to symbols.
  2036. Explanation
  2037. ===========
  2038. This is just a wrapper to unify .diff() and the Derivative class; its
  2039. interface is similar to that of integrate(). You can use the same
  2040. shortcuts for multiple variables as with Derivative. For example,
  2041. diff(f(x), x, x, x) and diff(f(x), x, 3) both return the third derivative
  2042. of f(x).
  2043. You can pass evaluate=False to get an unevaluated Derivative class. Note
  2044. that if there are 0 symbols (such as diff(f(x), x, 0), then the result will
  2045. be the function (the zeroth derivative), even if evaluate=False.
  2046. Examples
  2047. ========
  2048. >>> from sympy import sin, cos, Function, diff
  2049. >>> from sympy.abc import x, y
  2050. >>> f = Function('f')
  2051. >>> diff(sin(x), x)
  2052. cos(x)
  2053. >>> diff(f(x), x, x, x)
  2054. Derivative(f(x), (x, 3))
  2055. >>> diff(f(x), x, 3)
  2056. Derivative(f(x), (x, 3))
  2057. >>> diff(sin(x)*cos(y), x, 2, y, 2)
  2058. sin(x)*cos(y)
  2059. >>> type(diff(sin(x), x))
  2060. cos
  2061. >>> type(diff(sin(x), x, evaluate=False))
  2062. <class 'sympy.core.function.Derivative'>
  2063. >>> type(diff(sin(x), x, 0))
  2064. sin
  2065. >>> type(diff(sin(x), x, 0, evaluate=False))
  2066. sin
  2067. >>> diff(sin(x))
  2068. cos(x)
  2069. >>> diff(sin(x*y))
  2070. Traceback (most recent call last):
  2071. ...
  2072. ValueError: specify differentiation variables to differentiate sin(x*y)
  2073. Note that ``diff(sin(x))`` syntax is meant only for convenience
  2074. in interactive sessions and should be avoided in library code.
  2075. References
  2076. ==========
  2077. .. [1] http://reference.wolfram.com/legacy/v5_2/Built-inFunctions/AlgebraicComputation/Calculus/D.html
  2078. See Also
  2079. ========
  2080. Derivative
  2081. idiff: computes the derivative implicitly
  2082. """
  2083. if hasattr(f, 'diff'):
  2084. return f.diff(*symbols, **kwargs)
  2085. kwargs.setdefault('evaluate', True)
  2086. return _derivative_dispatch(f, *symbols, **kwargs)
  2087. def expand(e, deep=True, modulus=None, power_base=True, power_exp=True,
  2088. mul=True, log=True, multinomial=True, basic=True, **hints):
  2089. r"""
  2090. Expand an expression using methods given as hints.
  2091. Explanation
  2092. ===========
  2093. Hints evaluated unless explicitly set to False are: ``basic``, ``log``,
  2094. ``multinomial``, ``mul``, ``power_base``, and ``power_exp`` The following
  2095. hints are supported but not applied unless set to True: ``complex``,
  2096. ``func``, and ``trig``. In addition, the following meta-hints are
  2097. supported by some or all of the other hints: ``frac``, ``numer``,
  2098. ``denom``, ``modulus``, and ``force``. ``deep`` is supported by all
  2099. hints. Additionally, subclasses of Expr may define their own hints or
  2100. meta-hints.
  2101. The ``basic`` hint is used for any special rewriting of an object that
  2102. should be done automatically (along with the other hints like ``mul``)
  2103. when expand is called. This is a catch-all hint to handle any sort of
  2104. expansion that may not be described by the existing hint names. To use
  2105. this hint an object should override the ``_eval_expand_basic`` method.
  2106. Objects may also define their own expand methods, which are not run by
  2107. default. See the API section below.
  2108. If ``deep`` is set to ``True`` (the default), things like arguments of
  2109. functions are recursively expanded. Use ``deep=False`` to only expand on
  2110. the top level.
  2111. If the ``force`` hint is used, assumptions about variables will be ignored
  2112. in making the expansion.
  2113. Hints
  2114. =====
  2115. These hints are run by default
  2116. mul
  2117. ---
  2118. Distributes multiplication over addition:
  2119. >>> from sympy import cos, exp, sin
  2120. >>> from sympy.abc import x, y, z
  2121. >>> (y*(x + z)).expand(mul=True)
  2122. x*y + y*z
  2123. multinomial
  2124. -----------
  2125. Expand (x + y + ...)**n where n is a positive integer.
  2126. >>> ((x + y + z)**2).expand(multinomial=True)
  2127. x**2 + 2*x*y + 2*x*z + y**2 + 2*y*z + z**2
  2128. power_exp
  2129. ---------
  2130. Expand addition in exponents into multiplied bases.
  2131. >>> exp(x + y).expand(power_exp=True)
  2132. exp(x)*exp(y)
  2133. >>> (2**(x + y)).expand(power_exp=True)
  2134. 2**x*2**y
  2135. power_base
  2136. ----------
  2137. Split powers of multiplied bases.
  2138. This only happens by default if assumptions allow, or if the
  2139. ``force`` meta-hint is used:
  2140. >>> ((x*y)**z).expand(power_base=True)
  2141. (x*y)**z
  2142. >>> ((x*y)**z).expand(power_base=True, force=True)
  2143. x**z*y**z
  2144. >>> ((2*y)**z).expand(power_base=True)
  2145. 2**z*y**z
  2146. Note that in some cases where this expansion always holds, SymPy performs
  2147. it automatically:
  2148. >>> (x*y)**2
  2149. x**2*y**2
  2150. log
  2151. ---
  2152. Pull out power of an argument as a coefficient and split logs products
  2153. into sums of logs.
  2154. Note that these only work if the arguments of the log function have the
  2155. proper assumptions--the arguments must be positive and the exponents must
  2156. be real--or else the ``force`` hint must be True:
  2157. >>> from sympy import log, symbols
  2158. >>> log(x**2*y).expand(log=True)
  2159. log(x**2*y)
  2160. >>> log(x**2*y).expand(log=True, force=True)
  2161. 2*log(x) + log(y)
  2162. >>> x, y = symbols('x,y', positive=True)
  2163. >>> log(x**2*y).expand(log=True)
  2164. 2*log(x) + log(y)
  2165. basic
  2166. -----
  2167. This hint is intended primarily as a way for custom subclasses to enable
  2168. expansion by default.
  2169. These hints are not run by default:
  2170. complex
  2171. -------
  2172. Split an expression into real and imaginary parts.
  2173. >>> x, y = symbols('x,y')
  2174. >>> (x + y).expand(complex=True)
  2175. re(x) + re(y) + I*im(x) + I*im(y)
  2176. >>> cos(x).expand(complex=True)
  2177. -I*sin(re(x))*sinh(im(x)) + cos(re(x))*cosh(im(x))
  2178. Note that this is just a wrapper around ``as_real_imag()``. Most objects
  2179. that wish to redefine ``_eval_expand_complex()`` should consider
  2180. redefining ``as_real_imag()`` instead.
  2181. func
  2182. ----
  2183. Expand other functions.
  2184. >>> from sympy import gamma
  2185. >>> gamma(x + 1).expand(func=True)
  2186. x*gamma(x)
  2187. trig
  2188. ----
  2189. Do trigonometric expansions.
  2190. >>> cos(x + y).expand(trig=True)
  2191. -sin(x)*sin(y) + cos(x)*cos(y)
  2192. >>> sin(2*x).expand(trig=True)
  2193. 2*sin(x)*cos(x)
  2194. Note that the forms of ``sin(n*x)`` and ``cos(n*x)`` in terms of ``sin(x)``
  2195. and ``cos(x)`` are not unique, due to the identity `\sin^2(x) + \cos^2(x)
  2196. = 1`. The current implementation uses the form obtained from Chebyshev
  2197. polynomials, but this may change. See `this MathWorld article
  2198. <http://mathworld.wolfram.com/Multiple-AngleFormulas.html>`_ for more
  2199. information.
  2200. Notes
  2201. =====
  2202. - You can shut off unwanted methods::
  2203. >>> (exp(x + y)*(x + y)).expand()
  2204. x*exp(x)*exp(y) + y*exp(x)*exp(y)
  2205. >>> (exp(x + y)*(x + y)).expand(power_exp=False)
  2206. x*exp(x + y) + y*exp(x + y)
  2207. >>> (exp(x + y)*(x + y)).expand(mul=False)
  2208. (x + y)*exp(x)*exp(y)
  2209. - Use deep=False to only expand on the top level::
  2210. >>> exp(x + exp(x + y)).expand()
  2211. exp(x)*exp(exp(x)*exp(y))
  2212. >>> exp(x + exp(x + y)).expand(deep=False)
  2213. exp(x)*exp(exp(x + y))
  2214. - Hints are applied in an arbitrary, but consistent order (in the current
  2215. implementation, they are applied in alphabetical order, except
  2216. multinomial comes before mul, but this may change). Because of this,
  2217. some hints may prevent expansion by other hints if they are applied
  2218. first. For example, ``mul`` may distribute multiplications and prevent
  2219. ``log`` and ``power_base`` from expanding them. Also, if ``mul`` is
  2220. applied before ``multinomial`, the expression might not be fully
  2221. distributed. The solution is to use the various ``expand_hint`` helper
  2222. functions or to use ``hint=False`` to this function to finely control
  2223. which hints are applied. Here are some examples::
  2224. >>> from sympy import expand, expand_mul, expand_power_base
  2225. >>> x, y, z = symbols('x,y,z', positive=True)
  2226. >>> expand(log(x*(y + z)))
  2227. log(x) + log(y + z)
  2228. Here, we see that ``log`` was applied before ``mul``. To get the mul
  2229. expanded form, either of the following will work::
  2230. >>> expand_mul(log(x*(y + z)))
  2231. log(x*y + x*z)
  2232. >>> expand(log(x*(y + z)), log=False)
  2233. log(x*y + x*z)
  2234. A similar thing can happen with the ``power_base`` hint::
  2235. >>> expand((x*(y + z))**x)
  2236. (x*y + x*z)**x
  2237. To get the ``power_base`` expanded form, either of the following will
  2238. work::
  2239. >>> expand((x*(y + z))**x, mul=False)
  2240. x**x*(y + z)**x
  2241. >>> expand_power_base((x*(y + z))**x)
  2242. x**x*(y + z)**x
  2243. >>> expand((x + y)*y/x)
  2244. y + y**2/x
  2245. The parts of a rational expression can be targeted::
  2246. >>> expand((x + y)*y/x/(x + 1), frac=True)
  2247. (x*y + y**2)/(x**2 + x)
  2248. >>> expand((x + y)*y/x/(x + 1), numer=True)
  2249. (x*y + y**2)/(x*(x + 1))
  2250. >>> expand((x + y)*y/x/(x + 1), denom=True)
  2251. y*(x + y)/(x**2 + x)
  2252. - The ``modulus`` meta-hint can be used to reduce the coefficients of an
  2253. expression post-expansion::
  2254. >>> expand((3*x + 1)**2)
  2255. 9*x**2 + 6*x + 1
  2256. >>> expand((3*x + 1)**2, modulus=5)
  2257. 4*x**2 + x + 1
  2258. - Either ``expand()`` the function or ``.expand()`` the method can be
  2259. used. Both are equivalent::
  2260. >>> expand((x + 1)**2)
  2261. x**2 + 2*x + 1
  2262. >>> ((x + 1)**2).expand()
  2263. x**2 + 2*x + 1
  2264. API
  2265. ===
  2266. Objects can define their own expand hints by defining
  2267. ``_eval_expand_hint()``. The function should take the form::
  2268. def _eval_expand_hint(self, **hints):
  2269. # Only apply the method to the top-level expression
  2270. ...
  2271. See also the example below. Objects should define ``_eval_expand_hint()``
  2272. methods only if ``hint`` applies to that specific object. The generic
  2273. ``_eval_expand_hint()`` method defined in Expr will handle the no-op case.
  2274. Each hint should be responsible for expanding that hint only.
  2275. Furthermore, the expansion should be applied to the top-level expression
  2276. only. ``expand()`` takes care of the recursion that happens when
  2277. ``deep=True``.
  2278. You should only call ``_eval_expand_hint()`` methods directly if you are
  2279. 100% sure that the object has the method, as otherwise you are liable to
  2280. get unexpected ``AttributeError``s. Note, again, that you do not need to
  2281. recursively apply the hint to args of your object: this is handled
  2282. automatically by ``expand()``. ``_eval_expand_hint()`` should
  2283. generally not be used at all outside of an ``_eval_expand_hint()`` method.
  2284. If you want to apply a specific expansion from within another method, use
  2285. the public ``expand()`` function, method, or ``expand_hint()`` functions.
  2286. In order for expand to work, objects must be rebuildable by their args,
  2287. i.e., ``obj.func(*obj.args) == obj`` must hold.
  2288. Expand methods are passed ``**hints`` so that expand hints may use
  2289. 'metahints'--hints that control how different expand methods are applied.
  2290. For example, the ``force=True`` hint described above that causes
  2291. ``expand(log=True)`` to ignore assumptions is such a metahint. The
  2292. ``deep`` meta-hint is handled exclusively by ``expand()`` and is not
  2293. passed to ``_eval_expand_hint()`` methods.
  2294. Note that expansion hints should generally be methods that perform some
  2295. kind of 'expansion'. For hints that simply rewrite an expression, use the
  2296. .rewrite() API.
  2297. Examples
  2298. ========
  2299. >>> from sympy import Expr, sympify
  2300. >>> class MyClass(Expr):
  2301. ... def __new__(cls, *args):
  2302. ... args = sympify(args)
  2303. ... return Expr.__new__(cls, *args)
  2304. ...
  2305. ... def _eval_expand_double(self, *, force=False, **hints):
  2306. ... '''
  2307. ... Doubles the args of MyClass.
  2308. ...
  2309. ... If there more than four args, doubling is not performed,
  2310. ... unless force=True is also used (False by default).
  2311. ... '''
  2312. ... if not force and len(self.args) > 4:
  2313. ... return self
  2314. ... return self.func(*(self.args + self.args))
  2315. ...
  2316. >>> a = MyClass(1, 2, MyClass(3, 4))
  2317. >>> a
  2318. MyClass(1, 2, MyClass(3, 4))
  2319. >>> a.expand(double=True)
  2320. MyClass(1, 2, MyClass(3, 4, 3, 4), 1, 2, MyClass(3, 4, 3, 4))
  2321. >>> a.expand(double=True, deep=False)
  2322. MyClass(1, 2, MyClass(3, 4), 1, 2, MyClass(3, 4))
  2323. >>> b = MyClass(1, 2, 3, 4, 5)
  2324. >>> b.expand(double=True)
  2325. MyClass(1, 2, 3, 4, 5)
  2326. >>> b.expand(double=True, force=True)
  2327. MyClass(1, 2, 3, 4, 5, 1, 2, 3, 4, 5)
  2328. See Also
  2329. ========
  2330. expand_log, expand_mul, expand_multinomial, expand_complex, expand_trig,
  2331. expand_power_base, expand_power_exp, expand_func, sympy.simplify.hyperexpand.hyperexpand
  2332. """
  2333. # don't modify this; modify the Expr.expand method
  2334. hints['power_base'] = power_base
  2335. hints['power_exp'] = power_exp
  2336. hints['mul'] = mul
  2337. hints['log'] = log
  2338. hints['multinomial'] = multinomial
  2339. hints['basic'] = basic
  2340. return sympify(e).expand(deep=deep, modulus=modulus, **hints)
  2341. # This is a special application of two hints
  2342. def _mexpand(expr, recursive=False):
  2343. # expand multinomials and then expand products; this may not always
  2344. # be sufficient to give a fully expanded expression (see
  2345. # test_issue_8247_8354 in test_arit)
  2346. if expr is None:
  2347. return
  2348. was = None
  2349. while was != expr:
  2350. was, expr = expr, expand_mul(expand_multinomial(expr))
  2351. if not recursive:
  2352. break
  2353. return expr
  2354. # These are simple wrappers around single hints.
  2355. def expand_mul(expr, deep=True):
  2356. """
  2357. Wrapper around expand that only uses the mul hint. See the expand
  2358. docstring for more information.
  2359. Examples
  2360. ========
  2361. >>> from sympy import symbols, expand_mul, exp, log
  2362. >>> x, y = symbols('x,y', positive=True)
  2363. >>> expand_mul(exp(x+y)*(x+y)*log(x*y**2))
  2364. x*exp(x + y)*log(x*y**2) + y*exp(x + y)*log(x*y**2)
  2365. """
  2366. return sympify(expr).expand(deep=deep, mul=True, power_exp=False,
  2367. power_base=False, basic=False, multinomial=False, log=False)
  2368. def expand_multinomial(expr, deep=True):
  2369. """
  2370. Wrapper around expand that only uses the multinomial hint. See the expand
  2371. docstring for more information.
  2372. Examples
  2373. ========
  2374. >>> from sympy import symbols, expand_multinomial, exp
  2375. >>> x, y = symbols('x y', positive=True)
  2376. >>> expand_multinomial((x + exp(x + 1))**2)
  2377. x**2 + 2*x*exp(x + 1) + exp(2*x + 2)
  2378. """
  2379. return sympify(expr).expand(deep=deep, mul=False, power_exp=False,
  2380. power_base=False, basic=False, multinomial=True, log=False)
  2381. def expand_log(expr, deep=True, force=False, factor=False):
  2382. """
  2383. Wrapper around expand that only uses the log hint. See the expand
  2384. docstring for more information.
  2385. Examples
  2386. ========
  2387. >>> from sympy import symbols, expand_log, exp, log
  2388. >>> x, y = symbols('x,y', positive=True)
  2389. >>> expand_log(exp(x+y)*(x+y)*log(x*y**2))
  2390. (x + y)*(log(x) + 2*log(y))*exp(x + y)
  2391. """
  2392. from sympy.functions.elementary.exponential import log
  2393. if factor is False:
  2394. def _handle(x):
  2395. x1 = expand_mul(expand_log(x, deep=deep, force=force, factor=True))
  2396. if x1.count(log) <= x.count(log):
  2397. return x1
  2398. return x
  2399. expr = expr.replace(
  2400. lambda x: x.is_Mul and all(any(isinstance(i, log) and i.args[0].is_Rational
  2401. for i in Mul.make_args(j)) for j in x.as_numer_denom()),
  2402. _handle)
  2403. return sympify(expr).expand(deep=deep, log=True, mul=False,
  2404. power_exp=False, power_base=False, multinomial=False,
  2405. basic=False, force=force, factor=factor)
  2406. def expand_func(expr, deep=True):
  2407. """
  2408. Wrapper around expand that only uses the func hint. See the expand
  2409. docstring for more information.
  2410. Examples
  2411. ========
  2412. >>> from sympy import expand_func, gamma
  2413. >>> from sympy.abc import x
  2414. >>> expand_func(gamma(x + 2))
  2415. x*(x + 1)*gamma(x)
  2416. """
  2417. return sympify(expr).expand(deep=deep, func=True, basic=False,
  2418. log=False, mul=False, power_exp=False, power_base=False, multinomial=False)
  2419. def expand_trig(expr, deep=True):
  2420. """
  2421. Wrapper around expand that only uses the trig hint. See the expand
  2422. docstring for more information.
  2423. Examples
  2424. ========
  2425. >>> from sympy import expand_trig, sin
  2426. >>> from sympy.abc import x, y
  2427. >>> expand_trig(sin(x+y)*(x+y))
  2428. (x + y)*(sin(x)*cos(y) + sin(y)*cos(x))
  2429. """
  2430. return sympify(expr).expand(deep=deep, trig=True, basic=False,
  2431. log=False, mul=False, power_exp=False, power_base=False, multinomial=False)
  2432. def expand_complex(expr, deep=True):
  2433. """
  2434. Wrapper around expand that only uses the complex hint. See the expand
  2435. docstring for more information.
  2436. Examples
  2437. ========
  2438. >>> from sympy import expand_complex, exp, sqrt, I
  2439. >>> from sympy.abc import z
  2440. >>> expand_complex(exp(z))
  2441. I*exp(re(z))*sin(im(z)) + exp(re(z))*cos(im(z))
  2442. >>> expand_complex(sqrt(I))
  2443. sqrt(2)/2 + sqrt(2)*I/2
  2444. See Also
  2445. ========
  2446. sympy.core.expr.Expr.as_real_imag
  2447. """
  2448. return sympify(expr).expand(deep=deep, complex=True, basic=False,
  2449. log=False, mul=False, power_exp=False, power_base=False, multinomial=False)
  2450. def expand_power_base(expr, deep=True, force=False):
  2451. """
  2452. Wrapper around expand that only uses the power_base hint.
  2453. A wrapper to expand(power_base=True) which separates a power with a base
  2454. that is a Mul into a product of powers, without performing any other
  2455. expansions, provided that assumptions about the power's base and exponent
  2456. allow.
  2457. deep=False (default is True) will only apply to the top-level expression.
  2458. force=True (default is False) will cause the expansion to ignore
  2459. assumptions about the base and exponent. When False, the expansion will
  2460. only happen if the base is non-negative or the exponent is an integer.
  2461. >>> from sympy.abc import x, y, z
  2462. >>> from sympy import expand_power_base, sin, cos, exp
  2463. >>> (x*y)**2
  2464. x**2*y**2
  2465. >>> (2*x)**y
  2466. (2*x)**y
  2467. >>> expand_power_base(_)
  2468. 2**y*x**y
  2469. >>> expand_power_base((x*y)**z)
  2470. (x*y)**z
  2471. >>> expand_power_base((x*y)**z, force=True)
  2472. x**z*y**z
  2473. >>> expand_power_base(sin((x*y)**z), deep=False)
  2474. sin((x*y)**z)
  2475. >>> expand_power_base(sin((x*y)**z), force=True)
  2476. sin(x**z*y**z)
  2477. >>> expand_power_base((2*sin(x))**y + (2*cos(x))**y)
  2478. 2**y*sin(x)**y + 2**y*cos(x)**y
  2479. >>> expand_power_base((2*exp(y))**x)
  2480. 2**x*exp(y)**x
  2481. >>> expand_power_base((2*cos(x))**y)
  2482. 2**y*cos(x)**y
  2483. Notice that sums are left untouched. If this is not the desired behavior,
  2484. apply full ``expand()`` to the expression:
  2485. >>> expand_power_base(((x+y)*z)**2)
  2486. z**2*(x + y)**2
  2487. >>> (((x+y)*z)**2).expand()
  2488. x**2*z**2 + 2*x*y*z**2 + y**2*z**2
  2489. >>> expand_power_base((2*y)**(1+z))
  2490. 2**(z + 1)*y**(z + 1)
  2491. >>> ((2*y)**(1+z)).expand()
  2492. 2*2**z*y*y**z
  2493. See Also
  2494. ========
  2495. expand
  2496. """
  2497. return sympify(expr).expand(deep=deep, log=False, mul=False,
  2498. power_exp=False, power_base=True, multinomial=False,
  2499. basic=False, force=force)
  2500. def expand_power_exp(expr, deep=True):
  2501. """
  2502. Wrapper around expand that only uses the power_exp hint.
  2503. See the expand docstring for more information.
  2504. Examples
  2505. ========
  2506. >>> from sympy import expand_power_exp
  2507. >>> from sympy.abc import x, y
  2508. >>> expand_power_exp(x**(y + 2))
  2509. x**2*x**y
  2510. """
  2511. return sympify(expr).expand(deep=deep, complex=False, basic=False,
  2512. log=False, mul=False, power_exp=True, power_base=False, multinomial=False)
  2513. def count_ops(expr, visual=False):
  2514. """
  2515. Return a representation (integer or expression) of the operations in expr.
  2516. Parameters
  2517. ==========
  2518. expr : Expr
  2519. If expr is an iterable, the sum of the op counts of the
  2520. items will be returned.
  2521. visual : bool, optional
  2522. If ``False`` (default) then the sum of the coefficients of the
  2523. visual expression will be returned.
  2524. If ``True`` then the number of each type of operation is shown
  2525. with the core class types (or their virtual equivalent) multiplied by the
  2526. number of times they occur.
  2527. Examples
  2528. ========
  2529. >>> from sympy.abc import a, b, x, y
  2530. >>> from sympy import sin, count_ops
  2531. Although there isn't a SUB object, minus signs are interpreted as
  2532. either negations or subtractions:
  2533. >>> (x - y).count_ops(visual=True)
  2534. SUB
  2535. >>> (-x).count_ops(visual=True)
  2536. NEG
  2537. Here, there are two Adds and a Pow:
  2538. >>> (1 + a + b**2).count_ops(visual=True)
  2539. 2*ADD + POW
  2540. In the following, an Add, Mul, Pow and two functions:
  2541. >>> (sin(x)*x + sin(x)**2).count_ops(visual=True)
  2542. ADD + MUL + POW + 2*SIN
  2543. for a total of 5:
  2544. >>> (sin(x)*x + sin(x)**2).count_ops(visual=False)
  2545. 5
  2546. Note that "what you type" is not always what you get. The expression
  2547. 1/x/y is translated by sympy into 1/(x*y) so it gives a DIV and MUL rather
  2548. than two DIVs:
  2549. >>> (1/x/y).count_ops(visual=True)
  2550. DIV + MUL
  2551. The visual option can be used to demonstrate the difference in
  2552. operations for expressions in different forms. Here, the Horner
  2553. representation is compared with the expanded form of a polynomial:
  2554. >>> eq=x*(1 + x*(2 + x*(3 + x)))
  2555. >>> count_ops(eq.expand(), visual=True) - count_ops(eq, visual=True)
  2556. -MUL + 3*POW
  2557. The count_ops function also handles iterables:
  2558. >>> count_ops([x, sin(x), None, True, x + 2], visual=False)
  2559. 2
  2560. >>> count_ops([x, sin(x), None, True, x + 2], visual=True)
  2561. ADD + SIN
  2562. >>> count_ops({x: sin(x), x + 2: y + 1}, visual=True)
  2563. 2*ADD + SIN
  2564. """
  2565. from .relational import Relational
  2566. from sympy.concrete.summations import Sum
  2567. from sympy.integrals.integrals import Integral
  2568. from sympy.logic.boolalg import BooleanFunction
  2569. from sympy.simplify.radsimp import fraction
  2570. expr = sympify(expr)
  2571. if isinstance(expr, Expr) and not expr.is_Relational:
  2572. ops = []
  2573. args = [expr]
  2574. NEG = Symbol('NEG')
  2575. DIV = Symbol('DIV')
  2576. SUB = Symbol('SUB')
  2577. ADD = Symbol('ADD')
  2578. EXP = Symbol('EXP')
  2579. while args:
  2580. a = args.pop()
  2581. # if the following fails because the object is
  2582. # not Basic type, then the object should be fixed
  2583. # since it is the intention that all args of Basic
  2584. # should themselves be Basic
  2585. if a.is_Rational:
  2586. #-1/3 = NEG + DIV
  2587. if a is not S.One:
  2588. if a.p < 0:
  2589. ops.append(NEG)
  2590. if a.q != 1:
  2591. ops.append(DIV)
  2592. continue
  2593. elif a.is_Mul or a.is_MatMul:
  2594. if _coeff_isneg(a):
  2595. ops.append(NEG)
  2596. if a.args[0] is S.NegativeOne:
  2597. a = a.as_two_terms()[1]
  2598. else:
  2599. a = -a
  2600. n, d = fraction(a)
  2601. if n.is_Integer:
  2602. ops.append(DIV)
  2603. if n < 0:
  2604. ops.append(NEG)
  2605. args.append(d)
  2606. continue # won't be -Mul but could be Add
  2607. elif d is not S.One:
  2608. if not d.is_Integer:
  2609. args.append(d)
  2610. ops.append(DIV)
  2611. args.append(n)
  2612. continue # could be -Mul
  2613. elif a.is_Add or a.is_MatAdd:
  2614. aargs = list(a.args)
  2615. negs = 0
  2616. for i, ai in enumerate(aargs):
  2617. if _coeff_isneg(ai):
  2618. negs += 1
  2619. args.append(-ai)
  2620. if i > 0:
  2621. ops.append(SUB)
  2622. else:
  2623. args.append(ai)
  2624. if i > 0:
  2625. ops.append(ADD)
  2626. if negs == len(aargs): # -x - y = NEG + SUB
  2627. ops.append(NEG)
  2628. elif _coeff_isneg(aargs[0]): # -x + y = SUB, but already recorded ADD
  2629. ops.append(SUB - ADD)
  2630. continue
  2631. if a.is_Pow and a.exp is S.NegativeOne:
  2632. ops.append(DIV)
  2633. args.append(a.base) # won't be -Mul but could be Add
  2634. continue
  2635. if a == S.Exp1:
  2636. ops.append(EXP)
  2637. continue
  2638. if a.is_Pow and a.base == S.Exp1:
  2639. ops.append(EXP)
  2640. args.append(a.exp)
  2641. continue
  2642. if a.is_Mul or isinstance(a, LatticeOp):
  2643. o = Symbol(a.func.__name__.upper())
  2644. # count the args
  2645. ops.append(o*(len(a.args) - 1))
  2646. elif a.args and (
  2647. a.is_Pow or
  2648. a.is_Function or
  2649. isinstance(a, Derivative) or
  2650. isinstance(a, Integral) or
  2651. isinstance(a, Sum)):
  2652. # if it's not in the list above we don't
  2653. # consider a.func something to count, e.g.
  2654. # Tuple, MatrixSymbol, etc...
  2655. if isinstance(a.func, UndefinedFunction):
  2656. o = Symbol("FUNC_" + a.func.__name__.upper())
  2657. else:
  2658. o = Symbol(a.func.__name__.upper())
  2659. ops.append(o)
  2660. if not a.is_Symbol:
  2661. args.extend(a.args)
  2662. elif isinstance(expr, Dict):
  2663. ops = [count_ops(k, visual=visual) +
  2664. count_ops(v, visual=visual) for k, v in expr.items()]
  2665. elif iterable(expr):
  2666. ops = [count_ops(i, visual=visual) for i in expr]
  2667. elif isinstance(expr, (Relational, BooleanFunction)):
  2668. ops = []
  2669. for arg in expr.args:
  2670. ops.append(count_ops(arg, visual=True))
  2671. o = Symbol(func_name(expr, short=True).upper())
  2672. ops.append(o)
  2673. elif not isinstance(expr, Basic):
  2674. ops = []
  2675. else: # it's Basic not isinstance(expr, Expr):
  2676. if not isinstance(expr, Basic):
  2677. raise TypeError("Invalid type of expr")
  2678. else:
  2679. ops = []
  2680. args = [expr]
  2681. while args:
  2682. a = args.pop()
  2683. if a.args:
  2684. o = Symbol(type(a).__name__.upper())
  2685. if a.is_Boolean:
  2686. ops.append(o*(len(a.args)-1))
  2687. else:
  2688. ops.append(o)
  2689. args.extend(a.args)
  2690. if not ops:
  2691. if visual:
  2692. return S.Zero
  2693. return 0
  2694. ops = Add(*ops)
  2695. if visual:
  2696. return ops
  2697. if ops.is_Number:
  2698. return int(ops)
  2699. return sum(int((a.args or [1])[0]) for a in Add.make_args(ops))
  2700. def nfloat(expr, n=15, exponent=False, dkeys=False):
  2701. """Make all Rationals in expr Floats except those in exponents
  2702. (unless the exponents flag is set to True) and those in undefined
  2703. functions. When processing dictionaries, do not modify the keys
  2704. unless ``dkeys=True``.
  2705. Examples
  2706. ========
  2707. >>> from sympy import nfloat, cos, pi, sqrt
  2708. >>> from sympy.abc import x, y
  2709. >>> nfloat(x**4 + x/2 + cos(pi/3) + 1 + sqrt(y))
  2710. x**4 + 0.5*x + sqrt(y) + 1.5
  2711. >>> nfloat(x**4 + sqrt(y), exponent=True)
  2712. x**4.0 + y**0.5
  2713. Container types are not modified:
  2714. >>> type(nfloat((1, 2))) is tuple
  2715. True
  2716. """
  2717. from sympy.matrices.matrices import MatrixBase
  2718. kw = dict(n=n, exponent=exponent, dkeys=dkeys)
  2719. if isinstance(expr, MatrixBase):
  2720. return expr.applyfunc(lambda e: nfloat(e, **kw))
  2721. # handling of iterable containers
  2722. if iterable(expr, exclude=str):
  2723. if isinstance(expr, (dict, Dict)):
  2724. if dkeys:
  2725. args = [tuple(map(lambda i: nfloat(i, **kw), a))
  2726. for a in expr.items()]
  2727. else:
  2728. args = [(k, nfloat(v, **kw)) for k, v in expr.items()]
  2729. if isinstance(expr, dict):
  2730. return type(expr)(args)
  2731. else:
  2732. return expr.func(*args)
  2733. elif isinstance(expr, Basic):
  2734. return expr.func(*[nfloat(a, **kw) for a in expr.args])
  2735. return type(expr)([nfloat(a, **kw) for a in expr])
  2736. rv = sympify(expr)
  2737. if rv.is_Number:
  2738. return Float(rv, n)
  2739. elif rv.is_number:
  2740. # evalf doesn't always set the precision
  2741. rv = rv.n(n)
  2742. if rv.is_Number:
  2743. rv = Float(rv.n(n), n)
  2744. else:
  2745. pass # pure_complex(rv) is likely True
  2746. return rv
  2747. elif rv.is_Atom:
  2748. return rv
  2749. elif rv.is_Relational:
  2750. args_nfloat = (nfloat(arg, **kw) for arg in rv.args)
  2751. return rv.func(*args_nfloat)
  2752. # watch out for RootOf instances that don't like to have
  2753. # their exponents replaced with Dummies and also sometimes have
  2754. # problems with evaluating at low precision (issue 6393)
  2755. from sympy.polys.rootoftools import RootOf
  2756. rv = rv.xreplace({ro: ro.n(n) for ro in rv.atoms(RootOf)})
  2757. from .power import Pow
  2758. if not exponent:
  2759. reps = [(p, Pow(p.base, Dummy())) for p in rv.atoms(Pow)]
  2760. rv = rv.xreplace(dict(reps))
  2761. rv = rv.n(n)
  2762. if not exponent:
  2763. rv = rv.xreplace({d.exp: p.exp for p, d in reps})
  2764. else:
  2765. # Pow._eval_evalf special cases Integer exponents so if
  2766. # exponent is suppose to be handled we have to do so here
  2767. rv = rv.xreplace(Transform(
  2768. lambda x: Pow(x.base, Float(x.exp, n)),
  2769. lambda x: x.is_Pow and x.exp.is_Integer))
  2770. return rv.xreplace(Transform(
  2771. lambda x: x.func(*nfloat(x.args, n, exponent)),
  2772. lambda x: isinstance(x, Function) and not isinstance(x, AppliedUndef)))
  2773. from .symbol import Dummy, Symbol