m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

137 lines
4.2 KiB

6 months ago
  1. """
  2. Provides functionality for multidimensional usage of scalar-functions.
  3. Read the vectorize docstring for more details.
  4. """
  5. from functools import wraps
  6. def apply_on_element(f, args, kwargs, n):
  7. """
  8. Returns a structure with the same dimension as the specified argument,
  9. where each basic element is replaced by the function f applied on it. All
  10. other arguments stay the same.
  11. """
  12. # Get the specified argument.
  13. if isinstance(n, int):
  14. structure = args[n]
  15. is_arg = True
  16. elif isinstance(n, str):
  17. structure = kwargs[n]
  18. is_arg = False
  19. # Define reduced function that is only dependent on the specified argument.
  20. def f_reduced(x):
  21. if hasattr(x, "__iter__"):
  22. return list(map(f_reduced, x))
  23. else:
  24. if is_arg:
  25. args[n] = x
  26. else:
  27. kwargs[n] = x
  28. return f(*args, **kwargs)
  29. # f_reduced will call itself recursively so that in the end f is applied to
  30. # all basic elements.
  31. return list(map(f_reduced, structure))
  32. def iter_copy(structure):
  33. """
  34. Returns a copy of an iterable object (also copying all embedded iterables).
  35. """
  36. l = []
  37. for i in structure:
  38. if hasattr(i, "__iter__"):
  39. l.append(iter_copy(i))
  40. else:
  41. l.append(i)
  42. return l
  43. def structure_copy(structure):
  44. """
  45. Returns a copy of the given structure (numpy-array, list, iterable, ..).
  46. """
  47. if hasattr(structure, "copy"):
  48. return structure.copy()
  49. return iter_copy(structure)
  50. class vectorize:
  51. """
  52. Generalizes a function taking scalars to accept multidimensional arguments.
  53. Examples
  54. ========
  55. >>> from sympy import vectorize, diff, sin, symbols, Function
  56. >>> x, y, z = symbols('x y z')
  57. >>> f, g, h = list(map(Function, 'fgh'))
  58. >>> @vectorize(0)
  59. ... def vsin(x):
  60. ... return sin(x)
  61. >>> vsin([1, x, y])
  62. [sin(1), sin(x), sin(y)]
  63. >>> @vectorize(0, 1)
  64. ... def vdiff(f, y):
  65. ... return diff(f, y)
  66. >>> vdiff([f(x, y, z), g(x, y, z), h(x, y, z)], [x, y, z])
  67. [[Derivative(f(x, y, z), x), Derivative(f(x, y, z), y), Derivative(f(x, y, z), z)], [Derivative(g(x, y, z), x), Derivative(g(x, y, z), y), Derivative(g(x, y, z), z)], [Derivative(h(x, y, z), x), Derivative(h(x, y, z), y), Derivative(h(x, y, z), z)]]
  68. """
  69. def __init__(self, *mdargs):
  70. """
  71. The given numbers and strings characterize the arguments that will be
  72. treated as data structures, where the decorated function will be applied
  73. to every single element.
  74. If no argument is given, everything is treated multidimensional.
  75. """
  76. for a in mdargs:
  77. if not isinstance(a, (int, str)):
  78. raise TypeError("a is of invalid type")
  79. self.mdargs = mdargs
  80. def __call__(self, f):
  81. """
  82. Returns a wrapper for the one-dimensional function that can handle
  83. multidimensional arguments.
  84. """
  85. @wraps(f)
  86. def wrapper(*args, **kwargs):
  87. # Get arguments that should be treated multidimensional
  88. if self.mdargs:
  89. mdargs = self.mdargs
  90. else:
  91. mdargs = range(len(args)) + kwargs.keys()
  92. arglength = len(args)
  93. for n in mdargs:
  94. if isinstance(n, int):
  95. if n >= arglength:
  96. continue
  97. entry = args[n]
  98. is_arg = True
  99. elif isinstance(n, str):
  100. try:
  101. entry = kwargs[n]
  102. except KeyError:
  103. continue
  104. is_arg = False
  105. if hasattr(entry, "__iter__"):
  106. # Create now a copy of the given array and manipulate then
  107. # the entries directly.
  108. if is_arg:
  109. args = list(args)
  110. args[n] = structure_copy(entry)
  111. else:
  112. kwargs[n] = structure_copy(entry)
  113. result = apply_on_element(wrapper, args, kwargs, n)
  114. return result
  115. return f(*args, **kwargs)
  116. return wrapper