m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

209 lines
6.2 KiB

6 months ago
  1. import functools
  2. from typing import List
  3. from sympy.core.basic import Basic
  4. from sympy.core.containers import Tuple
  5. from sympy.core.singleton import S
  6. from sympy.core.sympify import _sympify
  7. from sympy.simplify.simplify import simplify
  8. from sympy.tensor.array.mutable_ndim_array import MutableNDimArray
  9. from sympy.tensor.array.ndim_array import NDimArray, ImmutableNDimArray, ArrayKind
  10. from sympy.utilities.iterables import flatten
  11. class DenseNDimArray(NDimArray):
  12. _array: List[Basic]
  13. def __new__(self, *args, **kwargs):
  14. return ImmutableDenseNDimArray(*args, **kwargs)
  15. @property
  16. def kind(self) -> ArrayKind:
  17. return ArrayKind._union(self._array)
  18. def __getitem__(self, index):
  19. """
  20. Allows to get items from N-dim array.
  21. Examples
  22. ========
  23. >>> from sympy import MutableDenseNDimArray
  24. >>> a = MutableDenseNDimArray([0, 1, 2, 3], (2, 2))
  25. >>> a
  26. [[0, 1], [2, 3]]
  27. >>> a[0, 0]
  28. 0
  29. >>> a[1, 1]
  30. 3
  31. >>> a[0]
  32. [0, 1]
  33. >>> a[1]
  34. [2, 3]
  35. Symbolic index:
  36. >>> from sympy.abc import i, j
  37. >>> a[i, j]
  38. [[0, 1], [2, 3]][i, j]
  39. Replace `i` and `j` to get element `(1, 1)`:
  40. >>> a[i, j].subs({i: 1, j: 1})
  41. 3
  42. """
  43. syindex = self._check_symbolic_index(index)
  44. if syindex is not None:
  45. return syindex
  46. index = self._check_index_for_getitem(index)
  47. if isinstance(index, tuple) and any(isinstance(i, slice) for i in index):
  48. sl_factors, eindices = self._get_slice_data_for_array_access(index)
  49. array = [self._array[self._parse_index(i)] for i in eindices]
  50. nshape = [len(el) for i, el in enumerate(sl_factors) if isinstance(index[i], slice)]
  51. return type(self)(array, nshape)
  52. else:
  53. index = self._parse_index(index)
  54. return self._array[index]
  55. @classmethod
  56. def zeros(cls, *shape):
  57. list_length = functools.reduce(lambda x, y: x*y, shape, S.One)
  58. return cls._new(([0]*list_length,), shape)
  59. def tomatrix(self):
  60. """
  61. Converts MutableDenseNDimArray to Matrix. Can convert only 2-dim array, else will raise error.
  62. Examples
  63. ========
  64. >>> from sympy import MutableDenseNDimArray
  65. >>> a = MutableDenseNDimArray([1 for i in range(9)], (3, 3))
  66. >>> b = a.tomatrix()
  67. >>> b
  68. Matrix([
  69. [1, 1, 1],
  70. [1, 1, 1],
  71. [1, 1, 1]])
  72. """
  73. from sympy.matrices import Matrix
  74. if self.rank() != 2:
  75. raise ValueError('Dimensions must be of size of 2')
  76. return Matrix(self.shape[0], self.shape[1], self._array)
  77. def reshape(self, *newshape):
  78. """
  79. Returns MutableDenseNDimArray instance with new shape. Elements number
  80. must be suitable to new shape. The only argument of method sets
  81. new shape.
  82. Examples
  83. ========
  84. >>> from sympy import MutableDenseNDimArray
  85. >>> a = MutableDenseNDimArray([1, 2, 3, 4, 5, 6], (2, 3))
  86. >>> a.shape
  87. (2, 3)
  88. >>> a
  89. [[1, 2, 3], [4, 5, 6]]
  90. >>> b = a.reshape(3, 2)
  91. >>> b.shape
  92. (3, 2)
  93. >>> b
  94. [[1, 2], [3, 4], [5, 6]]
  95. """
  96. new_total_size = functools.reduce(lambda x,y: x*y, newshape)
  97. if new_total_size != self._loop_size:
  98. raise ValueError("Invalid reshape parameters " + newshape)
  99. # there is no `.func` as this class does not subtype `Basic`:
  100. return type(self)(self._array, newshape)
  101. class ImmutableDenseNDimArray(DenseNDimArray, ImmutableNDimArray): # type: ignore
  102. """
  103. """
  104. def __new__(cls, iterable, shape=None, **kwargs):
  105. return cls._new(iterable, shape, **kwargs)
  106. @classmethod
  107. def _new(cls, iterable, shape, **kwargs):
  108. shape, flat_list = cls._handle_ndarray_creation_inputs(iterable, shape, **kwargs)
  109. shape = Tuple(*map(_sympify, shape))
  110. cls._check_special_bounds(flat_list, shape)
  111. flat_list = flatten(flat_list)
  112. flat_list = Tuple(*flat_list)
  113. self = Basic.__new__(cls, flat_list, shape, **kwargs)
  114. self._shape = shape
  115. self._array = list(flat_list)
  116. self._rank = len(shape)
  117. self._loop_size = functools.reduce(lambda x,y: x*y, shape, 1)
  118. return self
  119. def __setitem__(self, index, value):
  120. raise TypeError('immutable N-dim array')
  121. def as_mutable(self):
  122. return MutableDenseNDimArray(self)
  123. def _eval_simplify(self, **kwargs):
  124. return self.applyfunc(simplify)
  125. class MutableDenseNDimArray(DenseNDimArray, MutableNDimArray):
  126. def __new__(cls, iterable=None, shape=None, **kwargs):
  127. return cls._new(iterable, shape, **kwargs)
  128. @classmethod
  129. def _new(cls, iterable, shape, **kwargs):
  130. shape, flat_list = cls._handle_ndarray_creation_inputs(iterable, shape, **kwargs)
  131. flat_list = flatten(flat_list)
  132. self = object.__new__(cls)
  133. self._shape = shape
  134. self._array = list(flat_list)
  135. self._rank = len(shape)
  136. self._loop_size = functools.reduce(lambda x,y: x*y, shape) if shape else len(flat_list)
  137. return self
  138. def __setitem__(self, index, value):
  139. """Allows to set items to MutableDenseNDimArray.
  140. Examples
  141. ========
  142. >>> from sympy import MutableDenseNDimArray
  143. >>> a = MutableDenseNDimArray.zeros(2, 2)
  144. >>> a[0,0] = 1
  145. >>> a[1,1] = 1
  146. >>> a
  147. [[1, 0], [0, 1]]
  148. """
  149. if isinstance(index, tuple) and any(isinstance(i, slice) for i in index):
  150. value, eindices, slice_offsets = self._get_slice_data_for_array_assignment(index, value)
  151. for i in eindices:
  152. other_i = [ind - j for ind, j in zip(i, slice_offsets) if j is not None]
  153. self._array[self._parse_index(i)] = value[other_i]
  154. else:
  155. index = self._parse_index(index)
  156. self._setter_iterable_check(value)
  157. value = _sympify(value)
  158. self._array[index] = value
  159. def as_immutable(self):
  160. return ImmutableDenseNDimArray(self)
  161. @property
  162. def free_symbols(self):
  163. return {i for j in self._array for i in j.free_symbols}