m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

592 lines
18 KiB

6 months ago
  1. from sympy.core.basic import Basic
  2. from sympy.core.containers import (Dict, Tuple)
  3. from sympy.core.expr import Expr
  4. from sympy.core.kind import Kind, NumberKind, UndefinedKind
  5. from sympy.core.numbers import Integer
  6. from sympy.core.singleton import S
  7. from sympy.core.sympify import sympify
  8. from sympy.external.gmpy import SYMPY_INTS
  9. from sympy.printing.defaults import Printable
  10. import itertools
  11. from collections.abc import Iterable
  12. class ArrayKind(Kind):
  13. """
  14. Kind for N-dimensional array in SymPy.
  15. This kind represents the multidimensional array that algebraic
  16. operations are defined. Basic class for this kind is ``NDimArray``,
  17. but any expression representing the array can have this.
  18. Parameters
  19. ==========
  20. element_kind : Kind
  21. Kind of the element. Default is :obj:NumberKind `<sympy.core.kind.NumberKind>`,
  22. which means that the array contains only numbers.
  23. Examples
  24. ========
  25. Any instance of array class has ``ArrayKind``.
  26. >>> from sympy import NDimArray
  27. >>> NDimArray([1,2,3]).kind
  28. ArrayKind(NumberKind)
  29. Although expressions representing an array may be not instance of
  30. array class, it will have ``ArrayKind`` as well.
  31. >>> from sympy import Integral
  32. >>> from sympy.tensor.array import NDimArray
  33. >>> from sympy.abc import x
  34. >>> intA = Integral(NDimArray([1,2,3]), x)
  35. >>> isinstance(intA, NDimArray)
  36. False
  37. >>> intA.kind
  38. ArrayKind(NumberKind)
  39. Use ``isinstance()`` to check for ``ArrayKind` without specifying
  40. the element kind. Use ``is`` with specifying the element kind.
  41. >>> from sympy.tensor.array import ArrayKind
  42. >>> from sympy.core import NumberKind
  43. >>> boolA = NDimArray([True, False])
  44. >>> isinstance(boolA.kind, ArrayKind)
  45. True
  46. >>> boolA.kind is ArrayKind(NumberKind)
  47. False
  48. See Also
  49. ========
  50. shape : Function to return the shape of objects with ``MatrixKind``.
  51. """
  52. def __new__(cls, element_kind=NumberKind):
  53. obj = super().__new__(cls, element_kind)
  54. obj.element_kind = element_kind
  55. return obj
  56. def __repr__(self):
  57. return "ArrayKind(%s)" % self.element_kind
  58. @classmethod
  59. def _union(cls, kinds) -> 'ArrayKind':
  60. elem_kinds = set(e.kind for e in kinds)
  61. if len(elem_kinds) == 1:
  62. elemkind, = elem_kinds
  63. else:
  64. elemkind = UndefinedKind
  65. return ArrayKind(elemkind)
  66. class NDimArray(Printable):
  67. """
  68. Examples
  69. ========
  70. Create an N-dim array of zeros:
  71. >>> from sympy import MutableDenseNDimArray
  72. >>> a = MutableDenseNDimArray.zeros(2, 3, 4)
  73. >>> a
  74. [[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]
  75. Create an N-dim array from a list;
  76. >>> a = MutableDenseNDimArray([[2, 3], [4, 5]])
  77. >>> a
  78. [[2, 3], [4, 5]]
  79. >>> b = MutableDenseNDimArray([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]])
  80. >>> b
  81. [[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]
  82. Create an N-dim array from a flat list with dimension shape:
  83. >>> a = MutableDenseNDimArray([1, 2, 3, 4, 5, 6], (2, 3))
  84. >>> a
  85. [[1, 2, 3], [4, 5, 6]]
  86. Create an N-dim array from a matrix:
  87. >>> from sympy import Matrix
  88. >>> a = Matrix([[1,2],[3,4]])
  89. >>> a
  90. Matrix([
  91. [1, 2],
  92. [3, 4]])
  93. >>> b = MutableDenseNDimArray(a)
  94. >>> b
  95. [[1, 2], [3, 4]]
  96. Arithmetic operations on N-dim arrays
  97. >>> a = MutableDenseNDimArray([1, 1, 1, 1], (2, 2))
  98. >>> b = MutableDenseNDimArray([4, 4, 4, 4], (2, 2))
  99. >>> c = a + b
  100. >>> c
  101. [[5, 5], [5, 5]]
  102. >>> a - b
  103. [[-3, -3], [-3, -3]]
  104. """
  105. _diff_wrt = True
  106. is_scalar = False
  107. def __new__(cls, iterable, shape=None, **kwargs):
  108. from sympy.tensor.array import ImmutableDenseNDimArray
  109. return ImmutableDenseNDimArray(iterable, shape, **kwargs)
  110. def _parse_index(self, index):
  111. if isinstance(index, (SYMPY_INTS, Integer)):
  112. raise ValueError("Only a tuple index is accepted")
  113. if self._loop_size == 0:
  114. raise ValueError("Index not valide with an empty array")
  115. if len(index) != self._rank:
  116. raise ValueError('Wrong number of array axes')
  117. real_index = 0
  118. # check if input index can exist in current indexing
  119. for i in range(self._rank):
  120. if (index[i] >= self.shape[i]) or (index[i] < -self.shape[i]):
  121. raise ValueError('Index ' + str(index) + ' out of border')
  122. if index[i] < 0:
  123. real_index += 1
  124. real_index = real_index*self.shape[i] + index[i]
  125. return real_index
  126. def _get_tuple_index(self, integer_index):
  127. index = []
  128. for i, sh in enumerate(reversed(self.shape)):
  129. index.append(integer_index % sh)
  130. integer_index //= sh
  131. index.reverse()
  132. return tuple(index)
  133. def _check_symbolic_index(self, index):
  134. # Check if any index is symbolic:
  135. tuple_index = (index if isinstance(index, tuple) else (index,))
  136. if any((isinstance(i, Expr) and (not i.is_number)) for i in tuple_index):
  137. for i, nth_dim in zip(tuple_index, self.shape):
  138. if ((i < 0) == True) or ((i >= nth_dim) == True):
  139. raise ValueError("index out of range")
  140. from sympy.tensor import Indexed
  141. return Indexed(self, *tuple_index)
  142. return None
  143. def _setter_iterable_check(self, value):
  144. from sympy.matrices.matrices import MatrixBase
  145. if isinstance(value, (Iterable, MatrixBase, NDimArray)):
  146. raise NotImplementedError
  147. @classmethod
  148. def _scan_iterable_shape(cls, iterable):
  149. def f(pointer):
  150. if not isinstance(pointer, Iterable):
  151. return [pointer], ()
  152. result = []
  153. elems, shapes = zip(*[f(i) for i in pointer])
  154. if len(set(shapes)) != 1:
  155. raise ValueError("could not determine shape unambiguously")
  156. for i in elems:
  157. result.extend(i)
  158. return result, (len(shapes),)+shapes[0]
  159. return f(iterable)
  160. @classmethod
  161. def _handle_ndarray_creation_inputs(cls, iterable=None, shape=None, **kwargs):
  162. from sympy.matrices.matrices import MatrixBase
  163. from sympy.tensor.array import SparseNDimArray
  164. if shape is None:
  165. if iterable is None:
  166. shape = ()
  167. iterable = ()
  168. # Construction of a sparse array from a sparse array
  169. elif isinstance(iterable, SparseNDimArray):
  170. return iterable._shape, iterable._sparse_array
  171. # Construct N-dim array from another N-dim array:
  172. elif isinstance(iterable, NDimArray):
  173. shape = iterable.shape
  174. # Construct N-dim array from an iterable (numpy arrays included):
  175. elif isinstance(iterable, Iterable):
  176. iterable, shape = cls._scan_iterable_shape(iterable)
  177. # Construct N-dim array from a Matrix:
  178. elif isinstance(iterable, MatrixBase):
  179. shape = iterable.shape
  180. else:
  181. shape = ()
  182. iterable = (iterable,)
  183. if isinstance(iterable, (Dict, dict)) and shape is not None:
  184. new_dict = iterable.copy()
  185. for k, v in new_dict.items():
  186. if isinstance(k, (tuple, Tuple)):
  187. new_key = 0
  188. for i, idx in enumerate(k):
  189. new_key = new_key * shape[i] + idx
  190. iterable[new_key] = iterable[k]
  191. del iterable[k]
  192. if isinstance(shape, (SYMPY_INTS, Integer)):
  193. shape = (shape,)
  194. if not all(isinstance(dim, (SYMPY_INTS, Integer)) for dim in shape):
  195. raise TypeError("Shape should contain integers only.")
  196. return tuple(shape), iterable
  197. def __len__(self):
  198. """Overload common function len(). Returns number of elements in array.
  199. Examples
  200. ========
  201. >>> from sympy import MutableDenseNDimArray
  202. >>> a = MutableDenseNDimArray.zeros(3, 3)
  203. >>> a
  204. [[0, 0, 0], [0, 0, 0], [0, 0, 0]]
  205. >>> len(a)
  206. 9
  207. """
  208. return self._loop_size
  209. @property
  210. def shape(self):
  211. """
  212. Returns array shape (dimension).
  213. Examples
  214. ========
  215. >>> from sympy import MutableDenseNDimArray
  216. >>> a = MutableDenseNDimArray.zeros(3, 3)
  217. >>> a.shape
  218. (3, 3)
  219. """
  220. return self._shape
  221. def rank(self):
  222. """
  223. Returns rank of array.
  224. Examples
  225. ========
  226. >>> from sympy import MutableDenseNDimArray
  227. >>> a = MutableDenseNDimArray.zeros(3,4,5,6,3)
  228. >>> a.rank()
  229. 5
  230. """
  231. return self._rank
  232. def diff(self, *args, **kwargs):
  233. """
  234. Calculate the derivative of each element in the array.
  235. Examples
  236. ========
  237. >>> from sympy import ImmutableDenseNDimArray
  238. >>> from sympy.abc import x, y
  239. >>> M = ImmutableDenseNDimArray([[x, y], [1, x*y]])
  240. >>> M.diff(x)
  241. [[1, 0], [0, y]]
  242. """
  243. from sympy.tensor.array.array_derivatives import ArrayDerivative
  244. kwargs.setdefault('evaluate', True)
  245. return ArrayDerivative(self.as_immutable(), *args, **kwargs)
  246. def _eval_derivative(self, base):
  247. # Types are (base: scalar, self: array)
  248. return self.applyfunc(lambda x: base.diff(x))
  249. def _eval_derivative_n_times(self, s, n):
  250. return Basic._eval_derivative_n_times(self, s, n)
  251. def applyfunc(self, f):
  252. """Apply a function to each element of the N-dim array.
  253. Examples
  254. ========
  255. >>> from sympy import ImmutableDenseNDimArray
  256. >>> m = ImmutableDenseNDimArray([i*2+j for i in range(2) for j in range(2)], (2, 2))
  257. >>> m
  258. [[0, 1], [2, 3]]
  259. >>> m.applyfunc(lambda i: 2*i)
  260. [[0, 2], [4, 6]]
  261. """
  262. from sympy.tensor.array import SparseNDimArray
  263. from sympy.tensor.array.arrayop import Flatten
  264. if isinstance(self, SparseNDimArray) and f(S.Zero) == 0:
  265. return type(self)({k: f(v) for k, v in self._sparse_array.items() if f(v) != 0}, self.shape)
  266. return type(self)(map(f, Flatten(self)), self.shape)
  267. def _sympystr(self, printer):
  268. def f(sh, shape_left, i, j):
  269. if len(shape_left) == 1:
  270. return "["+", ".join([printer._print(self[self._get_tuple_index(e)]) for e in range(i, j)])+"]"
  271. sh //= shape_left[0]
  272. return "[" + ", ".join([f(sh, shape_left[1:], i+e*sh, i+(e+1)*sh) for e in range(shape_left[0])]) + "]" # + "\n"*len(shape_left)
  273. if self.rank() == 0:
  274. return printer._print(self[()])
  275. return f(self._loop_size, self.shape, 0, self._loop_size)
  276. def tolist(self):
  277. """
  278. Converting MutableDenseNDimArray to one-dim list
  279. Examples
  280. ========
  281. >>> from sympy import MutableDenseNDimArray
  282. >>> a = MutableDenseNDimArray([1, 2, 3, 4], (2, 2))
  283. >>> a
  284. [[1, 2], [3, 4]]
  285. >>> b = a.tolist()
  286. >>> b
  287. [[1, 2], [3, 4]]
  288. """
  289. def f(sh, shape_left, i, j):
  290. if len(shape_left) == 1:
  291. return [self[self._get_tuple_index(e)] for e in range(i, j)]
  292. result = []
  293. sh //= shape_left[0]
  294. for e in range(shape_left[0]):
  295. result.append(f(sh, shape_left[1:], i+e*sh, i+(e+1)*sh))
  296. return result
  297. return f(self._loop_size, self.shape, 0, self._loop_size)
  298. def __add__(self, other):
  299. from sympy.tensor.array.arrayop import Flatten
  300. if not isinstance(other, NDimArray):
  301. return NotImplemented
  302. if self.shape != other.shape:
  303. raise ValueError("array shape mismatch")
  304. result_list = [i+j for i,j in zip(Flatten(self), Flatten(other))]
  305. return type(self)(result_list, self.shape)
  306. def __sub__(self, other):
  307. from sympy.tensor.array.arrayop import Flatten
  308. if not isinstance(other, NDimArray):
  309. return NotImplemented
  310. if self.shape != other.shape:
  311. raise ValueError("array shape mismatch")
  312. result_list = [i-j for i,j in zip(Flatten(self), Flatten(other))]
  313. return type(self)(result_list, self.shape)
  314. def __mul__(self, other):
  315. from sympy.matrices.matrices import MatrixBase
  316. from sympy.tensor.array import SparseNDimArray
  317. from sympy.tensor.array.arrayop import Flatten
  318. if isinstance(other, (Iterable, NDimArray, MatrixBase)):
  319. raise ValueError("scalar expected, use tensorproduct(...) for tensorial product")
  320. other = sympify(other)
  321. if isinstance(self, SparseNDimArray):
  322. if other.is_zero:
  323. return type(self)({}, self.shape)
  324. return type(self)({k: other*v for (k, v) in self._sparse_array.items()}, self.shape)
  325. result_list = [i*other for i in Flatten(self)]
  326. return type(self)(result_list, self.shape)
  327. def __rmul__(self, other):
  328. from sympy.matrices.matrices import MatrixBase
  329. from sympy.tensor.array import SparseNDimArray
  330. from sympy.tensor.array.arrayop import Flatten
  331. if isinstance(other, (Iterable, NDimArray, MatrixBase)):
  332. raise ValueError("scalar expected, use tensorproduct(...) for tensorial product")
  333. other = sympify(other)
  334. if isinstance(self, SparseNDimArray):
  335. if other.is_zero:
  336. return type(self)({}, self.shape)
  337. return type(self)({k: other*v for (k, v) in self._sparse_array.items()}, self.shape)
  338. result_list = [other*i for i in Flatten(self)]
  339. return type(self)(result_list, self.shape)
  340. def __truediv__(self, other):
  341. from sympy.matrices.matrices import MatrixBase
  342. from sympy.tensor.array import SparseNDimArray
  343. from sympy.tensor.array.arrayop import Flatten
  344. if isinstance(other, (Iterable, NDimArray, MatrixBase)):
  345. raise ValueError("scalar expected")
  346. other = sympify(other)
  347. if isinstance(self, SparseNDimArray) and other != S.Zero:
  348. return type(self)({k: v/other for (k, v) in self._sparse_array.items()}, self.shape)
  349. result_list = [i/other for i in Flatten(self)]
  350. return type(self)(result_list, self.shape)
  351. def __rtruediv__(self, other):
  352. raise NotImplementedError('unsupported operation on NDimArray')
  353. def __neg__(self):
  354. from sympy.tensor.array import SparseNDimArray
  355. from sympy.tensor.array.arrayop import Flatten
  356. if isinstance(self, SparseNDimArray):
  357. return type(self)({k: -v for (k, v) in self._sparse_array.items()}, self.shape)
  358. result_list = [-i for i in Flatten(self)]
  359. return type(self)(result_list, self.shape)
  360. def __iter__(self):
  361. def iterator():
  362. if self._shape:
  363. for i in range(self._shape[0]):
  364. yield self[i]
  365. else:
  366. yield self[()]
  367. return iterator()
  368. def __eq__(self, other):
  369. """
  370. NDimArray instances can be compared to each other.
  371. Instances equal if they have same shape and data.
  372. Examples
  373. ========
  374. >>> from sympy import MutableDenseNDimArray
  375. >>> a = MutableDenseNDimArray.zeros(2, 3)
  376. >>> b = MutableDenseNDimArray.zeros(2, 3)
  377. >>> a == b
  378. True
  379. >>> c = a.reshape(3, 2)
  380. >>> c == b
  381. False
  382. >>> a[0,0] = 1
  383. >>> b[0,0] = 2
  384. >>> a == b
  385. False
  386. """
  387. from sympy.tensor.array import SparseNDimArray
  388. if not isinstance(other, NDimArray):
  389. return False
  390. if not self.shape == other.shape:
  391. return False
  392. if isinstance(self, SparseNDimArray) and isinstance(other, SparseNDimArray):
  393. return dict(self._sparse_array) == dict(other._sparse_array)
  394. return list(self) == list(other)
  395. def __ne__(self, other):
  396. return not self == other
  397. def _eval_transpose(self):
  398. if self.rank() != 2:
  399. raise ValueError("array rank not 2")
  400. from .arrayop import permutedims
  401. return permutedims(self, (1, 0))
  402. def transpose(self):
  403. return self._eval_transpose()
  404. def _eval_conjugate(self):
  405. from sympy.tensor.array.arrayop import Flatten
  406. return self.func([i.conjugate() for i in Flatten(self)], self.shape)
  407. def conjugate(self):
  408. return self._eval_conjugate()
  409. def _eval_adjoint(self):
  410. return self.transpose().conjugate()
  411. def adjoint(self):
  412. return self._eval_adjoint()
  413. def _slice_expand(self, s, dim):
  414. if not isinstance(s, slice):
  415. return (s,)
  416. start, stop, step = s.indices(dim)
  417. return [start + i*step for i in range((stop-start)//step)]
  418. def _get_slice_data_for_array_access(self, index):
  419. sl_factors = [self._slice_expand(i, dim) for (i, dim) in zip(index, self.shape)]
  420. eindices = itertools.product(*sl_factors)
  421. return sl_factors, eindices
  422. def _get_slice_data_for_array_assignment(self, index, value):
  423. if not isinstance(value, NDimArray):
  424. value = type(self)(value)
  425. sl_factors, eindices = self._get_slice_data_for_array_access(index)
  426. slice_offsets = [min(i) if isinstance(i, list) else None for i in sl_factors]
  427. # TODO: add checks for dimensions for `value`?
  428. return value, eindices, slice_offsets
  429. @classmethod
  430. def _check_special_bounds(cls, flat_list, shape):
  431. if shape == () and len(flat_list) != 1:
  432. raise ValueError("arrays without shape need one scalar value")
  433. if shape == (0,) and len(flat_list) > 0:
  434. raise ValueError("if array shape is (0,) there cannot be elements")
  435. def _check_index_for_getitem(self, index):
  436. if isinstance(index, (SYMPY_INTS, Integer, slice)):
  437. index = (index, )
  438. if len(index) < self.rank():
  439. index = tuple([i for i in index] + \
  440. [slice(None) for i in range(len(index), self.rank())])
  441. if len(index) > self.rank():
  442. raise ValueError('Dimension of index greater than rank of array')
  443. return index
  444. class ImmutableNDimArray(NDimArray, Basic):
  445. _op_priority = 11.0
  446. def __hash__(self):
  447. return Basic.__hash__(self)
  448. def as_immutable(self):
  449. return self
  450. def as_mutable(self):
  451. raise NotImplementedError("abstract method")