You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
227 lines
6.9 KiB
227 lines
6.9 KiB
from sympy.core import symbols, S, Pow, Function
|
|
from sympy.functions import exp
|
|
from sympy.testing.pytest import raises
|
|
from sympy.tensor.indexed import Idx, IndexedBase
|
|
from sympy.tensor.index_methods import IndexConformanceException
|
|
|
|
from sympy.tensor.index_methods import (get_contraction_structure, get_indices)
|
|
|
|
|
|
def test_trivial_indices():
|
|
x, y = symbols('x y')
|
|
assert get_indices(x) == (set(), {})
|
|
assert get_indices(x*y) == (set(), {})
|
|
assert get_indices(x + y) == (set(), {})
|
|
assert get_indices(x**y) == (set(), {})
|
|
|
|
|
|
def test_get_indices_Indexed():
|
|
x = IndexedBase('x')
|
|
i, j = Idx('i'), Idx('j')
|
|
assert get_indices(x[i, j]) == ({i, j}, {})
|
|
assert get_indices(x[j, i]) == ({j, i}, {})
|
|
|
|
|
|
def test_get_indices_Idx():
|
|
f = Function('f')
|
|
i, j = Idx('i'), Idx('j')
|
|
assert get_indices(f(i)*j) == ({i, j}, {})
|
|
assert get_indices(f(j, i)) == ({j, i}, {})
|
|
assert get_indices(f(i)*i) == (set(), {})
|
|
|
|
|
|
def test_get_indices_mul():
|
|
x = IndexedBase('x')
|
|
y = IndexedBase('y')
|
|
i, j = Idx('i'), Idx('j')
|
|
assert get_indices(x[j]*y[i]) == ({i, j}, {})
|
|
assert get_indices(x[i]*y[j]) == ({i, j}, {})
|
|
|
|
|
|
def test_get_indices_exceptions():
|
|
x = IndexedBase('x')
|
|
y = IndexedBase('y')
|
|
i, j = Idx('i'), Idx('j')
|
|
raises(IndexConformanceException, lambda: get_indices(x[i] + y[j]))
|
|
|
|
|
|
def test_scalar_broadcast():
|
|
x = IndexedBase('x')
|
|
y = IndexedBase('y')
|
|
i, j = Idx('i'), Idx('j')
|
|
assert get_indices(x[i] + y[i, i]) == ({i}, {})
|
|
assert get_indices(x[i] + y[j, j]) == ({i}, {})
|
|
|
|
|
|
def test_get_indices_add():
|
|
x = IndexedBase('x')
|
|
y = IndexedBase('y')
|
|
A = IndexedBase('A')
|
|
i, j, k = Idx('i'), Idx('j'), Idx('k')
|
|
assert get_indices(x[i] + 2*y[i]) == ({i}, {})
|
|
assert get_indices(y[i] + 2*A[i, j]*x[j]) == ({i}, {})
|
|
assert get_indices(y[i] + 2*(x[i] + A[i, j]*x[j])) == ({i}, {})
|
|
assert get_indices(y[i] + x[i]*(A[j, j] + 1)) == ({i}, {})
|
|
assert get_indices(
|
|
y[i] + x[i]*x[j]*(y[j] + A[j, k]*x[k])) == ({i}, {})
|
|
|
|
|
|
def test_get_indices_Pow():
|
|
x = IndexedBase('x')
|
|
y = IndexedBase('y')
|
|
A = IndexedBase('A')
|
|
i, j, k = Idx('i'), Idx('j'), Idx('k')
|
|
assert get_indices(Pow(x[i], y[j])) == ({i, j}, {})
|
|
assert get_indices(Pow(x[i, k], y[j, k])) == ({i, j, k}, {})
|
|
assert get_indices(Pow(A[i, k], y[k] + A[k, j]*x[j])) == ({i, k}, {})
|
|
assert get_indices(Pow(2, x[i])) == get_indices(exp(x[i]))
|
|
|
|
# test of a design decision, this may change:
|
|
assert get_indices(Pow(x[i], 2)) == ({i}, {})
|
|
|
|
|
|
def test_get_contraction_structure_basic():
|
|
x = IndexedBase('x')
|
|
y = IndexedBase('y')
|
|
i, j = Idx('i'), Idx('j')
|
|
assert get_contraction_structure(x[i]*y[j]) == {None: {x[i]*y[j]}}
|
|
assert get_contraction_structure(x[i] + y[j]) == {None: {x[i], y[j]}}
|
|
assert get_contraction_structure(x[i]*y[i]) == {(i,): {x[i]*y[i]}}
|
|
assert get_contraction_structure(
|
|
1 + x[i]*y[i]) == {None: {S.One}, (i,): {x[i]*y[i]}}
|
|
assert get_contraction_structure(x[i]**y[i]) == {None: {x[i]**y[i]}}
|
|
|
|
|
|
def test_get_contraction_structure_complex():
|
|
x = IndexedBase('x')
|
|
y = IndexedBase('y')
|
|
A = IndexedBase('A')
|
|
i, j, k = Idx('i'), Idx('j'), Idx('k')
|
|
expr1 = y[i] + A[i, j]*x[j]
|
|
d1 = {None: {y[i]}, (j,): {A[i, j]*x[j]}}
|
|
assert get_contraction_structure(expr1) == d1
|
|
expr2 = expr1*A[k, i] + x[k]
|
|
d2 = {None: {x[k]}, (i,): {expr1*A[k, i]}, expr1*A[k, i]: [d1]}
|
|
assert get_contraction_structure(expr2) == d2
|
|
|
|
|
|
def test_contraction_structure_simple_Pow():
|
|
x = IndexedBase('x')
|
|
y = IndexedBase('y')
|
|
i, j, k = Idx('i'), Idx('j'), Idx('k')
|
|
ii_jj = x[i, i]**y[j, j]
|
|
assert get_contraction_structure(ii_jj) == {
|
|
None: {ii_jj},
|
|
ii_jj: [
|
|
{(i,): {x[i, i]}},
|
|
{(j,): {y[j, j]}}
|
|
]
|
|
}
|
|
|
|
ii_jk = x[i, i]**y[j, k]
|
|
assert get_contraction_structure(ii_jk) == {
|
|
None: {x[i, i]**y[j, k]},
|
|
x[i, i]**y[j, k]: [
|
|
{(i,): {x[i, i]}}
|
|
]
|
|
}
|
|
|
|
|
|
def test_contraction_structure_Mul_and_Pow():
|
|
x = IndexedBase('x')
|
|
y = IndexedBase('y')
|
|
i, j, k = Idx('i'), Idx('j'), Idx('k')
|
|
|
|
i_ji = x[i]**(y[j]*x[i])
|
|
assert get_contraction_structure(i_ji) == {None: {i_ji}}
|
|
ij_i = (x[i]*y[j])**(y[i])
|
|
assert get_contraction_structure(ij_i) == {None: {ij_i}}
|
|
j_ij_i = x[j]*(x[i]*y[j])**(y[i])
|
|
assert get_contraction_structure(j_ij_i) == {(j,): {j_ij_i}}
|
|
j_i_ji = x[j]*x[i]**(y[j]*x[i])
|
|
assert get_contraction_structure(j_i_ji) == {(j,): {j_i_ji}}
|
|
ij_exp_kki = x[i]*y[j]*exp(y[i]*y[k, k])
|
|
result = get_contraction_structure(ij_exp_kki)
|
|
expected = {
|
|
(i,): {ij_exp_kki},
|
|
ij_exp_kki: [{
|
|
None: {exp(y[i]*y[k, k])},
|
|
exp(y[i]*y[k, k]): [{
|
|
None: {y[i]*y[k, k]},
|
|
y[i]*y[k, k]: [{(k,): {y[k, k]}}]
|
|
}]}
|
|
]
|
|
}
|
|
assert result == expected
|
|
|
|
|
|
def test_contraction_structure_Add_in_Pow():
|
|
x = IndexedBase('x')
|
|
y = IndexedBase('y')
|
|
i, j, k = Idx('i'), Idx('j'), Idx('k')
|
|
s_ii_jj_s = (1 + x[i, i])**(1 + y[j, j])
|
|
expected = {
|
|
None: {s_ii_jj_s},
|
|
s_ii_jj_s: [
|
|
{None: {S.One}, (i,): {x[i, i]}},
|
|
{None: {S.One}, (j,): {y[j, j]}}
|
|
]
|
|
}
|
|
result = get_contraction_structure(s_ii_jj_s)
|
|
assert result == expected
|
|
|
|
s_ii_jk_s = (1 + x[i, i]) ** (1 + y[j, k])
|
|
expected_2 = {
|
|
None: {(x[i, i] + 1)**(y[j, k] + 1)},
|
|
s_ii_jk_s: [
|
|
{None: {S.One}, (i,): {x[i, i]}}
|
|
]
|
|
}
|
|
result_2 = get_contraction_structure(s_ii_jk_s)
|
|
assert result_2 == expected_2
|
|
|
|
|
|
def test_contraction_structure_Pow_in_Pow():
|
|
x = IndexedBase('x')
|
|
y = IndexedBase('y')
|
|
z = IndexedBase('z')
|
|
i, j, k = Idx('i'), Idx('j'), Idx('k')
|
|
ii_jj_kk = x[i, i]**y[j, j]**z[k, k]
|
|
expected = {
|
|
None: {ii_jj_kk},
|
|
ii_jj_kk: [
|
|
{(i,): {x[i, i]}},
|
|
{
|
|
None: {y[j, j]**z[k, k]},
|
|
y[j, j]**z[k, k]: [
|
|
{(j,): {y[j, j]}},
|
|
{(k,): {z[k, k]}}
|
|
]
|
|
}
|
|
]
|
|
}
|
|
assert get_contraction_structure(ii_jj_kk) == expected
|
|
|
|
|
|
def test_ufunc_support():
|
|
f = Function('f')
|
|
g = Function('g')
|
|
x = IndexedBase('x')
|
|
y = IndexedBase('y')
|
|
i, j = Idx('i'), Idx('j')
|
|
a = symbols('a')
|
|
|
|
assert get_indices(f(x[i])) == ({i}, {})
|
|
assert get_indices(f(x[i], y[j])) == ({i, j}, {})
|
|
assert get_indices(f(y[i])*g(x[i])) == (set(), {})
|
|
assert get_indices(f(a, x[i])) == ({i}, {})
|
|
assert get_indices(f(a, y[i], x[j])*g(x[i])) == ({j}, {})
|
|
assert get_indices(g(f(x[i]))) == ({i}, {})
|
|
|
|
assert get_contraction_structure(f(x[i])) == {None: {f(x[i])}}
|
|
assert get_contraction_structure(
|
|
f(y[i])*g(x[i])) == {(i,): {f(y[i])*g(x[i])}}
|
|
assert get_contraction_structure(
|
|
f(y[i])*g(f(x[i]))) == {(i,): {f(y[i])*g(f(x[i]))}}
|
|
assert get_contraction_structure(
|
|
f(x[j], y[i])*g(x[i])) == {(i,): {f(x[j], y[i])*g(x[i])}}
|