图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

209 lines
6.2 KiB

  1. import functools
  2. from typing import List
  3. from sympy.core.basic import Basic
  4. from sympy.core.containers import Tuple
  5. from sympy.core.singleton import S
  6. from sympy.core.sympify import _sympify
  7. from sympy.simplify.simplify import simplify
  8. from sympy.tensor.array.mutable_ndim_array import MutableNDimArray
  9. from sympy.tensor.array.ndim_array import NDimArray, ImmutableNDimArray, ArrayKind
  10. from sympy.utilities.iterables import flatten
  11. class DenseNDimArray(NDimArray):
  12. _array: List[Basic]
  13. def __new__(self, *args, **kwargs):
  14. return ImmutableDenseNDimArray(*args, **kwargs)
  15. @property
  16. def kind(self) -> ArrayKind:
  17. return ArrayKind._union(self._array)
  18. def __getitem__(self, index):
  19. """
  20. Allows to get items from N-dim array.
  21. Examples
  22. ========
  23. >>> from sympy import MutableDenseNDimArray
  24. >>> a = MutableDenseNDimArray([0, 1, 2, 3], (2, 2))
  25. >>> a
  26. [[0, 1], [2, 3]]
  27. >>> a[0, 0]
  28. 0
  29. >>> a[1, 1]
  30. 3
  31. >>> a[0]
  32. [0, 1]
  33. >>> a[1]
  34. [2, 3]
  35. Symbolic index:
  36. >>> from sympy.abc import i, j
  37. >>> a[i, j]
  38. [[0, 1], [2, 3]][i, j]
  39. Replace `i` and `j` to get element `(1, 1)`:
  40. >>> a[i, j].subs({i: 1, j: 1})
  41. 3
  42. """
  43. syindex = self._check_symbolic_index(index)
  44. if syindex is not None:
  45. return syindex
  46. index = self._check_index_for_getitem(index)
  47. if isinstance(index, tuple) and any(isinstance(i, slice) for i in index):
  48. sl_factors, eindices = self._get_slice_data_for_array_access(index)
  49. array = [self._array[self._parse_index(i)] for i in eindices]
  50. nshape = [len(el) for i, el in enumerate(sl_factors) if isinstance(index[i], slice)]
  51. return type(self)(array, nshape)
  52. else:
  53. index = self._parse_index(index)
  54. return self._array[index]
  55. @classmethod
  56. def zeros(cls, *shape):
  57. list_length = functools.reduce(lambda x, y: x*y, shape, S.One)
  58. return cls._new(([0]*list_length,), shape)
  59. def tomatrix(self):
  60. """
  61. Converts MutableDenseNDimArray to Matrix. Can convert only 2-dim array, else will raise error.
  62. Examples
  63. ========
  64. >>> from sympy import MutableDenseNDimArray
  65. >>> a = MutableDenseNDimArray([1 for i in range(9)], (3, 3))
  66. >>> b = a.tomatrix()
  67. >>> b
  68. Matrix([
  69. [1, 1, 1],
  70. [1, 1, 1],
  71. [1, 1, 1]])
  72. """
  73. from sympy.matrices import Matrix
  74. if self.rank() != 2:
  75. raise ValueError('Dimensions must be of size of 2')
  76. return Matrix(self.shape[0], self.shape[1], self._array)
  77. def reshape(self, *newshape):
  78. """
  79. Returns MutableDenseNDimArray instance with new shape. Elements number
  80. must be suitable to new shape. The only argument of method sets
  81. new shape.
  82. Examples
  83. ========
  84. >>> from sympy import MutableDenseNDimArray
  85. >>> a = MutableDenseNDimArray([1, 2, 3, 4, 5, 6], (2, 3))
  86. >>> a.shape
  87. (2, 3)
  88. >>> a
  89. [[1, 2, 3], [4, 5, 6]]
  90. >>> b = a.reshape(3, 2)
  91. >>> b.shape
  92. (3, 2)
  93. >>> b
  94. [[1, 2], [3, 4], [5, 6]]
  95. """
  96. new_total_size = functools.reduce(lambda x,y: x*y, newshape)
  97. if new_total_size != self._loop_size:
  98. raise ValueError("Invalid reshape parameters " + newshape)
  99. # there is no `.func` as this class does not subtype `Basic`:
  100. return type(self)(self._array, newshape)
  101. class ImmutableDenseNDimArray(DenseNDimArray, ImmutableNDimArray): # type: ignore
  102. """
  103. """
  104. def __new__(cls, iterable, shape=None, **kwargs):
  105. return cls._new(iterable, shape, **kwargs)
  106. @classmethod
  107. def _new(cls, iterable, shape, **kwargs):
  108. shape, flat_list = cls._handle_ndarray_creation_inputs(iterable, shape, **kwargs)
  109. shape = Tuple(*map(_sympify, shape))
  110. cls._check_special_bounds(flat_list, shape)
  111. flat_list = flatten(flat_list)
  112. flat_list = Tuple(*flat_list)
  113. self = Basic.__new__(cls, flat_list, shape, **kwargs)
  114. self._shape = shape
  115. self._array = list(flat_list)
  116. self._rank = len(shape)
  117. self._loop_size = functools.reduce(lambda x,y: x*y, shape, 1)
  118. return self
  119. def __setitem__(self, index, value):
  120. raise TypeError('immutable N-dim array')
  121. def as_mutable(self):
  122. return MutableDenseNDimArray(self)
  123. def _eval_simplify(self, **kwargs):
  124. return self.applyfunc(simplify)
  125. class MutableDenseNDimArray(DenseNDimArray, MutableNDimArray):
  126. def __new__(cls, iterable=None, shape=None, **kwargs):
  127. return cls._new(iterable, shape, **kwargs)
  128. @classmethod
  129. def _new(cls, iterable, shape, **kwargs):
  130. shape, flat_list = cls._handle_ndarray_creation_inputs(iterable, shape, **kwargs)
  131. flat_list = flatten(flat_list)
  132. self = object.__new__(cls)
  133. self._shape = shape
  134. self._array = list(flat_list)
  135. self._rank = len(shape)
  136. self._loop_size = functools.reduce(lambda x,y: x*y, shape) if shape else len(flat_list)
  137. return self
  138. def __setitem__(self, index, value):
  139. """Allows to set items to MutableDenseNDimArray.
  140. Examples
  141. ========
  142. >>> from sympy import MutableDenseNDimArray
  143. >>> a = MutableDenseNDimArray.zeros(2, 2)
  144. >>> a[0,0] = 1
  145. >>> a[1,1] = 1
  146. >>> a
  147. [[1, 0], [0, 1]]
  148. """
  149. if isinstance(index, tuple) and any(isinstance(i, slice) for i in index):
  150. value, eindices, slice_offsets = self._get_slice_data_for_array_assignment(index, value)
  151. for i in eindices:
  152. other_i = [ind - j for ind, j in zip(i, slice_offsets) if j is not None]
  153. self._array[self._parse_index(i)] = value[other_i]
  154. else:
  155. index = self._parse_index(index)
  156. self._setter_iterable_check(value)
  157. value = _sympify(value)
  158. self._array[index] = value
  159. def as_immutable(self):
  160. return ImmutableDenseNDimArray(self)
  161. @property
  162. def free_symbols(self):
  163. return {i for j in self._array for i in j.free_symbols}