m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

480 lines
16 KiB

6 months ago
  1. import itertools
  2. from collections.abc import Iterable
  3. from sympy.core._print_helpers import Printable
  4. from sympy.core.containers import Tuple
  5. from sympy.core.function import diff
  6. from sympy.core.singleton import S
  7. from sympy.core.sympify import _sympify
  8. from sympy.tensor.array.ndim_array import NDimArray
  9. from sympy.tensor.array.dense_ndim_array import DenseNDimArray, ImmutableDenseNDimArray
  10. from sympy.tensor.array.sparse_ndim_array import SparseNDimArray
  11. def _arrayfy(a):
  12. from sympy.matrices import MatrixBase
  13. if isinstance(a, NDimArray):
  14. return a
  15. if isinstance(a, (MatrixBase, list, tuple, Tuple)):
  16. return ImmutableDenseNDimArray(a)
  17. return a
  18. def tensorproduct(*args):
  19. """
  20. Tensor product among scalars or array-like objects.
  21. Examples
  22. ========
  23. >>> from sympy.tensor.array import tensorproduct, Array
  24. >>> from sympy.abc import x, y, z, t
  25. >>> A = Array([[1, 2], [3, 4]])
  26. >>> B = Array([x, y])
  27. >>> tensorproduct(A, B)
  28. [[[x, y], [2*x, 2*y]], [[3*x, 3*y], [4*x, 4*y]]]
  29. >>> tensorproduct(A, x)
  30. [[x, 2*x], [3*x, 4*x]]
  31. >>> tensorproduct(A, B, B)
  32. [[[[x**2, x*y], [x*y, y**2]], [[2*x**2, 2*x*y], [2*x*y, 2*y**2]]], [[[3*x**2, 3*x*y], [3*x*y, 3*y**2]], [[4*x**2, 4*x*y], [4*x*y, 4*y**2]]]]
  33. Applying this function on two matrices will result in a rank 4 array.
  34. >>> from sympy import Matrix, eye
  35. >>> m = Matrix([[x, y], [z, t]])
  36. >>> p = tensorproduct(eye(3), m)
  37. >>> p
  38. [[[[x, y], [z, t]], [[0, 0], [0, 0]], [[0, 0], [0, 0]]], [[[0, 0], [0, 0]], [[x, y], [z, t]], [[0, 0], [0, 0]]], [[[0, 0], [0, 0]], [[0, 0], [0, 0]], [[x, y], [z, t]]]]
  39. """
  40. from sympy.tensor.array import SparseNDimArray, ImmutableSparseNDimArray
  41. if len(args) == 0:
  42. return S.One
  43. if len(args) == 1:
  44. return _arrayfy(args[0])
  45. from sympy.tensor.array.expressions.array_expressions import _CodegenArrayAbstract
  46. from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct
  47. from sympy.tensor.array.expressions.array_expressions import _ArrayExpr
  48. from sympy.matrices.expressions.matexpr import MatrixSymbol
  49. if any(isinstance(arg, (_ArrayExpr, _CodegenArrayAbstract, MatrixSymbol)) for arg in args):
  50. return ArrayTensorProduct(*args)
  51. if len(args) > 2:
  52. return tensorproduct(tensorproduct(args[0], args[1]), *args[2:])
  53. # length of args is 2:
  54. a, b = map(_arrayfy, args)
  55. if not isinstance(a, NDimArray) or not isinstance(b, NDimArray):
  56. return a*b
  57. if isinstance(a, SparseNDimArray) and isinstance(b, SparseNDimArray):
  58. lp = len(b)
  59. new_array = {k1*lp + k2: v1*v2 for k1, v1 in a._sparse_array.items() for k2, v2 in b._sparse_array.items()}
  60. return ImmutableSparseNDimArray(new_array, a.shape + b.shape)
  61. product_list = [i*j for i in Flatten(a) for j in Flatten(b)]
  62. return ImmutableDenseNDimArray(product_list, a.shape + b.shape)
  63. def _util_contraction_diagonal(array, *contraction_or_diagonal_axes):
  64. array = _arrayfy(array)
  65. # Verify contraction_axes:
  66. taken_dims = set()
  67. for axes_group in contraction_or_diagonal_axes:
  68. if not isinstance(axes_group, Iterable):
  69. raise ValueError("collections of contraction/diagonal axes expected")
  70. dim = array.shape[axes_group[0]]
  71. for d in axes_group:
  72. if d in taken_dims:
  73. raise ValueError("dimension specified more than once")
  74. if dim != array.shape[d]:
  75. raise ValueError("cannot contract or diagonalize between axes of different dimension")
  76. taken_dims.add(d)
  77. rank = array.rank()
  78. remaining_shape = [dim for i, dim in enumerate(array.shape) if i not in taken_dims]
  79. cum_shape = [0]*rank
  80. _cumul = 1
  81. for i in range(rank):
  82. cum_shape[rank - i - 1] = _cumul
  83. _cumul *= int(array.shape[rank - i - 1])
  84. # DEFINITION: by absolute position it is meant the position along the one
  85. # dimensional array containing all the tensor components.
  86. # Possible future work on this module: move computation of absolute
  87. # positions to a class method.
  88. # Determine absolute positions of the uncontracted indices:
  89. remaining_indices = [[cum_shape[i]*j for j in range(array.shape[i])]
  90. for i in range(rank) if i not in taken_dims]
  91. # Determine absolute positions of the contracted indices:
  92. summed_deltas = []
  93. for axes_group in contraction_or_diagonal_axes:
  94. lidx = []
  95. for js in range(array.shape[axes_group[0]]):
  96. lidx.append(sum([cum_shape[ig] * js for ig in axes_group]))
  97. summed_deltas.append(lidx)
  98. return array, remaining_indices, remaining_shape, summed_deltas
  99. def tensorcontraction(array, *contraction_axes):
  100. """
  101. Contraction of an array-like object on the specified axes.
  102. Examples
  103. ========
  104. >>> from sympy import Array, tensorcontraction
  105. >>> from sympy import Matrix, eye
  106. >>> tensorcontraction(eye(3), (0, 1))
  107. 3
  108. >>> A = Array(range(18), (3, 2, 3))
  109. >>> A
  110. [[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]], [[12, 13, 14], [15, 16, 17]]]
  111. >>> tensorcontraction(A, (0, 2))
  112. [21, 30]
  113. Matrix multiplication may be emulated with a proper combination of
  114. ``tensorcontraction`` and ``tensorproduct``
  115. >>> from sympy import tensorproduct
  116. >>> from sympy.abc import a,b,c,d,e,f,g,h
  117. >>> m1 = Matrix([[a, b], [c, d]])
  118. >>> m2 = Matrix([[e, f], [g, h]])
  119. >>> p = tensorproduct(m1, m2)
  120. >>> p
  121. [[[[a*e, a*f], [a*g, a*h]], [[b*e, b*f], [b*g, b*h]]], [[[c*e, c*f], [c*g, c*h]], [[d*e, d*f], [d*g, d*h]]]]
  122. >>> tensorcontraction(p, (1, 2))
  123. [[a*e + b*g, a*f + b*h], [c*e + d*g, c*f + d*h]]
  124. >>> m1*m2
  125. Matrix([
  126. [a*e + b*g, a*f + b*h],
  127. [c*e + d*g, c*f + d*h]])
  128. """
  129. from sympy.tensor.array.expressions.array_expressions import _array_contraction
  130. from sympy.tensor.array.expressions.array_expressions import _CodegenArrayAbstract
  131. from sympy.tensor.array.expressions.array_expressions import _ArrayExpr
  132. from sympy.matrices.expressions.matexpr import MatrixSymbol
  133. if isinstance(array, (_ArrayExpr, _CodegenArrayAbstract, MatrixSymbol)):
  134. return _array_contraction(array, *contraction_axes)
  135. array, remaining_indices, remaining_shape, summed_deltas = _util_contraction_diagonal(array, *contraction_axes)
  136. # Compute the contracted array:
  137. #
  138. # 1. external for loops on all uncontracted indices.
  139. # Uncontracted indices are determined by the combinatorial product of
  140. # the absolute positions of the remaining indices.
  141. # 2. internal loop on all contracted indices.
  142. # It sums the values of the absolute contracted index and the absolute
  143. # uncontracted index for the external loop.
  144. contracted_array = []
  145. for icontrib in itertools.product(*remaining_indices):
  146. index_base_position = sum(icontrib)
  147. isum = S.Zero
  148. for sum_to_index in itertools.product(*summed_deltas):
  149. idx = array._get_tuple_index(index_base_position + sum(sum_to_index))
  150. isum += array[idx]
  151. contracted_array.append(isum)
  152. if len(remaining_indices) == 0:
  153. assert len(contracted_array) == 1
  154. return contracted_array[0]
  155. return type(array)(contracted_array, remaining_shape)
  156. def tensordiagonal(array, *diagonal_axes):
  157. """
  158. Diagonalization of an array-like object on the specified axes.
  159. This is equivalent to multiplying the expression by Kronecker deltas
  160. uniting the axes.
  161. The diagonal indices are put at the end of the axes.
  162. Examples
  163. ========
  164. ``tensordiagonal`` acting on a 2-dimensional array by axes 0 and 1 is
  165. equivalent to the diagonal of the matrix:
  166. >>> from sympy import Array, tensordiagonal
  167. >>> from sympy import Matrix, eye
  168. >>> tensordiagonal(eye(3), (0, 1))
  169. [1, 1, 1]
  170. >>> from sympy.abc import a,b,c,d
  171. >>> m1 = Matrix([[a, b], [c, d]])
  172. >>> tensordiagonal(m1, [0, 1])
  173. [a, d]
  174. In case of higher dimensional arrays, the diagonalized out dimensions
  175. are appended removed and appended as a single dimension at the end:
  176. >>> A = Array(range(18), (3, 2, 3))
  177. >>> A
  178. [[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]], [[12, 13, 14], [15, 16, 17]]]
  179. >>> tensordiagonal(A, (0, 2))
  180. [[0, 7, 14], [3, 10, 17]]
  181. >>> from sympy import permutedims
  182. >>> tensordiagonal(A, (0, 2)) == permutedims(Array([A[0, :, 0], A[1, :, 1], A[2, :, 2]]), [1, 0])
  183. True
  184. """
  185. if any(len(i) <= 1 for i in diagonal_axes):
  186. raise ValueError("need at least two axes to diagonalize")
  187. from sympy.tensor.array.expressions.array_expressions import _ArrayExpr
  188. from sympy.tensor.array.expressions.array_expressions import _CodegenArrayAbstract
  189. from sympy.tensor.array.expressions.array_expressions import ArrayDiagonal, _array_diagonal
  190. from sympy.matrices.expressions.matexpr import MatrixSymbol
  191. if isinstance(array, (_ArrayExpr, _CodegenArrayAbstract, MatrixSymbol)):
  192. return _array_diagonal(array, *diagonal_axes)
  193. ArrayDiagonal._validate(array, *diagonal_axes)
  194. array, remaining_indices, remaining_shape, diagonal_deltas = _util_contraction_diagonal(array, *diagonal_axes)
  195. # Compute the diagonalized array:
  196. #
  197. # 1. external for loops on all undiagonalized indices.
  198. # Undiagonalized indices are determined by the combinatorial product of
  199. # the absolute positions of the remaining indices.
  200. # 2. internal loop on all diagonal indices.
  201. # It appends the values of the absolute diagonalized index and the absolute
  202. # undiagonalized index for the external loop.
  203. diagonalized_array = []
  204. diagonal_shape = [len(i) for i in diagonal_deltas]
  205. for icontrib in itertools.product(*remaining_indices):
  206. index_base_position = sum(icontrib)
  207. isum = []
  208. for sum_to_index in itertools.product(*diagonal_deltas):
  209. idx = array._get_tuple_index(index_base_position + sum(sum_to_index))
  210. isum.append(array[idx])
  211. isum = type(array)(isum).reshape(*diagonal_shape)
  212. diagonalized_array.append(isum)
  213. return type(array)(diagonalized_array, remaining_shape + diagonal_shape)
  214. def derive_by_array(expr, dx):
  215. r"""
  216. Derivative by arrays. Supports both arrays and scalars.
  217. Explanation
  218. ===========
  219. Given the array `A_{i_1, \ldots, i_N}` and the array `X_{j_1, \ldots, j_M}`
  220. this function will return a new array `B` defined by
  221. `B_{j_1,\ldots,j_M,i_1,\ldots,i_N} := \frac{\partial A_{i_1,\ldots,i_N}}{\partial X_{j_1,\ldots,j_M}}`
  222. Examples
  223. ========
  224. >>> from sympy import derive_by_array
  225. >>> from sympy.abc import x, y, z, t
  226. >>> from sympy import cos
  227. >>> derive_by_array(cos(x*t), x)
  228. -t*sin(t*x)
  229. >>> derive_by_array(cos(x*t), [x, y, z, t])
  230. [-t*sin(t*x), 0, 0, -x*sin(t*x)]
  231. >>> derive_by_array([x, y**2*z], [[x, y], [z, t]])
  232. [[[1, 0], [0, 2*y*z]], [[0, y**2], [0, 0]]]
  233. """
  234. from sympy.matrices import MatrixBase
  235. from sympy.tensor.array import SparseNDimArray
  236. array_types = (Iterable, MatrixBase, NDimArray)
  237. if isinstance(dx, array_types):
  238. dx = ImmutableDenseNDimArray(dx)
  239. for i in dx:
  240. if not i._diff_wrt:
  241. raise ValueError("cannot derive by this array")
  242. if isinstance(expr, array_types):
  243. if isinstance(expr, NDimArray):
  244. expr = expr.as_immutable()
  245. else:
  246. expr = ImmutableDenseNDimArray(expr)
  247. if isinstance(dx, array_types):
  248. if isinstance(expr, SparseNDimArray):
  249. lp = len(expr)
  250. new_array = {k + i*lp: v
  251. for i, x in enumerate(Flatten(dx))
  252. for k, v in expr.diff(x)._sparse_array.items()}
  253. else:
  254. new_array = [[y.diff(x) for y in Flatten(expr)] for x in Flatten(dx)]
  255. return type(expr)(new_array, dx.shape + expr.shape)
  256. else:
  257. return expr.diff(dx)
  258. else:
  259. expr = _sympify(expr)
  260. if isinstance(dx, array_types):
  261. return ImmutableDenseNDimArray([expr.diff(i) for i in Flatten(dx)], dx.shape)
  262. else:
  263. dx = _sympify(dx)
  264. return diff(expr, dx)
  265. def permutedims(expr, perm):
  266. """
  267. Permutes the indices of an array.
  268. Parameter specifies the permutation of the indices.
  269. Examples
  270. ========
  271. >>> from sympy.abc import x, y, z, t
  272. >>> from sympy import sin
  273. >>> from sympy import Array, permutedims
  274. >>> a = Array([[x, y, z], [t, sin(x), 0]])
  275. >>> a
  276. [[x, y, z], [t, sin(x), 0]]
  277. >>> permutedims(a, (1, 0))
  278. [[x, t], [y, sin(x)], [z, 0]]
  279. If the array is of second order, ``transpose`` can be used:
  280. >>> from sympy import transpose
  281. >>> transpose(a)
  282. [[x, t], [y, sin(x)], [z, 0]]
  283. Examples on higher dimensions:
  284. >>> b = Array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
  285. >>> permutedims(b, (2, 1, 0))
  286. [[[1, 5], [3, 7]], [[2, 6], [4, 8]]]
  287. >>> permutedims(b, (1, 2, 0))
  288. [[[1, 5], [2, 6]], [[3, 7], [4, 8]]]
  289. ``Permutation`` objects are also allowed:
  290. >>> from sympy.combinatorics import Permutation
  291. >>> permutedims(b, Permutation([1, 2, 0]))
  292. [[[1, 5], [2, 6]], [[3, 7], [4, 8]]]
  293. """
  294. from sympy.tensor.array import SparseNDimArray
  295. from sympy.tensor.array.expressions.array_expressions import _ArrayExpr
  296. from sympy.tensor.array.expressions.array_expressions import _CodegenArrayAbstract
  297. from sympy.tensor.array.expressions.array_expressions import _permute_dims
  298. from sympy.matrices.expressions.matexpr import MatrixSymbol
  299. if isinstance(expr, (_ArrayExpr, _CodegenArrayAbstract, MatrixSymbol)):
  300. return _permute_dims(expr, perm)
  301. if not isinstance(expr, NDimArray):
  302. expr = ImmutableDenseNDimArray(expr)
  303. from sympy.combinatorics import Permutation
  304. if not isinstance(perm, Permutation):
  305. perm = Permutation(list(perm))
  306. if perm.size != expr.rank():
  307. raise ValueError("wrong permutation size")
  308. # Get the inverse permutation:
  309. iperm = ~perm
  310. new_shape = perm(expr.shape)
  311. if isinstance(expr, SparseNDimArray):
  312. return type(expr)({tuple(perm(expr._get_tuple_index(k))): v
  313. for k, v in expr._sparse_array.items()}, new_shape)
  314. indices_span = perm([range(i) for i in expr.shape])
  315. new_array = [None]*len(expr)
  316. for i, idx in enumerate(itertools.product(*indices_span)):
  317. t = iperm(idx)
  318. new_array[i] = expr[t]
  319. return type(expr)(new_array, new_shape)
  320. class Flatten(Printable):
  321. '''
  322. Flatten an iterable object to a list in a lazy-evaluation way.
  323. Notes
  324. =====
  325. This class is an iterator with which the memory cost can be economised.
  326. Optimisation has been considered to ameliorate the performance for some
  327. specific data types like DenseNDimArray and SparseNDimArray.
  328. Examples
  329. ========
  330. >>> from sympy.tensor.array.arrayop import Flatten
  331. >>> from sympy.tensor.array import Array
  332. >>> A = Array(range(6)).reshape(2, 3)
  333. >>> Flatten(A)
  334. Flatten([[0, 1, 2], [3, 4, 5]])
  335. >>> [i for i in Flatten(A)]
  336. [0, 1, 2, 3, 4, 5]
  337. '''
  338. def __init__(self, iterable):
  339. from sympy.matrices.matrices import MatrixBase
  340. from sympy.tensor.array import NDimArray
  341. if not isinstance(iterable, (Iterable, MatrixBase)):
  342. raise NotImplementedError("Data type not yet supported")
  343. if isinstance(iterable, list):
  344. iterable = NDimArray(iterable)
  345. self._iter = iterable
  346. self._idx = 0
  347. def __iter__(self):
  348. return self
  349. def __next__(self):
  350. from sympy.matrices.matrices import MatrixBase
  351. if len(self._iter) > self._idx:
  352. if isinstance(self._iter, DenseNDimArray):
  353. result = self._iter._array[self._idx]
  354. elif isinstance(self._iter, SparseNDimArray):
  355. if self._idx in self._iter._sparse_array:
  356. result = self._iter._sparse_array[self._idx]
  357. else:
  358. result = 0
  359. elif isinstance(self._iter, MatrixBase):
  360. result = self._iter[self._idx]
  361. elif hasattr(self._iter, '__next__'):
  362. result = next(self._iter)
  363. else:
  364. result = self._iter[self._idx]
  365. else:
  366. raise StopIteration
  367. self._idx += 1
  368. return result
  369. def next(self):
  370. return self.__next__()
  371. def _sympystr(self, printer):
  372. return type(self).__name__ + '(' + printer._print(self._iter) + ')'