图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

413 lines
12 KiB

  1. from typing import Set as tSet
  2. from warnings import warn
  3. import inspect
  4. from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning
  5. from .utils import expand_tuples
  6. import itertools as itl
  7. class MDNotImplementedError(NotImplementedError):
  8. """ A NotImplementedError for multiple dispatch """
  9. ### Functions for on_ambiguity
  10. def ambiguity_warn(dispatcher, ambiguities):
  11. """ Raise warning when ambiguity is detected
  12. Parameters
  13. ----------
  14. dispatcher : Dispatcher
  15. The dispatcher on which the ambiguity was detected
  16. ambiguities : set
  17. Set of type signature pairs that are ambiguous within this dispatcher
  18. See Also:
  19. Dispatcher.add
  20. warning_text
  21. """
  22. warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning)
  23. class RaiseNotImplementedError:
  24. """Raise ``NotImplementedError`` when called."""
  25. def __init__(self, dispatcher):
  26. self.dispatcher = dispatcher
  27. def __call__(self, *args, **kwargs):
  28. types = tuple(type(a) for a in args)
  29. raise NotImplementedError(
  30. "Ambiguous signature for %s: <%s>" % (
  31. self.dispatcher.name, str_signature(types)
  32. ))
  33. def ambiguity_register_error_ignore_dup(dispatcher, ambiguities):
  34. """
  35. If super signature for ambiguous types is duplicate types, ignore it.
  36. Else, register instance of ``RaiseNotImplementedError`` for ambiguous types.
  37. Parameters
  38. ----------
  39. dispatcher : Dispatcher
  40. The dispatcher on which the ambiguity was detected
  41. ambiguities : set
  42. Set of type signature pairs that are ambiguous within this dispatcher
  43. See Also:
  44. Dispatcher.add
  45. ambiguity_warn
  46. """
  47. for amb in ambiguities:
  48. signature = tuple(super_signature(amb))
  49. if len(set(signature)) == 1:
  50. continue
  51. dispatcher.add(
  52. signature, RaiseNotImplementedError(dispatcher),
  53. on_ambiguity=ambiguity_register_error_ignore_dup
  54. )
  55. ###
  56. _unresolved_dispatchers = set() # type: tSet[Dispatcher]
  57. _resolve = [True]
  58. def halt_ordering():
  59. _resolve[0] = False
  60. def restart_ordering(on_ambiguity=ambiguity_warn):
  61. _resolve[0] = True
  62. while _unresolved_dispatchers:
  63. dispatcher = _unresolved_dispatchers.pop()
  64. dispatcher.reorder(on_ambiguity=on_ambiguity)
  65. class Dispatcher:
  66. """ Dispatch methods based on type signature
  67. Use ``dispatch`` to add implementations
  68. Examples
  69. --------
  70. >>> from sympy.multipledispatch import dispatch
  71. >>> @dispatch(int)
  72. ... def f(x):
  73. ... return x + 1
  74. >>> @dispatch(float)
  75. ... def f(x): # noqa: F811
  76. ... return x - 1
  77. >>> f(3)
  78. 4
  79. >>> f(3.0)
  80. 2.0
  81. """
  82. __slots__ = '__name__', 'name', 'funcs', 'ordering', '_cache', 'doc'
  83. def __init__(self, name, doc=None):
  84. self.name = self.__name__ = name
  85. self.funcs = dict()
  86. self._cache = dict()
  87. self.ordering = []
  88. self.doc = doc
  89. def register(self, *types, **kwargs):
  90. """ Register dispatcher with new implementation
  91. >>> from sympy.multipledispatch.dispatcher import Dispatcher
  92. >>> f = Dispatcher('f')
  93. >>> @f.register(int)
  94. ... def inc(x):
  95. ... return x + 1
  96. >>> @f.register(float)
  97. ... def dec(x):
  98. ... return x - 1
  99. >>> @f.register(list)
  100. ... @f.register(tuple)
  101. ... def reverse(x):
  102. ... return x[::-1]
  103. >>> f(1)
  104. 2
  105. >>> f(1.0)
  106. 0.0
  107. >>> f([1, 2, 3])
  108. [3, 2, 1]
  109. """
  110. def _(func):
  111. self.add(types, func, **kwargs)
  112. return func
  113. return _
  114. @classmethod
  115. def get_func_params(cls, func):
  116. if hasattr(inspect, "signature"):
  117. sig = inspect.signature(func)
  118. return sig.parameters.values()
  119. @classmethod
  120. def get_func_annotations(cls, func):
  121. """ Get annotations of function positional parameters
  122. """
  123. params = cls.get_func_params(func)
  124. if params:
  125. Parameter = inspect.Parameter
  126. params = (param for param in params
  127. if param.kind in
  128. (Parameter.POSITIONAL_ONLY,
  129. Parameter.POSITIONAL_OR_KEYWORD))
  130. annotations = tuple(
  131. param.annotation
  132. for param in params)
  133. if not any(ann is Parameter.empty for ann in annotations):
  134. return annotations
  135. def add(self, signature, func, on_ambiguity=ambiguity_warn):
  136. """ Add new types/method pair to dispatcher
  137. >>> from sympy.multipledispatch import Dispatcher
  138. >>> D = Dispatcher('add')
  139. >>> D.add((int, int), lambda x, y: x + y)
  140. >>> D.add((float, float), lambda x, y: x + y)
  141. >>> D(1, 2)
  142. 3
  143. >>> D(1, 2.0)
  144. Traceback (most recent call last):
  145. ...
  146. NotImplementedError: Could not find signature for add: <int, float>
  147. When ``add`` detects a warning it calls the ``on_ambiguity`` callback
  148. with a dispatcher/itself, and a set of ambiguous type signature pairs
  149. as inputs. See ``ambiguity_warn`` for an example.
  150. """
  151. # Handle annotations
  152. if not signature:
  153. annotations = self.get_func_annotations(func)
  154. if annotations:
  155. signature = annotations
  156. # Handle union types
  157. if any(isinstance(typ, tuple) for typ in signature):
  158. for typs in expand_tuples(signature):
  159. self.add(typs, func, on_ambiguity)
  160. return
  161. for typ in signature:
  162. if not isinstance(typ, type):
  163. str_sig = ', '.join(c.__name__ if isinstance(c, type)
  164. else str(c) for c in signature)
  165. raise TypeError("Tried to dispatch on non-type: %s\n"
  166. "In signature: <%s>\n"
  167. "In function: %s" %
  168. (typ, str_sig, self.name))
  169. self.funcs[signature] = func
  170. self.reorder(on_ambiguity=on_ambiguity)
  171. self._cache.clear()
  172. def reorder(self, on_ambiguity=ambiguity_warn):
  173. if _resolve[0]:
  174. self.ordering = ordering(self.funcs)
  175. amb = ambiguities(self.funcs)
  176. if amb:
  177. on_ambiguity(self, amb)
  178. else:
  179. _unresolved_dispatchers.add(self)
  180. def __call__(self, *args, **kwargs):
  181. types = tuple([type(arg) for arg in args])
  182. try:
  183. func = self._cache[types]
  184. except KeyError:
  185. func = self.dispatch(*types)
  186. if not func:
  187. raise NotImplementedError(
  188. 'Could not find signature for %s: <%s>' %
  189. (self.name, str_signature(types)))
  190. self._cache[types] = func
  191. try:
  192. return func(*args, **kwargs)
  193. except MDNotImplementedError:
  194. funcs = self.dispatch_iter(*types)
  195. next(funcs) # burn first
  196. for func in funcs:
  197. try:
  198. return func(*args, **kwargs)
  199. except MDNotImplementedError:
  200. pass
  201. raise NotImplementedError("Matching functions for "
  202. "%s: <%s> found, but none completed successfully"
  203. % (self.name, str_signature(types)))
  204. def __str__(self):
  205. return "<dispatched %s>" % self.name
  206. __repr__ = __str__
  207. def dispatch(self, *types):
  208. """ Deterimine appropriate implementation for this type signature
  209. This method is internal. Users should call this object as a function.
  210. Implementation resolution occurs within the ``__call__`` method.
  211. >>> from sympy.multipledispatch import dispatch
  212. >>> @dispatch(int)
  213. ... def inc(x):
  214. ... return x + 1
  215. >>> implementation = inc.dispatch(int)
  216. >>> implementation(3)
  217. 4
  218. >>> print(inc.dispatch(float))
  219. None
  220. See Also:
  221. ``sympy.multipledispatch.conflict`` - module to determine resolution order
  222. """
  223. if types in self.funcs:
  224. return self.funcs[types]
  225. try:
  226. return next(self.dispatch_iter(*types))
  227. except StopIteration:
  228. return None
  229. def dispatch_iter(self, *types):
  230. n = len(types)
  231. for signature in self.ordering:
  232. if len(signature) == n and all(map(issubclass, types, signature)):
  233. result = self.funcs[signature]
  234. yield result
  235. def resolve(self, types):
  236. """ Deterimine appropriate implementation for this type signature
  237. .. deprecated:: 0.4.4
  238. Use ``dispatch(*types)`` instead
  239. """
  240. warn("resolve() is deprecated, use dispatch(*types)",
  241. DeprecationWarning)
  242. return self.dispatch(*types)
  243. def __getstate__(self):
  244. return {'name': self.name,
  245. 'funcs': self.funcs}
  246. def __setstate__(self, d):
  247. self.name = d['name']
  248. self.funcs = d['funcs']
  249. self.ordering = ordering(self.funcs)
  250. self._cache = dict()
  251. @property
  252. def __doc__(self):
  253. docs = ["Multiply dispatched method: %s" % self.name]
  254. if self.doc:
  255. docs.append(self.doc)
  256. other = []
  257. for sig in self.ordering[::-1]:
  258. func = self.funcs[sig]
  259. if func.__doc__:
  260. s = 'Inputs: <%s>\n' % str_signature(sig)
  261. s += '-' * len(s) + '\n'
  262. s += func.__doc__.strip()
  263. docs.append(s)
  264. else:
  265. other.append(str_signature(sig))
  266. if other:
  267. docs.append('Other signatures:\n ' + '\n '.join(other))
  268. return '\n\n'.join(docs)
  269. def _help(self, *args):
  270. return self.dispatch(*map(type, args)).__doc__
  271. def help(self, *args, **kwargs):
  272. """ Print docstring for the function corresponding to inputs """
  273. print(self._help(*args))
  274. def _source(self, *args):
  275. func = self.dispatch(*map(type, args))
  276. if not func:
  277. raise TypeError("No function found")
  278. return source(func)
  279. def source(self, *args, **kwargs):
  280. """ Print source code for the function corresponding to inputs """
  281. print(self._source(*args))
  282. def source(func):
  283. s = 'File: %s\n\n' % inspect.getsourcefile(func)
  284. s = s + inspect.getsource(func)
  285. return s
  286. class MethodDispatcher(Dispatcher):
  287. """ Dispatch methods based on type signature
  288. See Also:
  289. Dispatcher
  290. """
  291. @classmethod
  292. def get_func_params(cls, func):
  293. if hasattr(inspect, "signature"):
  294. sig = inspect.signature(func)
  295. return itl.islice(sig.parameters.values(), 1, None)
  296. def __get__(self, instance, owner):
  297. self.obj = instance
  298. self.cls = owner
  299. return self
  300. def __call__(self, *args, **kwargs):
  301. types = tuple([type(arg) for arg in args])
  302. func = self.dispatch(*types)
  303. if not func:
  304. raise NotImplementedError('Could not find signature for %s: <%s>' %
  305. (self.name, str_signature(types)))
  306. return func(self.obj, *args, **kwargs)
  307. def str_signature(sig):
  308. """ String representation of type signature
  309. >>> from sympy.multipledispatch.dispatcher import str_signature
  310. >>> str_signature((int, float))
  311. 'int, float'
  312. """
  313. return ', '.join(cls.__name__ for cls in sig)
  314. def warning_text(name, amb):
  315. """ The text for ambiguity warnings """
  316. text = "\nAmbiguities exist in dispatched function %s\n\n" % (name)
  317. text += "The following signatures may result in ambiguous behavior:\n"
  318. for pair in amb:
  319. text += "\t" + \
  320. ', '.join('[' + str_signature(s) + ']' for s in pair) + "\n"
  321. text += "\n\nConsider making the following additions:\n\n"
  322. text += '\n\n'.join(['@dispatch(' + str_signature(super_signature(s))
  323. + ')\ndef %s(...)' % name for s in amb])
  324. return text