m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

503 lines
16 KiB

6 months ago
  1. from sympy.core import symbols, Symbol, Tuple, oo, Dummy
  2. from sympy.tensor.indexed import IndexException
  3. from sympy.testing.pytest import raises
  4. from sympy.utilities.iterables import iterable
  5. # import test:
  6. from sympy.concrete.summations import Sum
  7. from sympy.core.function import Function, Subs, Derivative
  8. from sympy.core.relational import (StrictLessThan, GreaterThan,
  9. StrictGreaterThan, LessThan)
  10. from sympy.core.singleton import S
  11. from sympy.functions.elementary.exponential import exp, log
  12. from sympy.functions.elementary.trigonometric import cos, sin
  13. from sympy.functions.special.tensor_functions import KroneckerDelta
  14. from sympy.series.order import Order
  15. from sympy.sets.fancysets import Range
  16. from sympy.tensor.indexed import IndexedBase, Idx, Indexed
  17. def test_Idx_construction():
  18. i, a, b = symbols('i a b', integer=True)
  19. assert Idx(i) != Idx(i, 1)
  20. assert Idx(i, a) == Idx(i, (0, a - 1))
  21. assert Idx(i, oo) == Idx(i, (0, oo))
  22. x = symbols('x', integer=False)
  23. raises(TypeError, lambda: Idx(x))
  24. raises(TypeError, lambda: Idx(0.5))
  25. raises(TypeError, lambda: Idx(i, x))
  26. raises(TypeError, lambda: Idx(i, 0.5))
  27. raises(TypeError, lambda: Idx(i, (x, 5)))
  28. raises(TypeError, lambda: Idx(i, (2, x)))
  29. raises(TypeError, lambda: Idx(i, (2, 3.5)))
  30. def test_Idx_properties():
  31. i, a, b = symbols('i a b', integer=True)
  32. assert Idx(i).is_integer
  33. assert Idx(i).name == 'i'
  34. assert Idx(i + 2).name == 'i + 2'
  35. assert Idx('foo').name == 'foo'
  36. def test_Idx_bounds():
  37. i, a, b = symbols('i a b', integer=True)
  38. assert Idx(i).lower is None
  39. assert Idx(i).upper is None
  40. assert Idx(i, a).lower == 0
  41. assert Idx(i, a).upper == a - 1
  42. assert Idx(i, 5).lower == 0
  43. assert Idx(i, 5).upper == 4
  44. assert Idx(i, oo).lower == 0
  45. assert Idx(i, oo).upper is oo
  46. assert Idx(i, (a, b)).lower == a
  47. assert Idx(i, (a, b)).upper == b
  48. assert Idx(i, (1, 5)).lower == 1
  49. assert Idx(i, (1, 5)).upper == 5
  50. assert Idx(i, (-oo, oo)).lower is -oo
  51. assert Idx(i, (-oo, oo)).upper is oo
  52. def test_Idx_fixed_bounds():
  53. i, a, b, x = symbols('i a b x', integer=True)
  54. assert Idx(x).lower is None
  55. assert Idx(x).upper is None
  56. assert Idx(x, a).lower == 0
  57. assert Idx(x, a).upper == a - 1
  58. assert Idx(x, 5).lower == 0
  59. assert Idx(x, 5).upper == 4
  60. assert Idx(x, oo).lower == 0
  61. assert Idx(x, oo).upper is oo
  62. assert Idx(x, (a, b)).lower == a
  63. assert Idx(x, (a, b)).upper == b
  64. assert Idx(x, (1, 5)).lower == 1
  65. assert Idx(x, (1, 5)).upper == 5
  66. assert Idx(x, (-oo, oo)).lower is -oo
  67. assert Idx(x, (-oo, oo)).upper is oo
  68. def test_Idx_inequalities():
  69. i14 = Idx("i14", (1, 4))
  70. i79 = Idx("i79", (7, 9))
  71. i46 = Idx("i46", (4, 6))
  72. i35 = Idx("i35", (3, 5))
  73. assert i14 <= 5
  74. assert i14 < 5
  75. assert not (i14 >= 5)
  76. assert not (i14 > 5)
  77. assert 5 >= i14
  78. assert 5 > i14
  79. assert not (5 <= i14)
  80. assert not (5 < i14)
  81. assert LessThan(i14, 5)
  82. assert StrictLessThan(i14, 5)
  83. assert not GreaterThan(i14, 5)
  84. assert not StrictGreaterThan(i14, 5)
  85. assert i14 <= 4
  86. assert isinstance(i14 < 4, StrictLessThan)
  87. assert isinstance(i14 >= 4, GreaterThan)
  88. assert not (i14 > 4)
  89. assert isinstance(i14 <= 1, LessThan)
  90. assert not (i14 < 1)
  91. assert i14 >= 1
  92. assert isinstance(i14 > 1, StrictGreaterThan)
  93. assert not (i14 <= 0)
  94. assert not (i14 < 0)
  95. assert i14 >= 0
  96. assert i14 > 0
  97. from sympy.abc import x
  98. assert isinstance(i14 < x, StrictLessThan)
  99. assert isinstance(i14 > x, StrictGreaterThan)
  100. assert isinstance(i14 <= x, LessThan)
  101. assert isinstance(i14 >= x, GreaterThan)
  102. assert i14 < i79
  103. assert i14 <= i79
  104. assert not (i14 > i79)
  105. assert not (i14 >= i79)
  106. assert i14 <= i46
  107. assert isinstance(i14 < i46, StrictLessThan)
  108. assert isinstance(i14 >= i46, GreaterThan)
  109. assert not (i14 > i46)
  110. assert isinstance(i14 < i35, StrictLessThan)
  111. assert isinstance(i14 > i35, StrictGreaterThan)
  112. assert isinstance(i14 <= i35, LessThan)
  113. assert isinstance(i14 >= i35, GreaterThan)
  114. iNone1 = Idx("iNone1")
  115. iNone2 = Idx("iNone2")
  116. assert isinstance(iNone1 < iNone2, StrictLessThan)
  117. assert isinstance(iNone1 > iNone2, StrictGreaterThan)
  118. assert isinstance(iNone1 <= iNone2, LessThan)
  119. assert isinstance(iNone1 >= iNone2, GreaterThan)
  120. def test_Idx_inequalities_current_fails():
  121. i14 = Idx("i14", (1, 4))
  122. assert S(5) >= i14
  123. assert S(5) > i14
  124. assert not (S(5) <= i14)
  125. assert not (S(5) < i14)
  126. def test_Idx_func_args():
  127. i, a, b = symbols('i a b', integer=True)
  128. ii = Idx(i)
  129. assert ii.func(*ii.args) == ii
  130. ii = Idx(i, a)
  131. assert ii.func(*ii.args) == ii
  132. ii = Idx(i, (a, b))
  133. assert ii.func(*ii.args) == ii
  134. def test_Idx_subs():
  135. i, a, b = symbols('i a b', integer=True)
  136. assert Idx(i, a).subs(a, b) == Idx(i, b)
  137. assert Idx(i, a).subs(i, b) == Idx(b, a)
  138. assert Idx(i).subs(i, 2) == Idx(2)
  139. assert Idx(i, a).subs(a, 2) == Idx(i, 2)
  140. assert Idx(i, (a, b)).subs(i, 2) == Idx(2, (a, b))
  141. def test_IndexedBase_sugar():
  142. i, j = symbols('i j', integer=True)
  143. a = symbols('a')
  144. A1 = Indexed(a, i, j)
  145. A2 = IndexedBase(a)
  146. assert A1 == A2[i, j]
  147. assert A1 == A2[(i, j)]
  148. assert A1 == A2[[i, j]]
  149. assert A1 == A2[Tuple(i, j)]
  150. assert all(a.is_Integer for a in A2[1, 0].args[1:])
  151. def test_IndexedBase_subs():
  152. i = symbols('i', integer=True)
  153. a, b = symbols('a b')
  154. A = IndexedBase(a)
  155. B = IndexedBase(b)
  156. assert A[i] == B[i].subs(b, a)
  157. C = {1: 2}
  158. assert C[1] == A[1].subs(A, C)
  159. def test_IndexedBase_shape():
  160. i, j, m, n = symbols('i j m n', integer=True)
  161. a = IndexedBase('a', shape=(m, m))
  162. b = IndexedBase('a', shape=(m, n))
  163. assert b.shape == Tuple(m, n)
  164. assert a[i, j] != b[i, j]
  165. assert a[i, j] == b[i, j].subs(n, m)
  166. assert b.func(*b.args) == b
  167. assert b[i, j].func(*b[i, j].args) == b[i, j]
  168. raises(IndexException, lambda: b[i])
  169. raises(IndexException, lambda: b[i, i, j])
  170. F = IndexedBase("F", shape=m)
  171. assert F.shape == Tuple(m)
  172. assert F[i].subs(i, j) == F[j]
  173. raises(IndexException, lambda: F[i, j])
  174. def test_IndexedBase_assumptions():
  175. i = Symbol('i', integer=True)
  176. a = Symbol('a')
  177. A = IndexedBase(a, positive=True)
  178. for c in (A, A[i]):
  179. assert c.is_real
  180. assert c.is_complex
  181. assert not c.is_imaginary
  182. assert c.is_nonnegative
  183. assert c.is_nonzero
  184. assert c.is_commutative
  185. assert log(exp(c)) == c
  186. assert A != IndexedBase(a)
  187. assert A == IndexedBase(a, positive=True, real=True)
  188. assert A[i] != Indexed(a, i)
  189. def test_IndexedBase_assumptions_inheritance():
  190. I = Symbol('I', integer=True)
  191. I_inherit = IndexedBase(I)
  192. I_explicit = IndexedBase('I', integer=True)
  193. assert I_inherit.is_integer
  194. assert I_explicit.is_integer
  195. assert I_inherit.label.is_integer
  196. assert I_explicit.label.is_integer
  197. assert I_inherit == I_explicit
  198. def test_issue_17652():
  199. """Regression test issue #17652.
  200. IndexedBase.label should not upcast subclasses of Symbol
  201. """
  202. class SubClass(Symbol):
  203. pass
  204. x = SubClass('X')
  205. assert type(x) == SubClass
  206. base = IndexedBase(x)
  207. assert type(x) == SubClass
  208. assert type(base.label) == SubClass
  209. def test_Indexed_constructor():
  210. i, j = symbols('i j', integer=True)
  211. A = Indexed('A', i, j)
  212. assert A == Indexed(Symbol('A'), i, j)
  213. assert A == Indexed(IndexedBase('A'), i, j)
  214. raises(TypeError, lambda: Indexed(A, i, j))
  215. raises(IndexException, lambda: Indexed("A"))
  216. assert A.free_symbols == {A, A.base.label, i, j}
  217. def test_Indexed_func_args():
  218. i, j = symbols('i j', integer=True)
  219. a = symbols('a')
  220. A = Indexed(a, i, j)
  221. assert A == A.func(*A.args)
  222. def test_Indexed_subs():
  223. i, j, k = symbols('i j k', integer=True)
  224. a, b = symbols('a b')
  225. A = IndexedBase(a)
  226. B = IndexedBase(b)
  227. assert A[i, j] == B[i, j].subs(b, a)
  228. assert A[i, j] == A[i, k].subs(k, j)
  229. def test_Indexed_properties():
  230. i, j = symbols('i j', integer=True)
  231. A = Indexed('A', i, j)
  232. assert A.name == 'A[i, j]'
  233. assert A.rank == 2
  234. assert A.indices == (i, j)
  235. assert A.base == IndexedBase('A')
  236. assert A.ranges == [None, None]
  237. raises(IndexException, lambda: A.shape)
  238. n, m = symbols('n m', integer=True)
  239. assert Indexed('A', Idx(
  240. i, m), Idx(j, n)).ranges == [Tuple(0, m - 1), Tuple(0, n - 1)]
  241. assert Indexed('A', Idx(i, m), Idx(j, n)).shape == Tuple(m, n)
  242. raises(IndexException, lambda: Indexed("A", Idx(i, m), Idx(j)).shape)
  243. def test_Indexed_shape_precedence():
  244. i, j = symbols('i j', integer=True)
  245. o, p = symbols('o p', integer=True)
  246. n, m = symbols('n m', integer=True)
  247. a = IndexedBase('a', shape=(o, p))
  248. assert a.shape == Tuple(o, p)
  249. assert Indexed(
  250. a, Idx(i, m), Idx(j, n)).ranges == [Tuple(0, m - 1), Tuple(0, n - 1)]
  251. assert Indexed(a, Idx(i, m), Idx(j, n)).shape == Tuple(o, p)
  252. assert Indexed(
  253. a, Idx(i, m), Idx(j)).ranges == [Tuple(0, m - 1), (None, None)]
  254. assert Indexed(a, Idx(i, m), Idx(j)).shape == Tuple(o, p)
  255. def test_complex_indices():
  256. i, j = symbols('i j', integer=True)
  257. A = Indexed('A', i, i + j)
  258. assert A.rank == 2
  259. assert A.indices == (i, i + j)
  260. def test_not_interable():
  261. i, j = symbols('i j', integer=True)
  262. A = Indexed('A', i, i + j)
  263. assert not iterable(A)
  264. def test_Indexed_coeff():
  265. N = Symbol('N', integer=True)
  266. len_y = N
  267. i = Idx('i', len_y-1)
  268. y = IndexedBase('y', shape=(len_y,))
  269. a = (1/y[i+1]*y[i]).coeff(y[i])
  270. b = (y[i]/y[i+1]).coeff(y[i])
  271. assert a == b
  272. def test_differentiation():
  273. from sympy.functions.special.tensor_functions import KroneckerDelta
  274. i, j, k, l = symbols('i j k l', cls=Idx)
  275. a = symbols('a')
  276. m, n = symbols("m, n", integer=True, finite=True)
  277. assert m.is_real
  278. h, L = symbols('h L', cls=IndexedBase)
  279. hi, hj = h[i], h[j]
  280. expr = hi
  281. assert expr.diff(hj) == KroneckerDelta(i, j)
  282. assert expr.diff(hi) == KroneckerDelta(i, i)
  283. expr = S(2) * hi
  284. assert expr.diff(hj) == S(2) * KroneckerDelta(i, j)
  285. assert expr.diff(hi) == S(2) * KroneckerDelta(i, i)
  286. assert expr.diff(a) is S.Zero
  287. assert Sum(expr, (i, -oo, oo)).diff(hj) == Sum(2*KroneckerDelta(i, j), (i, -oo, oo))
  288. assert Sum(expr.diff(hj), (i, -oo, oo)) == Sum(2*KroneckerDelta(i, j), (i, -oo, oo))
  289. assert Sum(expr, (i, -oo, oo)).diff(hj).doit() == 2
  290. assert Sum(expr.diff(hi), (i, -oo, oo)).doit() == Sum(2, (i, -oo, oo)).doit()
  291. assert Sum(expr, (i, -oo, oo)).diff(hi).doit() is oo
  292. expr = a * hj * hj / S(2)
  293. assert expr.diff(hi) == a * h[j] * KroneckerDelta(i, j)
  294. assert expr.diff(a) == hj * hj / S(2)
  295. assert expr.diff(a, 2) is S.Zero
  296. assert Sum(expr, (i, -oo, oo)).diff(hi) == Sum(a*KroneckerDelta(i, j)*h[j], (i, -oo, oo))
  297. assert Sum(expr.diff(hi), (i, -oo, oo)) == Sum(a*KroneckerDelta(i, j)*h[j], (i, -oo, oo))
  298. assert Sum(expr, (i, -oo, oo)).diff(hi).doit() == a*h[j]
  299. assert Sum(expr, (j, -oo, oo)).diff(hi) == Sum(a*KroneckerDelta(i, j)*h[j], (j, -oo, oo))
  300. assert Sum(expr.diff(hi), (j, -oo, oo)) == Sum(a*KroneckerDelta(i, j)*h[j], (j, -oo, oo))
  301. assert Sum(expr, (j, -oo, oo)).diff(hi).doit() == a*h[i]
  302. expr = a * sin(hj * hj)
  303. assert expr.diff(hi) == 2*a*cos(hj * hj) * hj * KroneckerDelta(i, j)
  304. assert expr.diff(hj) == 2*a*cos(hj * hj) * hj
  305. expr = a * L[i, j] * h[j]
  306. assert expr.diff(hi) == a*L[i, j]*KroneckerDelta(i, j)
  307. assert expr.diff(hj) == a*L[i, j]
  308. assert expr.diff(L[i, j]) == a*h[j]
  309. assert expr.diff(L[k, l]) == a*KroneckerDelta(i, k)*KroneckerDelta(j, l)*h[j]
  310. assert expr.diff(L[i, l]) == a*KroneckerDelta(j, l)*h[j]
  311. assert Sum(expr, (j, -oo, oo)).diff(L[k, l]) == Sum(a * KroneckerDelta(i, k) * KroneckerDelta(j, l) * h[j], (j, -oo, oo))
  312. assert Sum(expr, (j, -oo, oo)).diff(L[k, l]).doit() == a * KroneckerDelta(i, k) * h[l]
  313. assert h[m].diff(h[m]) == 1
  314. assert h[m].diff(h[n]) == KroneckerDelta(m, n)
  315. assert Sum(a*h[m], (m, -oo, oo)).diff(h[n]) == Sum(a*KroneckerDelta(m, n), (m, -oo, oo))
  316. assert Sum(a*h[m], (m, -oo, oo)).diff(h[n]).doit() == a
  317. assert Sum(a*h[m], (n, -oo, oo)).diff(h[n]) == Sum(a*KroneckerDelta(m, n), (n, -oo, oo))
  318. assert Sum(a*h[m], (m, -oo, oo)).diff(h[m]).doit() == oo*a
  319. def test_indexed_series():
  320. A = IndexedBase("A")
  321. i = symbols("i", integer=True)
  322. assert sin(A[i]).series(A[i]) == A[i] - A[i]**3/6 + A[i]**5/120 + Order(A[i]**6, A[i])
  323. def test_indexed_is_constant():
  324. A = IndexedBase("A")
  325. i, j, k = symbols("i,j,k")
  326. assert not A[i].is_constant()
  327. assert A[i].is_constant(j)
  328. assert not A[1+2*i, k].is_constant()
  329. assert not A[1+2*i, k].is_constant(i)
  330. assert A[1+2*i, k].is_constant(j)
  331. assert not A[1+2*i, k].is_constant(k)
  332. def test_issue_12533():
  333. d = IndexedBase('d')
  334. assert IndexedBase(range(5)) == Range(0, 5, 1)
  335. assert d[0].subs(Symbol("d"), range(5)) == 0
  336. assert d[0].subs(d, range(5)) == 0
  337. assert d[1].subs(d, range(5)) == 1
  338. assert Indexed(Range(5), 2) == 2
  339. def test_issue_12780():
  340. n = symbols("n")
  341. i = Idx("i", (0, n))
  342. raises(TypeError, lambda: i.subs(n, 1.5))
  343. def test_issue_18604():
  344. m = symbols("m")
  345. assert Idx("i", m).name == 'i'
  346. assert Idx("i", m).lower == 0
  347. assert Idx("i", m).upper == m - 1
  348. m = symbols("m", real=False)
  349. raises(TypeError, lambda: Idx("i", m))
  350. def test_Subs_with_Indexed():
  351. A = IndexedBase("A")
  352. i, j, k = symbols("i,j,k")
  353. x, y, z = symbols("x,y,z")
  354. f = Function("f")
  355. assert Subs(A[i], A[i], A[j]).diff(A[j]) == 1
  356. assert Subs(A[i], A[i], x).diff(A[i]) == 0
  357. assert Subs(A[i], A[i], x).diff(A[j]) == 0
  358. assert Subs(A[i], A[i], x).diff(x) == 1
  359. assert Subs(A[i], A[i], x).diff(y) == 0
  360. assert Subs(A[i], A[i], A[j]).diff(A[k]) == KroneckerDelta(j, k)
  361. assert Subs(x, x, A[i]).diff(A[j]) == KroneckerDelta(i, j)
  362. assert Subs(f(A[i]), A[i], x).diff(A[j]) == 0
  363. assert Subs(f(A[i]), A[i], A[k]).diff(A[j]) == Derivative(f(A[k]), A[k])*KroneckerDelta(j, k)
  364. assert Subs(x, x, A[i]**2).diff(A[j]) == 2*KroneckerDelta(i, j)*A[i]
  365. assert Subs(A[i], A[i], A[j]**2).diff(A[k]) == 2*KroneckerDelta(j, k)*A[j]
  366. assert Subs(A[i]*x, x, A[i]).diff(A[i]) == 2*A[i]
  367. assert Subs(A[i]*x, x, A[i]).diff(A[j]) == 2*A[i]*KroneckerDelta(i, j)
  368. assert Subs(A[i]*x, x, A[j]).diff(A[i]) == A[j] + A[i]*KroneckerDelta(i, j)
  369. assert Subs(A[i]*x, x, A[j]).diff(A[j]) == A[i] + A[j]*KroneckerDelta(i, j)
  370. assert Subs(A[i]*x, x, A[i]).diff(A[k]) == 2*A[i]*KroneckerDelta(i, k)
  371. assert Subs(A[i]*x, x, A[j]).diff(A[k]) == KroneckerDelta(i, k)*A[j] + KroneckerDelta(j, k)*A[i]
  372. assert Subs(A[i]*x, A[i], x).diff(A[i]) == 0
  373. assert Subs(A[i]*x, A[i], x).diff(A[j]) == 0
  374. assert Subs(A[i]*x, A[j], x).diff(A[i]) == x
  375. assert Subs(A[i]*x, A[j], x).diff(A[j]) == x*KroneckerDelta(i, j)
  376. assert Subs(A[i]*x, A[i], x).diff(A[k]) == 0
  377. assert Subs(A[i]*x, A[j], x).diff(A[k]) == x*KroneckerDelta(i, k)
  378. def test_complicated_derivative_with_Indexed():
  379. x, y = symbols("x,y", cls=IndexedBase)
  380. sigma = symbols("sigma")
  381. i, j, k = symbols("i,j,k")
  382. m0,m1,m2,m3,m4,m5 = symbols("m0:6")
  383. f = Function("f")
  384. expr = f((x[i] - y[i])**2/sigma)
  385. _xi_1 = symbols("xi_1", cls=Dummy)
  386. assert expr.diff(x[m0]).dummy_eq(
  387. (x[i] - y[i])*KroneckerDelta(i, m0)*\
  388. 2*Subs(
  389. Derivative(f(_xi_1), _xi_1),
  390. (_xi_1,),
  391. ((x[i] - y[i])**2/sigma,)
  392. )/sigma
  393. )
  394. assert expr.diff(x[m0]).diff(x[m1]).dummy_eq(
  395. 2*KroneckerDelta(i, m0)*\
  396. KroneckerDelta(i, m1)*Subs(
  397. Derivative(f(_xi_1), _xi_1),
  398. (_xi_1,),
  399. ((x[i] - y[i])**2/sigma,)
  400. )/sigma + \
  401. 4*(x[i] - y[i])**2*KroneckerDelta(i, m0)*KroneckerDelta(i, m1)*\
  402. Subs(
  403. Derivative(f(_xi_1), _xi_1, _xi_1),
  404. (_xi_1,),
  405. ((x[i] - y[i])**2/sigma,)
  406. )/sigma**2
  407. )