m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

170 lines
4.3 KiB

6 months ago
  1. """ Generic Rules for SymPy
  2. This file assumes knowledge of Basic and little else.
  3. """
  4. from sympy.utilities.iterables import sift
  5. from .util import new
  6. # Functions that create rules
  7. def rm_id(isid, new=new):
  8. """ Create a rule to remove identities.
  9. isid - fn :: x -> Bool --- whether or not this element is an identity.
  10. Examples
  11. ========
  12. >>> from sympy.strategies import rm_id
  13. >>> from sympy import Basic, S
  14. >>> remove_zeros = rm_id(lambda x: x==0)
  15. >>> remove_zeros(Basic(S(1), S(0), S(2)))
  16. Basic(1, 2)
  17. >>> remove_zeros(Basic(S(0), S(0))) # If only identites then we keep one
  18. Basic(0)
  19. See Also:
  20. unpack
  21. """
  22. def ident_remove(expr):
  23. """ Remove identities """
  24. ids = list(map(isid, expr.args))
  25. if sum(ids) == 0: # No identities. Common case
  26. return expr
  27. elif sum(ids) != len(ids): # there is at least one non-identity
  28. return new(expr.__class__,
  29. *[arg for arg, x in zip(expr.args, ids) if not x])
  30. else:
  31. return new(expr.__class__, expr.args[0])
  32. return ident_remove
  33. def glom(key, count, combine):
  34. """ Create a rule to conglomerate identical args.
  35. Examples
  36. ========
  37. >>> from sympy.strategies import glom
  38. >>> from sympy import Add
  39. >>> from sympy.abc import x
  40. >>> key = lambda x: x.as_coeff_Mul()[1]
  41. >>> count = lambda x: x.as_coeff_Mul()[0]
  42. >>> combine = lambda cnt, arg: cnt * arg
  43. >>> rl = glom(key, count, combine)
  44. >>> rl(Add(x, -x, 3*x, 2, 3, evaluate=False))
  45. 3*x + 5
  46. Wait, how are key, count and combine supposed to work?
  47. >>> key(2*x)
  48. x
  49. >>> count(2*x)
  50. 2
  51. >>> combine(2, x)
  52. 2*x
  53. """
  54. def conglomerate(expr):
  55. """ Conglomerate together identical args x + x -> 2x """
  56. groups = sift(expr.args, key)
  57. counts = {k: sum(map(count, args)) for k, args in groups.items()}
  58. newargs = [combine(cnt, mat) for mat, cnt in counts.items()]
  59. if set(newargs) != set(expr.args):
  60. return new(type(expr), *newargs)
  61. else:
  62. return expr
  63. return conglomerate
  64. def sort(key, new=new):
  65. """ Create a rule to sort by a key function.
  66. Examples
  67. ========
  68. >>> from sympy.strategies import sort
  69. >>> from sympy import Basic, S
  70. >>> sort_rl = sort(str)
  71. >>> sort_rl(Basic(S(3), S(1), S(2)))
  72. Basic(1, 2, 3)
  73. """
  74. def sort_rl(expr):
  75. return new(expr.__class__, *sorted(expr.args, key=key))
  76. return sort_rl
  77. def distribute(A, B):
  78. """ Turns an A containing Bs into a B of As
  79. where A, B are container types
  80. >>> from sympy.strategies import distribute
  81. >>> from sympy import Add, Mul, symbols
  82. >>> x, y = symbols('x,y')
  83. >>> dist = distribute(Mul, Add)
  84. >>> expr = Mul(2, x+y, evaluate=False)
  85. >>> expr
  86. 2*(x + y)
  87. >>> dist(expr)
  88. 2*x + 2*y
  89. """
  90. def distribute_rl(expr):
  91. for i, arg in enumerate(expr.args):
  92. if isinstance(arg, B):
  93. first, b, tail = expr.args[:i], expr.args[i], expr.args[i+1:]
  94. return B(*[A(*(first + (arg,) + tail)) for arg in b.args])
  95. return expr
  96. return distribute_rl
  97. def subs(a, b):
  98. """ Replace expressions exactly """
  99. def subs_rl(expr):
  100. if expr == a:
  101. return b
  102. else:
  103. return expr
  104. return subs_rl
  105. # Functions that are rules
  106. def unpack(expr):
  107. """ Rule to unpack singleton args
  108. >>> from sympy.strategies import unpack
  109. >>> from sympy import Basic, S
  110. >>> unpack(Basic(S(2)))
  111. 2
  112. """
  113. if len(expr.args) == 1:
  114. return expr.args[0]
  115. else:
  116. return expr
  117. def flatten(expr, new=new):
  118. """ Flatten T(a, b, T(c, d), T2(e)) to T(a, b, c, d, T2(e)) """
  119. cls = expr.__class__
  120. args = []
  121. for arg in expr.args:
  122. if arg.__class__ == cls:
  123. args.extend(arg.args)
  124. else:
  125. args.append(arg)
  126. return new(expr.__class__, *args)
  127. def rebuild(expr):
  128. """ Rebuild a SymPy tree.
  129. Explanation
  130. ===========
  131. This function recursively calls constructors in the expression tree.
  132. This forces canonicalization and removes ugliness introduced by the use of
  133. Basic.__new__
  134. """
  135. if expr.is_Atom:
  136. return expr
  137. else:
  138. return expr.func(*list(map(rebuild, expr.args)))