|
|
from sympy.core.basic import Basic from sympy.core.expr import Expr from sympy.core.symbol import Symbol from sympy.core.numbers import Integer, Rational, Float from sympy.printing.repr import srepr
__all__ = ['dotprint']
default_styles = ( (Basic, {'color': 'blue', 'shape': 'ellipse'}), (Expr, {'color': 'black'}) )
slotClasses = (Symbol, Integer, Rational, Float) def purestr(x, with_args=False): """A string that follows ```obj = type(obj)(*obj.args)``` exactly.
Parameters ==========
with_args : boolean, optional If ``True``, there will be a second argument for the return value, which is a tuple containing ``purestr`` applied to each of the subnodes.
If ``False``, there will not be a second argument for the return.
Default is ``False``
Examples ========
>>> from sympy import Float, Symbol, MatrixSymbol >>> from sympy import Integer # noqa: F401 >>> from sympy.core.symbol import Str # noqa: F401 >>> from sympy.printing.dot import purestr
Applying ``purestr`` for basic symbolic object: >>> code = purestr(Symbol('x')) >>> code "Symbol('x')" >>> eval(code) == Symbol('x') True
For basic numeric object: >>> purestr(Float(2)) "Float('2.0', precision=53)"
For matrix symbol: >>> code = purestr(MatrixSymbol('x', 2, 2)) >>> code "MatrixSymbol(Str('x'), Integer(2), Integer(2))" >>> eval(code) == MatrixSymbol('x', 2, 2) True
With ``with_args=True``: >>> purestr(Float(2), with_args=True) ("Float('2.0', precision=53)", ()) >>> purestr(MatrixSymbol('x', 2, 2), with_args=True) ("MatrixSymbol(Str('x'), Integer(2), Integer(2))", ("Str('x')", 'Integer(2)', 'Integer(2)')) """
sargs = () if not isinstance(x, Basic): rv = str(x) elif not x.args: rv = srepr(x) else: args = x.args sargs = tuple(map(purestr, args)) rv = "%s(%s)"%(type(x).__name__, ', '.join(sargs)) if with_args: rv = rv, sargs return rv
def styleof(expr, styles=default_styles): """ Merge style dictionaries in order
Examples ========
>>> from sympy import Symbol, Basic, Expr, S >>> from sympy.printing.dot import styleof >>> styles = [(Basic, {'color': 'blue', 'shape': 'ellipse'}), ... (Expr, {'color': 'black'})]
>>> styleof(Basic(S(1)), styles) {'color': 'blue', 'shape': 'ellipse'}
>>> x = Symbol('x') >>> styleof(x + 1, styles) # this is an Expr {'color': 'black', 'shape': 'ellipse'} """
style = dict() for typ, sty in styles: if isinstance(expr, typ): style.update(sty) return style
def attrprint(d, delimiter=', '): """ Print a dictionary of attributes
Examples ========
>>> from sympy.printing.dot import attrprint >>> print(attrprint({'color': 'blue', 'shape': 'ellipse'})) "color"="blue", "shape"="ellipse" """
return delimiter.join('"%s"="%s"'%item for item in sorted(d.items()))
def dotnode(expr, styles=default_styles, labelfunc=str, pos=(), repeat=True): """ String defining a node
Examples ========
>>> from sympy.printing.dot import dotnode >>> from sympy.abc import x >>> print(dotnode(x)) "Symbol('x')_()" ["color"="black", "label"="x", "shape"="ellipse"]; """
style = styleof(expr, styles)
if isinstance(expr, Basic) and not expr.is_Atom: label = str(expr.__class__.__name__) else: label = labelfunc(expr) style['label'] = label expr_str = purestr(expr) if repeat: expr_str += '_%s' % str(pos) return '"%s" [%s];' % (expr_str, attrprint(style))
def dotedges(expr, atom=lambda x: not isinstance(x, Basic), pos=(), repeat=True): """ List of strings for all expr->expr.arg pairs
See the docstring of dotprint for explanations of the options.
Examples ========
>>> from sympy.printing.dot import dotedges >>> from sympy.abc import x >>> for e in dotedges(x+2): ... print(e) "Add(Integer(2), Symbol('x'))_()" -> "Integer(2)_(0,)"; "Add(Integer(2), Symbol('x'))_()" -> "Symbol('x')_(1,)"; """
if atom(expr): return [] else: expr_str, arg_strs = purestr(expr, with_args=True) if repeat: expr_str += '_%s' % str(pos) arg_strs = ['%s_%s' % (a, str(pos + (i,))) for i, a in enumerate(arg_strs)] return ['"%s" -> "%s";' % (expr_str, a) for a in arg_strs]
template = \ """digraph{
# Graph style %(graphstyle)s
######### # Nodes # #########
%(nodes)s
######### # Edges # #########
%(edges)s }"""
_graphstyle = {'rankdir': 'TD', 'ordering': 'out'}
def dotprint(expr, styles=default_styles, atom=lambda x: not isinstance(x, Basic), maxdepth=None, repeat=True, labelfunc=str, **kwargs): """DOT description of a SymPy expression tree
Parameters ==========
styles : list of lists composed of (Class, mapping), optional Styles for different classes.
The default is
.. code-block:: python
( (Basic, {'color': 'blue', 'shape': 'ellipse'}), (Expr, {'color': 'black'}) )
atom : function, optional Function used to determine if an arg is an atom.
A good choice is ``lambda x: not x.args``.
The default is ``lambda x: not isinstance(x, Basic)``.
maxdepth : integer, optional The maximum depth.
The default is ``None``, meaning no limit.
repeat : boolean, optional Whether to use different nodes for common subexpressions.
The default is ``True``.
For example, for ``x + x*y`` with ``repeat=True``, it will have two nodes for ``x``; with ``repeat=False``, it will have one node.
.. warning:: Even if a node appears twice in the same object like ``x`` in ``Pow(x, x)``, it will still only appear once. Hence, with ``repeat=False``, the number of arrows out of an object might not equal the number of args it has.
labelfunc : function, optional A function to create a label for a given leaf node.
The default is ``str``.
Another good option is ``srepr``.
For example with ``str``, the leaf nodes of ``x + 1`` are labeled, ``x`` and ``1``. With ``srepr``, they are labeled ``Symbol('x')`` and ``Integer(1)``.
**kwargs : optional Additional keyword arguments are included as styles for the graph.
Examples ========
>>> from sympy import dotprint >>> from sympy.abc import x >>> print(dotprint(x+2)) # doctest: +NORMALIZE_WHITESPACE digraph{ <BLANKLINE> # Graph style "ordering"="out" "rankdir"="TD" <BLANKLINE> ######### # Nodes # ######### <BLANKLINE> "Add(Integer(2), Symbol('x'))_()" ["color"="black", "label"="Add", "shape"="ellipse"]; "Integer(2)_(0,)" ["color"="black", "label"="2", "shape"="ellipse"]; "Symbol('x')_(1,)" ["color"="black", "label"="x", "shape"="ellipse"]; <BLANKLINE> ######### # Edges # ######### <BLANKLINE> "Add(Integer(2), Symbol('x'))_()" -> "Integer(2)_(0,)"; "Add(Integer(2), Symbol('x'))_()" -> "Symbol('x')_(1,)"; }
"""
# repeat works by adding a signature tuple to the end of each node for its # position in the graph. For example, for expr = Add(x, Pow(x, 2)), the x in the # Pow will have the tuple (1, 0), meaning it is expr.args[1].args[0]. graphstyle = _graphstyle.copy() graphstyle.update(kwargs)
nodes = [] edges = [] def traverse(e, depth, pos=()): nodes.append(dotnode(e, styles, labelfunc=labelfunc, pos=pos, repeat=repeat)) if maxdepth and depth >= maxdepth: return edges.extend(dotedges(e, atom=atom, pos=pos, repeat=repeat)) [traverse(arg, depth+1, pos + (i,)) for i, arg in enumerate(e.args) if not atom(arg)] traverse(expr, 0)
return template%{'graphstyle': attrprint(graphstyle, delimiter='\n'), 'nodes': '\n'.join(nodes), 'edges': '\n'.join(edges)}
|