|
|
from typing import Set as tSet
from warnings import warn import inspect from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning from .utils import expand_tuples import itertools as itl
class MDNotImplementedError(NotImplementedError): """ A NotImplementedError for multiple dispatch """
### Functions for on_ambiguity
def ambiguity_warn(dispatcher, ambiguities): """ Raise warning when ambiguity is detected
Parameters ---------- dispatcher : Dispatcher The dispatcher on which the ambiguity was detected ambiguities : set Set of type signature pairs that are ambiguous within this dispatcher
See Also: Dispatcher.add warning_text """
warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning)
class RaiseNotImplementedError: """Raise ``NotImplementedError`` when called."""
def __init__(self, dispatcher): self.dispatcher = dispatcher
def __call__(self, *args, **kwargs): types = tuple(type(a) for a in args) raise NotImplementedError( "Ambiguous signature for %s: <%s>" % ( self.dispatcher.name, str_signature(types) ))
def ambiguity_register_error_ignore_dup(dispatcher, ambiguities): """
If super signature for ambiguous types is duplicate types, ignore it. Else, register instance of ``RaiseNotImplementedError`` for ambiguous types.
Parameters ---------- dispatcher : Dispatcher The dispatcher on which the ambiguity was detected ambiguities : set Set of type signature pairs that are ambiguous within this dispatcher
See Also: Dispatcher.add ambiguity_warn """
for amb in ambiguities: signature = tuple(super_signature(amb)) if len(set(signature)) == 1: continue dispatcher.add( signature, RaiseNotImplementedError(dispatcher), on_ambiguity=ambiguity_register_error_ignore_dup )
###
_unresolved_dispatchers = set() # type: tSet[Dispatcher] _resolve = [True]
def halt_ordering(): _resolve[0] = False
def restart_ordering(on_ambiguity=ambiguity_warn): _resolve[0] = True while _unresolved_dispatchers: dispatcher = _unresolved_dispatchers.pop() dispatcher.reorder(on_ambiguity=on_ambiguity)
class Dispatcher: """ Dispatch methods based on type signature
Use ``dispatch`` to add implementations
Examples --------
>>> from sympy.multipledispatch import dispatch >>> @dispatch(int) ... def f(x): ... return x + 1
>>> @dispatch(float) ... def f(x): # noqa: F811 ... return x - 1
>>> f(3) 4 >>> f(3.0) 2.0 """
__slots__ = '__name__', 'name', 'funcs', 'ordering', '_cache', 'doc'
def __init__(self, name, doc=None): self.name = self.__name__ = name self.funcs = dict() self._cache = dict() self.ordering = [] self.doc = doc
def register(self, *types, **kwargs): """ Register dispatcher with new implementation
>>> from sympy.multipledispatch.dispatcher import Dispatcher >>> f = Dispatcher('f') >>> @f.register(int) ... def inc(x): ... return x + 1
>>> @f.register(float) ... def dec(x): ... return x - 1
>>> @f.register(list) ... @f.register(tuple) ... def reverse(x): ... return x[::-1]
>>> f(1) 2
>>> f(1.0) 0.0
>>> f([1, 2, 3]) [3, 2, 1] """
def _(func): self.add(types, func, **kwargs) return func return _
@classmethod def get_func_params(cls, func): if hasattr(inspect, "signature"): sig = inspect.signature(func) return sig.parameters.values()
@classmethod def get_func_annotations(cls, func): """ Get annotations of function positional parameters
"""
params = cls.get_func_params(func) if params: Parameter = inspect.Parameter
params = (param for param in params if param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD))
annotations = tuple( param.annotation for param in params)
if not any(ann is Parameter.empty for ann in annotations): return annotations
def add(self, signature, func, on_ambiguity=ambiguity_warn): """ Add new types/method pair to dispatcher
>>> from sympy.multipledispatch import Dispatcher >>> D = Dispatcher('add') >>> D.add((int, int), lambda x, y: x + y) >>> D.add((float, float), lambda x, y: x + y)
>>> D(1, 2) 3 >>> D(1, 2.0) Traceback (most recent call last): ... NotImplementedError: Could not find signature for add: <int, float>
When ``add`` detects a warning it calls the ``on_ambiguity`` callback with a dispatcher/itself, and a set of ambiguous type signature pairs as inputs. See ``ambiguity_warn`` for an example. """
# Handle annotations if not signature: annotations = self.get_func_annotations(func) if annotations: signature = annotations
# Handle union types if any(isinstance(typ, tuple) for typ in signature): for typs in expand_tuples(signature): self.add(typs, func, on_ambiguity) return
for typ in signature: if not isinstance(typ, type): str_sig = ', '.join(c.__name__ if isinstance(c, type) else str(c) for c in signature) raise TypeError("Tried to dispatch on non-type: %s\n" "In signature: <%s>\n" "In function: %s" % (typ, str_sig, self.name))
self.funcs[signature] = func self.reorder(on_ambiguity=on_ambiguity) self._cache.clear()
def reorder(self, on_ambiguity=ambiguity_warn): if _resolve[0]: self.ordering = ordering(self.funcs) amb = ambiguities(self.funcs) if amb: on_ambiguity(self, amb) else: _unresolved_dispatchers.add(self)
def __call__(self, *args, **kwargs): types = tuple([type(arg) for arg in args]) try: func = self._cache[types] except KeyError: func = self.dispatch(*types) if not func: raise NotImplementedError( 'Could not find signature for %s: <%s>' % (self.name, str_signature(types))) self._cache[types] = func try: return func(*args, **kwargs)
except MDNotImplementedError: funcs = self.dispatch_iter(*types) next(funcs) # burn first for func in funcs: try: return func(*args, **kwargs) except MDNotImplementedError: pass raise NotImplementedError("Matching functions for " "%s: <%s> found, but none completed successfully" % (self.name, str_signature(types)))
def __str__(self): return "<dispatched %s>" % self.name __repr__ = __str__
def dispatch(self, *types): """ Deterimine appropriate implementation for this type signature
This method is internal. Users should call this object as a function. Implementation resolution occurs within the ``__call__`` method.
>>> from sympy.multipledispatch import dispatch >>> @dispatch(int) ... def inc(x): ... return x + 1
>>> implementation = inc.dispatch(int) >>> implementation(3) 4
>>> print(inc.dispatch(float)) None
See Also: ``sympy.multipledispatch.conflict`` - module to determine resolution order """
if types in self.funcs: return self.funcs[types]
try: return next(self.dispatch_iter(*types)) except StopIteration: return None
def dispatch_iter(self, *types): n = len(types) for signature in self.ordering: if len(signature) == n and all(map(issubclass, types, signature)): result = self.funcs[signature] yield result
def resolve(self, types): """ Deterimine appropriate implementation for this type signature
.. deprecated:: 0.4.4 Use ``dispatch(*types)`` instead """
warn("resolve() is deprecated, use dispatch(*types)", DeprecationWarning)
return self.dispatch(*types)
def __getstate__(self): return {'name': self.name, 'funcs': self.funcs}
def __setstate__(self, d): self.name = d['name'] self.funcs = d['funcs'] self.ordering = ordering(self.funcs) self._cache = dict()
@property def __doc__(self): docs = ["Multiply dispatched method: %s" % self.name]
if self.doc: docs.append(self.doc)
other = [] for sig in self.ordering[::-1]: func = self.funcs[sig] if func.__doc__: s = 'Inputs: <%s>\n' % str_signature(sig) s += '-' * len(s) + '\n' s += func.__doc__.strip() docs.append(s) else: other.append(str_signature(sig))
if other: docs.append('Other signatures:\n ' + '\n '.join(other))
return '\n\n'.join(docs)
def _help(self, *args): return self.dispatch(*map(type, args)).__doc__
def help(self, *args, **kwargs): """ Print docstring for the function corresponding to inputs """ print(self._help(*args))
def _source(self, *args): func = self.dispatch(*map(type, args)) if not func: raise TypeError("No function found") return source(func)
def source(self, *args, **kwargs): """ Print source code for the function corresponding to inputs """ print(self._source(*args))
def source(func): s = 'File: %s\n\n' % inspect.getsourcefile(func) s = s + inspect.getsource(func) return s
class MethodDispatcher(Dispatcher): """ Dispatch methods based on type signature
See Also: Dispatcher """
@classmethod def get_func_params(cls, func): if hasattr(inspect, "signature"): sig = inspect.signature(func) return itl.islice(sig.parameters.values(), 1, None)
def __get__(self, instance, owner): self.obj = instance self.cls = owner return self
def __call__(self, *args, **kwargs): types = tuple([type(arg) for arg in args]) func = self.dispatch(*types) if not func: raise NotImplementedError('Could not find signature for %s: <%s>' % (self.name, str_signature(types))) return func(self.obj, *args, **kwargs)
def str_signature(sig): """ String representation of type signature
>>> from sympy.multipledispatch.dispatcher import str_signature >>> str_signature((int, float)) 'int, float' """
return ', '.join(cls.__name__ for cls in sig)
def warning_text(name, amb): """ The text for ambiguity warnings """ text = "\nAmbiguities exist in dispatched function %s\n\n" % (name) text += "The following signatures may result in ambiguous behavior:\n" for pair in amb: text += "\t" + \ ', '.join('[' + str_signature(s) + ']' for s in pair) + "\n" text += "\n\nConsider making the following additions:\n\n" text += '\n\n'.join(['@dispatch(' + str_signature(super_signature(s)) + ')\ndef %s(...)' % name for s in amb]) return text
|