图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

52 lines
1.6 KiB

  1. """ Optimizations of the expression tree representation for better CSE
  2. opportunities.
  3. """
  4. from sympy.core import Add, Basic, Mul
  5. from sympy.core.singleton import S
  6. from sympy.core.sorting import default_sort_key
  7. from sympy.core.traversal import preorder_traversal
  8. def sub_pre(e):
  9. """ Replace y - x with -(x - y) if -1 can be extracted from y - x.
  10. """
  11. # replacing Add, A, from which -1 can be extracted with -1*-A
  12. adds = [a for a in e.atoms(Add) if a.could_extract_minus_sign()]
  13. reps = {}
  14. ignore = set()
  15. for a in adds:
  16. na = -a
  17. if na.is_Mul: # e.g. MatExpr
  18. ignore.add(a)
  19. continue
  20. reps[a] = Mul._from_args([S.NegativeOne, na])
  21. e = e.xreplace(reps)
  22. # repeat again for persisting Adds but mark these with a leading 1, -1
  23. # e.g. y - x -> 1*-1*(x - y)
  24. if isinstance(e, Basic):
  25. negs = {}
  26. for a in sorted(e.atoms(Add), key=default_sort_key):
  27. if a in ignore:
  28. continue
  29. if a in reps:
  30. negs[a] = reps[a]
  31. elif a.could_extract_minus_sign():
  32. negs[a] = Mul._from_args([S.One, S.NegativeOne, -a])
  33. e = e.xreplace(negs)
  34. return e
  35. def sub_post(e):
  36. """ Replace 1*-1*x with -x.
  37. """
  38. replacements = []
  39. for node in preorder_traversal(e):
  40. if isinstance(node, Mul) and \
  41. node.args[0] is S.One and node.args[1] is S.NegativeOne:
  42. replacements.append((node, -Mul._from_args(node.args[2:])))
  43. for node, replacement in replacements:
  44. e = e.xreplace({node: replacement})
  45. return e