图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

503 lines
16 KiB

  1. from sympy.core import symbols, Symbol, Tuple, oo, Dummy
  2. from sympy.tensor.indexed import IndexException
  3. from sympy.testing.pytest import raises
  4. from sympy.utilities.iterables import iterable
  5. # import test:
  6. from sympy.concrete.summations import Sum
  7. from sympy.core.function import Function, Subs, Derivative
  8. from sympy.core.relational import (StrictLessThan, GreaterThan,
  9. StrictGreaterThan, LessThan)
  10. from sympy.core.singleton import S
  11. from sympy.functions.elementary.exponential import exp, log
  12. from sympy.functions.elementary.trigonometric import cos, sin
  13. from sympy.functions.special.tensor_functions import KroneckerDelta
  14. from sympy.series.order import Order
  15. from sympy.sets.fancysets import Range
  16. from sympy.tensor.indexed import IndexedBase, Idx, Indexed
  17. def test_Idx_construction():
  18. i, a, b = symbols('i a b', integer=True)
  19. assert Idx(i) != Idx(i, 1)
  20. assert Idx(i, a) == Idx(i, (0, a - 1))
  21. assert Idx(i, oo) == Idx(i, (0, oo))
  22. x = symbols('x', integer=False)
  23. raises(TypeError, lambda: Idx(x))
  24. raises(TypeError, lambda: Idx(0.5))
  25. raises(TypeError, lambda: Idx(i, x))
  26. raises(TypeError, lambda: Idx(i, 0.5))
  27. raises(TypeError, lambda: Idx(i, (x, 5)))
  28. raises(TypeError, lambda: Idx(i, (2, x)))
  29. raises(TypeError, lambda: Idx(i, (2, 3.5)))
  30. def test_Idx_properties():
  31. i, a, b = symbols('i a b', integer=True)
  32. assert Idx(i).is_integer
  33. assert Idx(i).name == 'i'
  34. assert Idx(i + 2).name == 'i + 2'
  35. assert Idx('foo').name == 'foo'
  36. def test_Idx_bounds():
  37. i, a, b = symbols('i a b', integer=True)
  38. assert Idx(i).lower is None
  39. assert Idx(i).upper is None
  40. assert Idx(i, a).lower == 0
  41. assert Idx(i, a).upper == a - 1
  42. assert Idx(i, 5).lower == 0
  43. assert Idx(i, 5).upper == 4
  44. assert Idx(i, oo).lower == 0
  45. assert Idx(i, oo).upper is oo
  46. assert Idx(i, (a, b)).lower == a
  47. assert Idx(i, (a, b)).upper == b
  48. assert Idx(i, (1, 5)).lower == 1
  49. assert Idx(i, (1, 5)).upper == 5
  50. assert Idx(i, (-oo, oo)).lower is -oo
  51. assert Idx(i, (-oo, oo)).upper is oo
  52. def test_Idx_fixed_bounds():
  53. i, a, b, x = symbols('i a b x', integer=True)
  54. assert Idx(x).lower is None
  55. assert Idx(x).upper is None
  56. assert Idx(x, a).lower == 0
  57. assert Idx(x, a).upper == a - 1
  58. assert Idx(x, 5).lower == 0
  59. assert Idx(x, 5).upper == 4
  60. assert Idx(x, oo).lower == 0
  61. assert Idx(x, oo).upper is oo
  62. assert Idx(x, (a, b)).lower == a
  63. assert Idx(x, (a, b)).upper == b
  64. assert Idx(x, (1, 5)).lower == 1
  65. assert Idx(x, (1, 5)).upper == 5
  66. assert Idx(x, (-oo, oo)).lower is -oo
  67. assert Idx(x, (-oo, oo)).upper is oo
  68. def test_Idx_inequalities():
  69. i14 = Idx("i14", (1, 4))
  70. i79 = Idx("i79", (7, 9))
  71. i46 = Idx("i46", (4, 6))
  72. i35 = Idx("i35", (3, 5))
  73. assert i14 <= 5
  74. assert i14 < 5
  75. assert not (i14 >= 5)
  76. assert not (i14 > 5)
  77. assert 5 >= i14
  78. assert 5 > i14
  79. assert not (5 <= i14)
  80. assert not (5 < i14)
  81. assert LessThan(i14, 5)
  82. assert StrictLessThan(i14, 5)
  83. assert not GreaterThan(i14, 5)
  84. assert not StrictGreaterThan(i14, 5)
  85. assert i14 <= 4
  86. assert isinstance(i14 < 4, StrictLessThan)
  87. assert isinstance(i14 >= 4, GreaterThan)
  88. assert not (i14 > 4)
  89. assert isinstance(i14 <= 1, LessThan)
  90. assert not (i14 < 1)
  91. assert i14 >= 1
  92. assert isinstance(i14 > 1, StrictGreaterThan)
  93. assert not (i14 <= 0)
  94. assert not (i14 < 0)
  95. assert i14 >= 0
  96. assert i14 > 0
  97. from sympy.abc import x
  98. assert isinstance(i14 < x, StrictLessThan)
  99. assert isinstance(i14 > x, StrictGreaterThan)
  100. assert isinstance(i14 <= x, LessThan)
  101. assert isinstance(i14 >= x, GreaterThan)
  102. assert i14 < i79
  103. assert i14 <= i79
  104. assert not (i14 > i79)
  105. assert not (i14 >= i79)
  106. assert i14 <= i46
  107. assert isinstance(i14 < i46, StrictLessThan)
  108. assert isinstance(i14 >= i46, GreaterThan)
  109. assert not (i14 > i46)
  110. assert isinstance(i14 < i35, StrictLessThan)
  111. assert isinstance(i14 > i35, StrictGreaterThan)
  112. assert isinstance(i14 <= i35, LessThan)
  113. assert isinstance(i14 >= i35, GreaterThan)
  114. iNone1 = Idx("iNone1")
  115. iNone2 = Idx("iNone2")
  116. assert isinstance(iNone1 < iNone2, StrictLessThan)
  117. assert isinstance(iNone1 > iNone2, StrictGreaterThan)
  118. assert isinstance(iNone1 <= iNone2, LessThan)
  119. assert isinstance(iNone1 >= iNone2, GreaterThan)
  120. def test_Idx_inequalities_current_fails():
  121. i14 = Idx("i14", (1, 4))
  122. assert S(5) >= i14
  123. assert S(5) > i14
  124. assert not (S(5) <= i14)
  125. assert not (S(5) < i14)
  126. def test_Idx_func_args():
  127. i, a, b = symbols('i a b', integer=True)
  128. ii = Idx(i)
  129. assert ii.func(*ii.args) == ii
  130. ii = Idx(i, a)
  131. assert ii.func(*ii.args) == ii
  132. ii = Idx(i, (a, b))
  133. assert ii.func(*ii.args) == ii
  134. def test_Idx_subs():
  135. i, a, b = symbols('i a b', integer=True)
  136. assert Idx(i, a).subs(a, b) == Idx(i, b)
  137. assert Idx(i, a).subs(i, b) == Idx(b, a)
  138. assert Idx(i).subs(i, 2) == Idx(2)
  139. assert Idx(i, a).subs(a, 2) == Idx(i, 2)
  140. assert Idx(i, (a, b)).subs(i, 2) == Idx(2, (a, b))
  141. def test_IndexedBase_sugar():
  142. i, j = symbols('i j', integer=True)
  143. a = symbols('a')
  144. A1 = Indexed(a, i, j)
  145. A2 = IndexedBase(a)
  146. assert A1 == A2[i, j]
  147. assert A1 == A2[(i, j)]
  148. assert A1 == A2[[i, j]]
  149. assert A1 == A2[Tuple(i, j)]
  150. assert all(a.is_Integer for a in A2[1, 0].args[1:])
  151. def test_IndexedBase_subs():
  152. i = symbols('i', integer=True)
  153. a, b = symbols('a b')
  154. A = IndexedBase(a)
  155. B = IndexedBase(b)
  156. assert A[i] == B[i].subs(b, a)
  157. C = {1: 2}
  158. assert C[1] == A[1].subs(A, C)
  159. def test_IndexedBase_shape():
  160. i, j, m, n = symbols('i j m n', integer=True)
  161. a = IndexedBase('a', shape=(m, m))
  162. b = IndexedBase('a', shape=(m, n))
  163. assert b.shape == Tuple(m, n)
  164. assert a[i, j] != b[i, j]
  165. assert a[i, j] == b[i, j].subs(n, m)
  166. assert b.func(*b.args) == b
  167. assert b[i, j].func(*b[i, j].args) == b[i, j]
  168. raises(IndexException, lambda: b[i])
  169. raises(IndexException, lambda: b[i, i, j])
  170. F = IndexedBase("F", shape=m)
  171. assert F.shape == Tuple(m)
  172. assert F[i].subs(i, j) == F[j]
  173. raises(IndexException, lambda: F[i, j])
  174. def test_IndexedBase_assumptions():
  175. i = Symbol('i', integer=True)
  176. a = Symbol('a')
  177. A = IndexedBase(a, positive=True)
  178. for c in (A, A[i]):
  179. assert c.is_real
  180. assert c.is_complex
  181. assert not c.is_imaginary
  182. assert c.is_nonnegative
  183. assert c.is_nonzero
  184. assert c.is_commutative
  185. assert log(exp(c)) == c
  186. assert A != IndexedBase(a)
  187. assert A == IndexedBase(a, positive=True, real=True)
  188. assert A[i] != Indexed(a, i)
  189. def test_IndexedBase_assumptions_inheritance():
  190. I = Symbol('I', integer=True)
  191. I_inherit = IndexedBase(I)
  192. I_explicit = IndexedBase('I', integer=True)
  193. assert I_inherit.is_integer
  194. assert I_explicit.is_integer
  195. assert I_inherit.label.is_integer
  196. assert I_explicit.label.is_integer
  197. assert I_inherit == I_explicit
  198. def test_issue_17652():
  199. """Regression test issue #17652.
  200. IndexedBase.label should not upcast subclasses of Symbol
  201. """
  202. class SubClass(Symbol):
  203. pass
  204. x = SubClass('X')
  205. assert type(x) == SubClass
  206. base = IndexedBase(x)
  207. assert type(x) == SubClass
  208. assert type(base.label) == SubClass
  209. def test_Indexed_constructor():
  210. i, j = symbols('i j', integer=True)
  211. A = Indexed('A', i, j)
  212. assert A == Indexed(Symbol('A'), i, j)
  213. assert A == Indexed(IndexedBase('A'), i, j)
  214. raises(TypeError, lambda: Indexed(A, i, j))
  215. raises(IndexException, lambda: Indexed("A"))
  216. assert A.free_symbols == {A, A.base.label, i, j}
  217. def test_Indexed_func_args():
  218. i, j = symbols('i j', integer=True)
  219. a = symbols('a')
  220. A = Indexed(a, i, j)
  221. assert A == A.func(*A.args)
  222. def test_Indexed_subs():
  223. i, j, k = symbols('i j k', integer=True)
  224. a, b = symbols('a b')
  225. A = IndexedBase(a)
  226. B = IndexedBase(b)
  227. assert A[i, j] == B[i, j].subs(b, a)
  228. assert A[i, j] == A[i, k].subs(k, j)
  229. def test_Indexed_properties():
  230. i, j = symbols('i j', integer=True)
  231. A = Indexed('A', i, j)
  232. assert A.name == 'A[i, j]'
  233. assert A.rank == 2
  234. assert A.indices == (i, j)
  235. assert A.base == IndexedBase('A')
  236. assert A.ranges == [None, None]
  237. raises(IndexException, lambda: A.shape)
  238. n, m = symbols('n m', integer=True)
  239. assert Indexed('A', Idx(
  240. i, m), Idx(j, n)).ranges == [Tuple(0, m - 1), Tuple(0, n - 1)]
  241. assert Indexed('A', Idx(i, m), Idx(j, n)).shape == Tuple(m, n)
  242. raises(IndexException, lambda: Indexed("A", Idx(i, m), Idx(j)).shape)
  243. def test_Indexed_shape_precedence():
  244. i, j = symbols('i j', integer=True)
  245. o, p = symbols('o p', integer=True)
  246. n, m = symbols('n m', integer=True)
  247. a = IndexedBase('a', shape=(o, p))
  248. assert a.shape == Tuple(o, p)
  249. assert Indexed(
  250. a, Idx(i, m), Idx(j, n)).ranges == [Tuple(0, m - 1), Tuple(0, n - 1)]
  251. assert Indexed(a, Idx(i, m), Idx(j, n)).shape == Tuple(o, p)
  252. assert Indexed(
  253. a, Idx(i, m), Idx(j)).ranges == [Tuple(0, m - 1), (None, None)]
  254. assert Indexed(a, Idx(i, m), Idx(j)).shape == Tuple(o, p)
  255. def test_complex_indices():
  256. i, j = symbols('i j', integer=True)
  257. A = Indexed('A', i, i + j)
  258. assert A.rank == 2
  259. assert A.indices == (i, i + j)
  260. def test_not_interable():
  261. i, j = symbols('i j', integer=True)
  262. A = Indexed('A', i, i + j)
  263. assert not iterable(A)
  264. def test_Indexed_coeff():
  265. N = Symbol('N', integer=True)
  266. len_y = N
  267. i = Idx('i', len_y-1)
  268. y = IndexedBase('y', shape=(len_y,))
  269. a = (1/y[i+1]*y[i]).coeff(y[i])
  270. b = (y[i]/y[i+1]).coeff(y[i])
  271. assert a == b
  272. def test_differentiation():
  273. from sympy.functions.special.tensor_functions import KroneckerDelta
  274. i, j, k, l = symbols('i j k l', cls=Idx)
  275. a = symbols('a')
  276. m, n = symbols("m, n", integer=True, finite=True)
  277. assert m.is_real
  278. h, L = symbols('h L', cls=IndexedBase)
  279. hi, hj = h[i], h[j]
  280. expr = hi
  281. assert expr.diff(hj) == KroneckerDelta(i, j)
  282. assert expr.diff(hi) == KroneckerDelta(i, i)
  283. expr = S(2) * hi
  284. assert expr.diff(hj) == S(2) * KroneckerDelta(i, j)
  285. assert expr.diff(hi) == S(2) * KroneckerDelta(i, i)
  286. assert expr.diff(a) is S.Zero
  287. assert Sum(expr, (i, -oo, oo)).diff(hj) == Sum(2*KroneckerDelta(i, j), (i, -oo, oo))
  288. assert Sum(expr.diff(hj), (i, -oo, oo)) == Sum(2*KroneckerDelta(i, j), (i, -oo, oo))
  289. assert Sum(expr, (i, -oo, oo)).diff(hj).doit() == 2
  290. assert Sum(expr.diff(hi), (i, -oo, oo)).doit() == Sum(2, (i, -oo, oo)).doit()
  291. assert Sum(expr, (i, -oo, oo)).diff(hi).doit() is oo
  292. expr = a * hj * hj / S(2)
  293. assert expr.diff(hi) == a * h[j] * KroneckerDelta(i, j)
  294. assert expr.diff(a) == hj * hj / S(2)
  295. assert expr.diff(a, 2) is S.Zero
  296. assert Sum(expr, (i, -oo, oo)).diff(hi) == Sum(a*KroneckerDelta(i, j)*h[j], (i, -oo, oo))
  297. assert Sum(expr.diff(hi), (i, -oo, oo)) == Sum(a*KroneckerDelta(i, j)*h[j], (i, -oo, oo))
  298. assert Sum(expr, (i, -oo, oo)).diff(hi).doit() == a*h[j]
  299. assert Sum(expr, (j, -oo, oo)).diff(hi) == Sum(a*KroneckerDelta(i, j)*h[j], (j, -oo, oo))
  300. assert Sum(expr.diff(hi), (j, -oo, oo)) == Sum(a*KroneckerDelta(i, j)*h[j], (j, -oo, oo))
  301. assert Sum(expr, (j, -oo, oo)).diff(hi).doit() == a*h[i]
  302. expr = a * sin(hj * hj)
  303. assert expr.diff(hi) == 2*a*cos(hj * hj) * hj * KroneckerDelta(i, j)
  304. assert expr.diff(hj) == 2*a*cos(hj * hj) * hj
  305. expr = a * L[i, j] * h[j]
  306. assert expr.diff(hi) == a*L[i, j]*KroneckerDelta(i, j)
  307. assert expr.diff(hj) == a*L[i, j]
  308. assert expr.diff(L[i, j]) == a*h[j]
  309. assert expr.diff(L[k, l]) == a*KroneckerDelta(i, k)*KroneckerDelta(j, l)*h[j]
  310. assert expr.diff(L[i, l]) == a*KroneckerDelta(j, l)*h[j]
  311. assert Sum(expr, (j, -oo, oo)).diff(L[k, l]) == Sum(a * KroneckerDelta(i, k) * KroneckerDelta(j, l) * h[j], (j, -oo, oo))
  312. assert Sum(expr, (j, -oo, oo)).diff(L[k, l]).doit() == a * KroneckerDelta(i, k) * h[l]
  313. assert h[m].diff(h[m]) == 1
  314. assert h[m].diff(h[n]) == KroneckerDelta(m, n)
  315. assert Sum(a*h[m], (m, -oo, oo)).diff(h[n]) == Sum(a*KroneckerDelta(m, n), (m, -oo, oo))
  316. assert Sum(a*h[m], (m, -oo, oo)).diff(h[n]).doit() == a
  317. assert Sum(a*h[m], (n, -oo, oo)).diff(h[n]) == Sum(a*KroneckerDelta(m, n), (n, -oo, oo))
  318. assert Sum(a*h[m], (m, -oo, oo)).diff(h[m]).doit() == oo*a
  319. def test_indexed_series():
  320. A = IndexedBase("A")
  321. i = symbols("i", integer=True)
  322. assert sin(A[i]).series(A[i]) == A[i] - A[i]**3/6 + A[i]**5/120 + Order(A[i]**6, A[i])
  323. def test_indexed_is_constant():
  324. A = IndexedBase("A")
  325. i, j, k = symbols("i,j,k")
  326. assert not A[i].is_constant()
  327. assert A[i].is_constant(j)
  328. assert not A[1+2*i, k].is_constant()
  329. assert not A[1+2*i, k].is_constant(i)
  330. assert A[1+2*i, k].is_constant(j)
  331. assert not A[1+2*i, k].is_constant(k)
  332. def test_issue_12533():
  333. d = IndexedBase('d')
  334. assert IndexedBase(range(5)) == Range(0, 5, 1)
  335. assert d[0].subs(Symbol("d"), range(5)) == 0
  336. assert d[0].subs(d, range(5)) == 0
  337. assert d[1].subs(d, range(5)) == 1
  338. assert Indexed(Range(5), 2) == 2
  339. def test_issue_12780():
  340. n = symbols("n")
  341. i = Idx("i", (0, n))
  342. raises(TypeError, lambda: i.subs(n, 1.5))
  343. def test_issue_18604():
  344. m = symbols("m")
  345. assert Idx("i", m).name == 'i'
  346. assert Idx("i", m).lower == 0
  347. assert Idx("i", m).upper == m - 1
  348. m = symbols("m", real=False)
  349. raises(TypeError, lambda: Idx("i", m))
  350. def test_Subs_with_Indexed():
  351. A = IndexedBase("A")
  352. i, j, k = symbols("i,j,k")
  353. x, y, z = symbols("x,y,z")
  354. f = Function("f")
  355. assert Subs(A[i], A[i], A[j]).diff(A[j]) == 1
  356. assert Subs(A[i], A[i], x).diff(A[i]) == 0
  357. assert Subs(A[i], A[i], x).diff(A[j]) == 0
  358. assert Subs(A[i], A[i], x).diff(x) == 1
  359. assert Subs(A[i], A[i], x).diff(y) == 0
  360. assert Subs(A[i], A[i], A[j]).diff(A[k]) == KroneckerDelta(j, k)
  361. assert Subs(x, x, A[i]).diff(A[j]) == KroneckerDelta(i, j)
  362. assert Subs(f(A[i]), A[i], x).diff(A[j]) == 0
  363. assert Subs(f(A[i]), A[i], A[k]).diff(A[j]) == Derivative(f(A[k]), A[k])*KroneckerDelta(j, k)
  364. assert Subs(x, x, A[i]**2).diff(A[j]) == 2*KroneckerDelta(i, j)*A[i]
  365. assert Subs(A[i], A[i], A[j]**2).diff(A[k]) == 2*KroneckerDelta(j, k)*A[j]
  366. assert Subs(A[i]*x, x, A[i]).diff(A[i]) == 2*A[i]
  367. assert Subs(A[i]*x, x, A[i]).diff(A[j]) == 2*A[i]*KroneckerDelta(i, j)
  368. assert Subs(A[i]*x, x, A[j]).diff(A[i]) == A[j] + A[i]*KroneckerDelta(i, j)
  369. assert Subs(A[i]*x, x, A[j]).diff(A[j]) == A[i] + A[j]*KroneckerDelta(i, j)
  370. assert Subs(A[i]*x, x, A[i]).diff(A[k]) == 2*A[i]*KroneckerDelta(i, k)
  371. assert Subs(A[i]*x, x, A[j]).diff(A[k]) == KroneckerDelta(i, k)*A[j] + KroneckerDelta(j, k)*A[i]
  372. assert Subs(A[i]*x, A[i], x).diff(A[i]) == 0
  373. assert Subs(A[i]*x, A[i], x).diff(A[j]) == 0
  374. assert Subs(A[i]*x, A[j], x).diff(A[i]) == x
  375. assert Subs(A[i]*x, A[j], x).diff(A[j]) == x*KroneckerDelta(i, j)
  376. assert Subs(A[i]*x, A[i], x).diff(A[k]) == 0
  377. assert Subs(A[i]*x, A[j], x).diff(A[k]) == x*KroneckerDelta(i, k)
  378. def test_complicated_derivative_with_Indexed():
  379. x, y = symbols("x,y", cls=IndexedBase)
  380. sigma = symbols("sigma")
  381. i, j, k = symbols("i,j,k")
  382. m0,m1,m2,m3,m4,m5 = symbols("m0:6")
  383. f = Function("f")
  384. expr = f((x[i] - y[i])**2/sigma)
  385. _xi_1 = symbols("xi_1", cls=Dummy)
  386. assert expr.diff(x[m0]).dummy_eq(
  387. (x[i] - y[i])*KroneckerDelta(i, m0)*\
  388. 2*Subs(
  389. Derivative(f(_xi_1), _xi_1),
  390. (_xi_1,),
  391. ((x[i] - y[i])**2/sigma,)
  392. )/sigma
  393. )
  394. assert expr.diff(x[m0]).diff(x[m1]).dummy_eq(
  395. 2*KroneckerDelta(i, m0)*\
  396. KroneckerDelta(i, m1)*Subs(
  397. Derivative(f(_xi_1), _xi_1),
  398. (_xi_1,),
  399. ((x[i] - y[i])**2/sigma,)
  400. )/sigma + \
  401. 4*(x[i] - y[i])**2*KroneckerDelta(i, m0)*KroneckerDelta(i, m1)*\
  402. Subs(
  403. Derivative(f(_xi_1), _xi_1, _xi_1),
  404. (_xi_1,),
  405. ((x[i] - y[i])**2/sigma,)
  406. )/sigma**2
  407. )