m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

227 lines
7.9 KiB

6 months ago
  1. """Implementation of __array_function__ overrides from NEP-18."""
  2. import collections
  3. import functools
  4. import os
  5. import textwrap
  6. from numpy.core._multiarray_umath import (
  7. add_docstring, implement_array_function, _get_implementing_args)
  8. from numpy.compat._inspect import getargspec
  9. ARRAY_FUNCTION_ENABLED = bool(
  10. int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 1)))
  11. array_function_like_doc = (
  12. """like : array_like
  13. Reference object to allow the creation of arrays which are not
  14. NumPy arrays. If an array-like passed in as ``like`` supports
  15. the ``__array_function__`` protocol, the result will be defined
  16. by it. In this case, it ensures the creation of an array object
  17. compatible with that passed in via this argument."""
  18. )
  19. def set_array_function_like_doc(public_api):
  20. if public_api.__doc__ is not None:
  21. public_api.__doc__ = public_api.__doc__.replace(
  22. "${ARRAY_FUNCTION_LIKE}",
  23. array_function_like_doc,
  24. )
  25. return public_api
  26. add_docstring(
  27. implement_array_function,
  28. """
  29. Implement a function with checks for __array_function__ overrides.
  30. All arguments are required, and can only be passed by position.
  31. Parameters
  32. ----------
  33. implementation : function
  34. Function that implements the operation on NumPy array without
  35. overrides when called like ``implementation(*args, **kwargs)``.
  36. public_api : function
  37. Function exposed by NumPy's public API originally called like
  38. ``public_api(*args, **kwargs)`` on which arguments are now being
  39. checked.
  40. relevant_args : iterable
  41. Iterable of arguments to check for __array_function__ methods.
  42. args : tuple
  43. Arbitrary positional arguments originally passed into ``public_api``.
  44. kwargs : dict
  45. Arbitrary keyword arguments originally passed into ``public_api``.
  46. Returns
  47. -------
  48. Result from calling ``implementation()`` or an ``__array_function__``
  49. method, as appropriate.
  50. Raises
  51. ------
  52. TypeError : if no implementation is found.
  53. """)
  54. # exposed for testing purposes; used internally by implement_array_function
  55. add_docstring(
  56. _get_implementing_args,
  57. """
  58. Collect arguments on which to call __array_function__.
  59. Parameters
  60. ----------
  61. relevant_args : iterable of array-like
  62. Iterable of possibly array-like arguments to check for
  63. __array_function__ methods.
  64. Returns
  65. -------
  66. Sequence of arguments with __array_function__ methods, in the order in
  67. which they should be called.
  68. """)
  69. ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults')
  70. def verify_matching_signatures(implementation, dispatcher):
  71. """Verify that a dispatcher function has the right signature."""
  72. implementation_spec = ArgSpec(*getargspec(implementation))
  73. dispatcher_spec = ArgSpec(*getargspec(dispatcher))
  74. if (implementation_spec.args != dispatcher_spec.args or
  75. implementation_spec.varargs != dispatcher_spec.varargs or
  76. implementation_spec.keywords != dispatcher_spec.keywords or
  77. (bool(implementation_spec.defaults) !=
  78. bool(dispatcher_spec.defaults)) or
  79. (implementation_spec.defaults is not None and
  80. len(implementation_spec.defaults) !=
  81. len(dispatcher_spec.defaults))):
  82. raise RuntimeError('implementation and dispatcher for %s have '
  83. 'different function signatures' % implementation)
  84. if implementation_spec.defaults is not None:
  85. if dispatcher_spec.defaults != (None,) * len(dispatcher_spec.defaults):
  86. raise RuntimeError('dispatcher functions can only use None for '
  87. 'default argument values')
  88. def set_module(module):
  89. """Decorator for overriding __module__ on a function or class.
  90. Example usage::
  91. @set_module('numpy')
  92. def example():
  93. pass
  94. assert example.__module__ == 'numpy'
  95. """
  96. def decorator(func):
  97. if module is not None:
  98. func.__module__ = module
  99. return func
  100. return decorator
  101. # Call textwrap.dedent here instead of in the function so as to avoid
  102. # calling dedent multiple times on the same text
  103. _wrapped_func_source = textwrap.dedent("""
  104. @functools.wraps(implementation)
  105. def {name}(*args, **kwargs):
  106. relevant_args = dispatcher(*args, **kwargs)
  107. return implement_array_function(
  108. implementation, {name}, relevant_args, args, kwargs)
  109. """)
  110. def array_function_dispatch(dispatcher, module=None, verify=True,
  111. docs_from_dispatcher=False):
  112. """Decorator for adding dispatch with the __array_function__ protocol.
  113. See NEP-18 for example usage.
  114. Parameters
  115. ----------
  116. dispatcher : callable
  117. Function that when called like ``dispatcher(*args, **kwargs)`` with
  118. arguments from the NumPy function call returns an iterable of
  119. array-like arguments to check for ``__array_function__``.
  120. module : str, optional
  121. __module__ attribute to set on new function, e.g., ``module='numpy'``.
  122. By default, module is copied from the decorated function.
  123. verify : bool, optional
  124. If True, verify the that the signature of the dispatcher and decorated
  125. function signatures match exactly: all required and optional arguments
  126. should appear in order with the same names, but the default values for
  127. all optional arguments should be ``None``. Only disable verification
  128. if the dispatcher's signature needs to deviate for some particular
  129. reason, e.g., because the function has a signature like
  130. ``func(*args, **kwargs)``.
  131. docs_from_dispatcher : bool, optional
  132. If True, copy docs from the dispatcher function onto the dispatched
  133. function, rather than from the implementation. This is useful for
  134. functions defined in C, which otherwise don't have docstrings.
  135. Returns
  136. -------
  137. Function suitable for decorating the implementation of a NumPy function.
  138. """
  139. if not ARRAY_FUNCTION_ENABLED:
  140. def decorator(implementation):
  141. if docs_from_dispatcher:
  142. add_docstring(implementation, dispatcher.__doc__)
  143. if module is not None:
  144. implementation.__module__ = module
  145. return implementation
  146. return decorator
  147. def decorator(implementation):
  148. if verify:
  149. verify_matching_signatures(implementation, dispatcher)
  150. if docs_from_dispatcher:
  151. add_docstring(implementation, dispatcher.__doc__)
  152. # Equivalently, we could define this function directly instead of using
  153. # exec. This version has the advantage of giving the helper function a
  154. # more interpettable name. Otherwise, the original function does not
  155. # show up at all in many cases, e.g., if it's written in C or if the
  156. # dispatcher gets an invalid keyword argument.
  157. source = _wrapped_func_source.format(name=implementation.__name__)
  158. source_object = compile(
  159. source, filename='<__array_function__ internals>', mode='exec')
  160. scope = {
  161. 'implementation': implementation,
  162. 'dispatcher': dispatcher,
  163. 'functools': functools,
  164. 'implement_array_function': implement_array_function,
  165. }
  166. exec(source_object, scope)
  167. public_api = scope[implementation.__name__]
  168. if module is not None:
  169. public_api.__module__ = module
  170. public_api._implementation = implementation
  171. return public_api
  172. return decorator
  173. def array_function_from_dispatcher(
  174. implementation, module=None, verify=True, docs_from_dispatcher=True):
  175. """Like array_function_dispatcher, but with function arguments flipped."""
  176. def decorator(dispatcher):
  177. return array_function_dispatch(
  178. dispatcher, module, verify=verify,
  179. docs_from_dispatcher=docs_from_dispatcher)(implementation)
  180. return decorator