|
|
""" Generic Rules for SymPy
This file assumes knowledge of Basic and little else. """
from sympy.utilities.iterables import sift from .util import new
# Functions that create rules
def rm_id(isid, new=new): """ Create a rule to remove identities.
isid - fn :: x -> Bool --- whether or not this element is an identity.
Examples ========
>>> from sympy.strategies import rm_id >>> from sympy import Basic, S >>> remove_zeros = rm_id(lambda x: x==0) >>> remove_zeros(Basic(S(1), S(0), S(2))) Basic(1, 2) >>> remove_zeros(Basic(S(0), S(0))) # If only identites then we keep one Basic(0)
See Also: unpack """
def ident_remove(expr): """ Remove identities """ ids = list(map(isid, expr.args)) if sum(ids) == 0: # No identities. Common case return expr elif sum(ids) != len(ids): # there is at least one non-identity return new(expr.__class__, *[arg for arg, x in zip(expr.args, ids) if not x]) else: return new(expr.__class__, expr.args[0])
return ident_remove
def glom(key, count, combine): """ Create a rule to conglomerate identical args.
Examples ========
>>> from sympy.strategies import glom >>> from sympy import Add >>> from sympy.abc import x
>>> key = lambda x: x.as_coeff_Mul()[1] >>> count = lambda x: x.as_coeff_Mul()[0] >>> combine = lambda cnt, arg: cnt * arg >>> rl = glom(key, count, combine)
>>> rl(Add(x, -x, 3*x, 2, 3, evaluate=False)) 3*x + 5
Wait, how are key, count and combine supposed to work?
>>> key(2*x) x >>> count(2*x) 2 >>> combine(2, x) 2*x """
def conglomerate(expr): """ Conglomerate together identical args x + x -> 2x """ groups = sift(expr.args, key) counts = {k: sum(map(count, args)) for k, args in groups.items()} newargs = [combine(cnt, mat) for mat, cnt in counts.items()] if set(newargs) != set(expr.args): return new(type(expr), *newargs) else: return expr
return conglomerate
def sort(key, new=new): """ Create a rule to sort by a key function.
Examples ========
>>> from sympy.strategies import sort >>> from sympy import Basic, S >>> sort_rl = sort(str) >>> sort_rl(Basic(S(3), S(1), S(2))) Basic(1, 2, 3) """
def sort_rl(expr): return new(expr.__class__, *sorted(expr.args, key=key)) return sort_rl
def distribute(A, B): """ Turns an A containing Bs into a B of As
where A, B are container types
>>> from sympy.strategies import distribute >>> from sympy import Add, Mul, symbols >>> x, y = symbols('x,y') >>> dist = distribute(Mul, Add) >>> expr = Mul(2, x+y, evaluate=False) >>> expr 2*(x + y) >>> dist(expr) 2*x + 2*y """
def distribute_rl(expr): for i, arg in enumerate(expr.args): if isinstance(arg, B): first, b, tail = expr.args[:i], expr.args[i], expr.args[i+1:] return B(*[A(*(first + (arg,) + tail)) for arg in b.args]) return expr return distribute_rl
def subs(a, b): """ Replace expressions exactly """ def subs_rl(expr): if expr == a: return b else: return expr return subs_rl
# Functions that are rules
def unpack(expr): """ Rule to unpack singleton args
>>> from sympy.strategies import unpack >>> from sympy import Basic, S >>> unpack(Basic(S(2))) 2 """
if len(expr.args) == 1: return expr.args[0] else: return expr
def flatten(expr, new=new): """ Flatten T(a, b, T(c, d), T2(e)) to T(a, b, c, d, T2(e)) """ cls = expr.__class__ args = [] for arg in expr.args: if arg.__class__ == cls: args.extend(arg.args) else: args.append(arg) return new(expr.__class__, *args)
def rebuild(expr): """ Rebuild a SymPy tree.
Explanation ===========
This function recursively calls constructors in the expression tree. This forces canonicalization and removes ugliness introduced by the use of Basic.__new__ """
if expr.is_Atom: return expr else: return expr.func(*list(map(rebuild, expr.args)))
|