m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

311 lines
8.8 KiB

6 months ago
  1. from .basic import Basic
  2. from .sorting import ordered
  3. from .sympify import sympify
  4. from sympy.utilities.iterables import iterable
  5. def iterargs(expr):
  6. """Yield the args of a Basic object in a breadth-first traversal.
  7. Depth-traversal stops if `arg.args` is either empty or is not
  8. an iterable.
  9. Examples
  10. ========
  11. >>> from sympy import Integral, Function
  12. >>> from sympy.abc import x
  13. >>> f = Function('f')
  14. >>> from sympy.core.traversal import iterargs
  15. >>> list(iterargs(Integral(f(x), (f(x), 1))))
  16. [Integral(f(x), (f(x), 1)), f(x), (f(x), 1), x, f(x), 1, x]
  17. See Also
  18. ========
  19. iterfreeargs, preorder_traversal
  20. """
  21. args = [expr]
  22. for i in args:
  23. yield i
  24. try:
  25. args.extend(i.args)
  26. except TypeError:
  27. pass # for cases like f being an arg
  28. def iterfreeargs(expr, _first=True):
  29. """Yield the args of a Basic object in a breadth-first traversal.
  30. Depth-traversal stops if `arg.args` is either empty or is not
  31. an iterable. The bound objects of an expression will be returned
  32. as canonical variables.
  33. Examples
  34. ========
  35. >>> from sympy import Integral, Function
  36. >>> from sympy.abc import x
  37. >>> f = Function('f')
  38. >>> from sympy.core.traversal import iterfreeargs
  39. >>> list(iterfreeargs(Integral(f(x), (f(x), 1))))
  40. [Integral(f(x), (f(x), 1)), 1]
  41. See Also
  42. ========
  43. iterargs, preorder_traversal
  44. """
  45. args = [expr]
  46. for i in args:
  47. yield i
  48. if _first and hasattr(i, 'bound_symbols'):
  49. void = i.canonical_variables.values()
  50. for i in iterfreeargs(i.as_dummy(), _first=False):
  51. if not i.has(*void):
  52. yield i
  53. try:
  54. args.extend(i.args)
  55. except TypeError:
  56. pass # for cases like f being an arg
  57. class preorder_traversal:
  58. """
  59. Do a pre-order traversal of a tree.
  60. This iterator recursively yields nodes that it has visited in a pre-order
  61. fashion. That is, it yields the current node then descends through the
  62. tree breadth-first to yield all of a node's children's pre-order
  63. traversal.
  64. For an expression, the order of the traversal depends on the order of
  65. .args, which in many cases can be arbitrary.
  66. Parameters
  67. ==========
  68. node : SymPy expression
  69. The expression to traverse.
  70. keys : (default None) sort key(s)
  71. The key(s) used to sort args of Basic objects. When None, args of Basic
  72. objects are processed in arbitrary order. If key is defined, it will
  73. be passed along to ordered() as the only key(s) to use to sort the
  74. arguments; if ``key`` is simply True then the default keys of ordered
  75. will be used.
  76. Yields
  77. ======
  78. subtree : SymPy expression
  79. All of the subtrees in the tree.
  80. Examples
  81. ========
  82. >>> from sympy import preorder_traversal, symbols
  83. >>> x, y, z = symbols('x y z')
  84. The nodes are returned in the order that they are encountered unless key
  85. is given; simply passing key=True will guarantee that the traversal is
  86. unique.
  87. >>> list(preorder_traversal((x + y)*z, keys=None)) # doctest: +SKIP
  88. [z*(x + y), z, x + y, y, x]
  89. >>> list(preorder_traversal((x + y)*z, keys=True))
  90. [z*(x + y), z, x + y, x, y]
  91. """
  92. def __init__(self, node, keys=None):
  93. self._skip_flag = False
  94. self._pt = self._preorder_traversal(node, keys)
  95. def _preorder_traversal(self, node, keys):
  96. yield node
  97. if self._skip_flag:
  98. self._skip_flag = False
  99. return
  100. if isinstance(node, Basic):
  101. if not keys and hasattr(node, '_argset'):
  102. # LatticeOp keeps args as a set. We should use this if we
  103. # don't care about the order, to prevent unnecessary sorting.
  104. args = node._argset
  105. else:
  106. args = node.args
  107. if keys:
  108. if keys != True:
  109. args = ordered(args, keys, default=False)
  110. else:
  111. args = ordered(args)
  112. for arg in args:
  113. yield from self._preorder_traversal(arg, keys)
  114. elif iterable(node):
  115. for item in node:
  116. yield from self._preorder_traversal(item, keys)
  117. def skip(self):
  118. """
  119. Skip yielding current node's (last yielded node's) subtrees.
  120. Examples
  121. ========
  122. >>> from sympy import preorder_traversal, symbols
  123. >>> x, y, z = symbols('x y z')
  124. >>> pt = preorder_traversal((x + y*z)*z)
  125. >>> for i in pt:
  126. ... print(i)
  127. ... if i == x + y*z:
  128. ... pt.skip()
  129. z*(x + y*z)
  130. z
  131. x + y*z
  132. """
  133. self._skip_flag = True
  134. def __next__(self):
  135. return next(self._pt)
  136. def __iter__(self):
  137. return self
  138. def use(expr, func, level=0, args=(), kwargs={}):
  139. """
  140. Use ``func`` to transform ``expr`` at the given level.
  141. Examples
  142. ========
  143. >>> from sympy import use, expand
  144. >>> from sympy.abc import x, y
  145. >>> f = (x + y)**2*x + 1
  146. >>> use(f, expand, level=2)
  147. x*(x**2 + 2*x*y + y**2) + 1
  148. >>> expand(f)
  149. x**3 + 2*x**2*y + x*y**2 + 1
  150. """
  151. def _use(expr, level):
  152. if not level:
  153. return func(expr, *args, **kwargs)
  154. else:
  155. if expr.is_Atom:
  156. return expr
  157. else:
  158. level -= 1
  159. _args = []
  160. for arg in expr.args:
  161. _args.append(_use(arg, level))
  162. return expr.__class__(*_args)
  163. return _use(sympify(expr), level)
  164. def walk(e, *target):
  165. """Iterate through the args that are the given types (target) and
  166. return a list of the args that were traversed; arguments
  167. that are not of the specified types are not traversed.
  168. Examples
  169. ========
  170. >>> from sympy.core.traversal import walk
  171. >>> from sympy import Min, Max
  172. >>> from sympy.abc import x, y, z
  173. >>> list(walk(Min(x, Max(y, Min(1, z))), Min))
  174. [Min(x, Max(y, Min(1, z)))]
  175. >>> list(walk(Min(x, Max(y, Min(1, z))), Min, Max))
  176. [Min(x, Max(y, Min(1, z))), Max(y, Min(1, z)), Min(1, z)]
  177. See Also
  178. ========
  179. bottom_up
  180. """
  181. if isinstance(e, target):
  182. yield e
  183. for i in e.args:
  184. yield from walk(i, *target)
  185. def bottom_up(rv, F, atoms=False, nonbasic=False):
  186. """Apply ``F`` to all expressions in an expression tree from the
  187. bottom up. If ``atoms`` is True, apply ``F`` even if there are no args;
  188. if ``nonbasic`` is True, try to apply ``F`` to non-Basic objects.
  189. """
  190. args = getattr(rv, 'args', None)
  191. if args is not None:
  192. if args:
  193. args = tuple([bottom_up(a, F, atoms, nonbasic) for a in args])
  194. if args != rv.args:
  195. rv = rv.func(*args)
  196. rv = F(rv)
  197. elif atoms:
  198. rv = F(rv)
  199. else:
  200. if nonbasic:
  201. try:
  202. rv = F(rv)
  203. except TypeError:
  204. pass
  205. return rv
  206. def postorder_traversal(node, keys=None):
  207. """
  208. Do a postorder traversal of a tree.
  209. This generator recursively yields nodes that it has visited in a postorder
  210. fashion. That is, it descends through the tree depth-first to yield all of
  211. a node's children's postorder traversal before yielding the node itself.
  212. Parameters
  213. ==========
  214. node : SymPy expression
  215. The expression to traverse.
  216. keys : (default None) sort key(s)
  217. The key(s) used to sort args of Basic objects. When None, args of Basic
  218. objects are processed in arbitrary order. If key is defined, it will
  219. be passed along to ordered() as the only key(s) to use to sort the
  220. arguments; if ``key`` is simply True then the default keys of
  221. ``ordered`` will be used (node count and default_sort_key).
  222. Yields
  223. ======
  224. subtree : SymPy expression
  225. All of the subtrees in the tree.
  226. Examples
  227. ========
  228. >>> from sympy import postorder_traversal
  229. >>> from sympy.abc import w, x, y, z
  230. The nodes are returned in the order that they are encountered unless key
  231. is given; simply passing key=True will guarantee that the traversal is
  232. unique.
  233. >>> list(postorder_traversal(w + (x + y)*z)) # doctest: +SKIP
  234. [z, y, x, x + y, z*(x + y), w, w + z*(x + y)]
  235. >>> list(postorder_traversal(w + (x + y)*z, keys=True))
  236. [w, z, x, y, x + y, z*(x + y), w + z*(x + y)]
  237. """
  238. if isinstance(node, Basic):
  239. args = node.args
  240. if keys:
  241. if keys != True:
  242. args = ordered(args, keys, default=False)
  243. else:
  244. args = ordered(args)
  245. for arg in args:
  246. yield from postorder_traversal(arg, keys)
  247. elif iterable(node):
  248. for item in node:
  249. yield from postorder_traversal(item, keys)
  250. yield node