图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

294 lines
8.1 KiB

  1. from sympy.core.basic import Basic
  2. from sympy.core.expr import Expr
  3. from sympy.core.symbol import Symbol
  4. from sympy.core.numbers import Integer, Rational, Float
  5. from sympy.printing.repr import srepr
  6. __all__ = ['dotprint']
  7. default_styles = (
  8. (Basic, {'color': 'blue', 'shape': 'ellipse'}),
  9. (Expr, {'color': 'black'})
  10. )
  11. slotClasses = (Symbol, Integer, Rational, Float)
  12. def purestr(x, with_args=False):
  13. """A string that follows ```obj = type(obj)(*obj.args)``` exactly.
  14. Parameters
  15. ==========
  16. with_args : boolean, optional
  17. If ``True``, there will be a second argument for the return
  18. value, which is a tuple containing ``purestr`` applied to each
  19. of the subnodes.
  20. If ``False``, there will not be a second argument for the
  21. return.
  22. Default is ``False``
  23. Examples
  24. ========
  25. >>> from sympy import Float, Symbol, MatrixSymbol
  26. >>> from sympy import Integer # noqa: F401
  27. >>> from sympy.core.symbol import Str # noqa: F401
  28. >>> from sympy.printing.dot import purestr
  29. Applying ``purestr`` for basic symbolic object:
  30. >>> code = purestr(Symbol('x'))
  31. >>> code
  32. "Symbol('x')"
  33. >>> eval(code) == Symbol('x')
  34. True
  35. For basic numeric object:
  36. >>> purestr(Float(2))
  37. "Float('2.0', precision=53)"
  38. For matrix symbol:
  39. >>> code = purestr(MatrixSymbol('x', 2, 2))
  40. >>> code
  41. "MatrixSymbol(Str('x'), Integer(2), Integer(2))"
  42. >>> eval(code) == MatrixSymbol('x', 2, 2)
  43. True
  44. With ``with_args=True``:
  45. >>> purestr(Float(2), with_args=True)
  46. ("Float('2.0', precision=53)", ())
  47. >>> purestr(MatrixSymbol('x', 2, 2), with_args=True)
  48. ("MatrixSymbol(Str('x'), Integer(2), Integer(2))",
  49. ("Str('x')", 'Integer(2)', 'Integer(2)'))
  50. """
  51. sargs = ()
  52. if not isinstance(x, Basic):
  53. rv = str(x)
  54. elif not x.args:
  55. rv = srepr(x)
  56. else:
  57. args = x.args
  58. sargs = tuple(map(purestr, args))
  59. rv = "%s(%s)"%(type(x).__name__, ', '.join(sargs))
  60. if with_args:
  61. rv = rv, sargs
  62. return rv
  63. def styleof(expr, styles=default_styles):
  64. """ Merge style dictionaries in order
  65. Examples
  66. ========
  67. >>> from sympy import Symbol, Basic, Expr, S
  68. >>> from sympy.printing.dot import styleof
  69. >>> styles = [(Basic, {'color': 'blue', 'shape': 'ellipse'}),
  70. ... (Expr, {'color': 'black'})]
  71. >>> styleof(Basic(S(1)), styles)
  72. {'color': 'blue', 'shape': 'ellipse'}
  73. >>> x = Symbol('x')
  74. >>> styleof(x + 1, styles) # this is an Expr
  75. {'color': 'black', 'shape': 'ellipse'}
  76. """
  77. style = dict()
  78. for typ, sty in styles:
  79. if isinstance(expr, typ):
  80. style.update(sty)
  81. return style
  82. def attrprint(d, delimiter=', '):
  83. """ Print a dictionary of attributes
  84. Examples
  85. ========
  86. >>> from sympy.printing.dot import attrprint
  87. >>> print(attrprint({'color': 'blue', 'shape': 'ellipse'}))
  88. "color"="blue", "shape"="ellipse"
  89. """
  90. return delimiter.join('"%s"="%s"'%item for item in sorted(d.items()))
  91. def dotnode(expr, styles=default_styles, labelfunc=str, pos=(), repeat=True):
  92. """ String defining a node
  93. Examples
  94. ========
  95. >>> from sympy.printing.dot import dotnode
  96. >>> from sympy.abc import x
  97. >>> print(dotnode(x))
  98. "Symbol('x')_()" ["color"="black", "label"="x", "shape"="ellipse"];
  99. """
  100. style = styleof(expr, styles)
  101. if isinstance(expr, Basic) and not expr.is_Atom:
  102. label = str(expr.__class__.__name__)
  103. else:
  104. label = labelfunc(expr)
  105. style['label'] = label
  106. expr_str = purestr(expr)
  107. if repeat:
  108. expr_str += '_%s' % str(pos)
  109. return '"%s" [%s];' % (expr_str, attrprint(style))
  110. def dotedges(expr, atom=lambda x: not isinstance(x, Basic), pos=(), repeat=True):
  111. """ List of strings for all expr->expr.arg pairs
  112. See the docstring of dotprint for explanations of the options.
  113. Examples
  114. ========
  115. >>> from sympy.printing.dot import dotedges
  116. >>> from sympy.abc import x
  117. >>> for e in dotedges(x+2):
  118. ... print(e)
  119. "Add(Integer(2), Symbol('x'))_()" -> "Integer(2)_(0,)";
  120. "Add(Integer(2), Symbol('x'))_()" -> "Symbol('x')_(1,)";
  121. """
  122. if atom(expr):
  123. return []
  124. else:
  125. expr_str, arg_strs = purestr(expr, with_args=True)
  126. if repeat:
  127. expr_str += '_%s' % str(pos)
  128. arg_strs = ['%s_%s' % (a, str(pos + (i,)))
  129. for i, a in enumerate(arg_strs)]
  130. return ['"%s" -> "%s";' % (expr_str, a) for a in arg_strs]
  131. template = \
  132. """digraph{
  133. # Graph style
  134. %(graphstyle)s
  135. #########
  136. # Nodes #
  137. #########
  138. %(nodes)s
  139. #########
  140. # Edges #
  141. #########
  142. %(edges)s
  143. }"""
  144. _graphstyle = {'rankdir': 'TD', 'ordering': 'out'}
  145. def dotprint(expr,
  146. styles=default_styles, atom=lambda x: not isinstance(x, Basic),
  147. maxdepth=None, repeat=True, labelfunc=str, **kwargs):
  148. """DOT description of a SymPy expression tree
  149. Parameters
  150. ==========
  151. styles : list of lists composed of (Class, mapping), optional
  152. Styles for different classes.
  153. The default is
  154. .. code-block:: python
  155. (
  156. (Basic, {'color': 'blue', 'shape': 'ellipse'}),
  157. (Expr, {'color': 'black'})
  158. )
  159. atom : function, optional
  160. Function used to determine if an arg is an atom.
  161. A good choice is ``lambda x: not x.args``.
  162. The default is ``lambda x: not isinstance(x, Basic)``.
  163. maxdepth : integer, optional
  164. The maximum depth.
  165. The default is ``None``, meaning no limit.
  166. repeat : boolean, optional
  167. Whether to use different nodes for common subexpressions.
  168. The default is ``True``.
  169. For example, for ``x + x*y`` with ``repeat=True``, it will have
  170. two nodes for ``x``; with ``repeat=False``, it will have one
  171. node.
  172. .. warning::
  173. Even if a node appears twice in the same object like ``x`` in
  174. ``Pow(x, x)``, it will still only appear once.
  175. Hence, with ``repeat=False``, the number of arrows out of an
  176. object might not equal the number of args it has.
  177. labelfunc : function, optional
  178. A function to create a label for a given leaf node.
  179. The default is ``str``.
  180. Another good option is ``srepr``.
  181. For example with ``str``, the leaf nodes of ``x + 1`` are labeled,
  182. ``x`` and ``1``. With ``srepr``, they are labeled ``Symbol('x')``
  183. and ``Integer(1)``.
  184. **kwargs : optional
  185. Additional keyword arguments are included as styles for the graph.
  186. Examples
  187. ========
  188. >>> from sympy import dotprint
  189. >>> from sympy.abc import x
  190. >>> print(dotprint(x+2)) # doctest: +NORMALIZE_WHITESPACE
  191. digraph{
  192. <BLANKLINE>
  193. # Graph style
  194. "ordering"="out"
  195. "rankdir"="TD"
  196. <BLANKLINE>
  197. #########
  198. # Nodes #
  199. #########
  200. <BLANKLINE>
  201. "Add(Integer(2), Symbol('x'))_()" ["color"="black", "label"="Add", "shape"="ellipse"];
  202. "Integer(2)_(0,)" ["color"="black", "label"="2", "shape"="ellipse"];
  203. "Symbol('x')_(1,)" ["color"="black", "label"="x", "shape"="ellipse"];
  204. <BLANKLINE>
  205. #########
  206. # Edges #
  207. #########
  208. <BLANKLINE>
  209. "Add(Integer(2), Symbol('x'))_()" -> "Integer(2)_(0,)";
  210. "Add(Integer(2), Symbol('x'))_()" -> "Symbol('x')_(1,)";
  211. }
  212. """
  213. # repeat works by adding a signature tuple to the end of each node for its
  214. # position in the graph. For example, for expr = Add(x, Pow(x, 2)), the x in the
  215. # Pow will have the tuple (1, 0), meaning it is expr.args[1].args[0].
  216. graphstyle = _graphstyle.copy()
  217. graphstyle.update(kwargs)
  218. nodes = []
  219. edges = []
  220. def traverse(e, depth, pos=()):
  221. nodes.append(dotnode(e, styles, labelfunc=labelfunc, pos=pos, repeat=repeat))
  222. if maxdepth and depth >= maxdepth:
  223. return
  224. edges.extend(dotedges(e, atom=atom, pos=pos, repeat=repeat))
  225. [traverse(arg, depth+1, pos + (i,)) for i, arg in enumerate(e.args) if not atom(arg)]
  226. traverse(expr, 0)
  227. return template%{'graphstyle': attrprint(graphstyle, delimiter='\n'),
  228. 'nodes': '\n'.join(nodes),
  229. 'edges': '\n'.join(edges)}