You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
186 lines
5.9 KiB
186 lines
5.9 KiB
import operator
|
|
from functools import reduce, singledispatch
|
|
|
|
from sympy.core.expr import Expr
|
|
from sympy.core.singleton import S
|
|
from sympy.matrices.expressions.hadamard import HadamardProduct
|
|
from sympy.matrices.expressions.inverse import Inverse
|
|
from sympy.matrices.expressions.matexpr import (MatrixExpr, MatrixSymbol)
|
|
from sympy.matrices.expressions.special import Identity
|
|
from sympy.matrices.expressions.transpose import Transpose
|
|
from sympy.combinatorics.permutations import _af_invert
|
|
from sympy.matrices.expressions.applyfunc import ElementwiseApplyFunction
|
|
from sympy.tensor.array.expressions.array_expressions import (
|
|
_ArrayExpr, ZeroArray, ArraySymbol, ArrayTensorProduct, ArrayAdd,
|
|
PermuteDims, ArrayDiagonal, ArrayElementwiseApplyFunc, get_rank,
|
|
get_shape, ArrayContraction, _array_tensor_product, _array_contraction,
|
|
_array_diagonal, _array_add, _permute_dims, Reshape)
|
|
from sympy.tensor.array.expressions.conv_matrix_to_array import convert_matrix_to_array
|
|
|
|
|
|
@singledispatch
|
|
def array_derive(expr, x):
|
|
raise NotImplementedError(f"not implemented for type {type(expr)}")
|
|
|
|
|
|
@array_derive.register(Expr)
|
|
def _(expr: Expr, x: _ArrayExpr):
|
|
return ZeroArray(*x.shape)
|
|
|
|
|
|
@array_derive.register(ArrayTensorProduct)
|
|
def _(expr: ArrayTensorProduct, x: Expr):
|
|
args = expr.args
|
|
addend_list = []
|
|
for i, arg in enumerate(expr.args):
|
|
darg = array_derive(arg, x)
|
|
if darg == 0:
|
|
continue
|
|
args_prev = args[:i]
|
|
args_succ = args[i+1:]
|
|
shape_prev = reduce(operator.add, map(get_shape, args_prev), ())
|
|
shape_succ = reduce(operator.add, map(get_shape, args_succ), ())
|
|
addend = _array_tensor_product(*args_prev, darg, *args_succ)
|
|
tot1 = len(get_shape(x))
|
|
tot2 = tot1 + len(shape_prev)
|
|
tot3 = tot2 + len(get_shape(arg))
|
|
tot4 = tot3 + len(shape_succ)
|
|
perm = [i for i in range(tot1, tot2)] + \
|
|
[i for i in range(tot1)] + [i for i in range(tot2, tot3)] + \
|
|
[i for i in range(tot3, tot4)]
|
|
addend = _permute_dims(addend, _af_invert(perm))
|
|
addend_list.append(addend)
|
|
if len(addend_list) == 1:
|
|
return addend_list[0]
|
|
elif len(addend_list) == 0:
|
|
return S.Zero
|
|
else:
|
|
return _array_add(*addend_list)
|
|
|
|
|
|
@array_derive.register(ArraySymbol)
|
|
def _(expr: ArraySymbol, x: _ArrayExpr):
|
|
if expr == x:
|
|
return _permute_dims(
|
|
ArrayTensorProduct.fromiter(Identity(i) for i in expr.shape),
|
|
[2*i for i in range(len(expr.shape))] + [2*i+1 for i in range(len(expr.shape))]
|
|
)
|
|
return ZeroArray(*(x.shape + expr.shape))
|
|
|
|
|
|
@array_derive.register(MatrixSymbol)
|
|
def _(expr: MatrixSymbol, x: _ArrayExpr):
|
|
m, n = expr.shape
|
|
if expr == x:
|
|
return _permute_dims(
|
|
_array_tensor_product(Identity(m), Identity(n)),
|
|
[0, 2, 1, 3]
|
|
)
|
|
return ZeroArray(*(x.shape + expr.shape))
|
|
|
|
|
|
@array_derive.register(Identity)
|
|
def _(expr: Identity, x: _ArrayExpr):
|
|
return ZeroArray(*(x.shape + expr.shape))
|
|
|
|
|
|
@array_derive.register(Transpose)
|
|
def _(expr: Transpose, x: Expr):
|
|
# D(A.T, A) ==> (m,n,i,j) ==> D(A_ji, A_mn) = d_mj d_ni
|
|
# D(B.T, A) ==> (m,n,i,j) ==> D(B_ji, A_mn)
|
|
fd = array_derive(expr.arg, x)
|
|
return _permute_dims(fd, [0, 1, 3, 2])
|
|
|
|
|
|
@array_derive.register(Inverse)
|
|
def _(expr: Inverse, x: Expr):
|
|
mat = expr.I
|
|
dexpr = array_derive(mat, x)
|
|
tp = _array_tensor_product(-expr, dexpr, expr)
|
|
mp = _array_contraction(tp, (1, 4), (5, 6))
|
|
pp = _permute_dims(mp, [1, 2, 0, 3])
|
|
return pp
|
|
|
|
|
|
@array_derive.register(ElementwiseApplyFunction)
|
|
def _(expr: ElementwiseApplyFunction, x: Expr):
|
|
assert get_rank(expr) == 2
|
|
assert get_rank(x) == 2
|
|
fdiff = expr._get_function_fdiff()
|
|
dexpr = array_derive(expr.expr, x)
|
|
tp = _array_tensor_product(
|
|
ElementwiseApplyFunction(fdiff, expr.expr),
|
|
dexpr
|
|
)
|
|
td = _array_diagonal(
|
|
tp, (0, 4), (1, 5)
|
|
)
|
|
return td
|
|
|
|
|
|
@array_derive.register(ArrayElementwiseApplyFunc)
|
|
def _(expr: ArrayElementwiseApplyFunc, x: Expr):
|
|
fdiff = expr._get_function_fdiff()
|
|
subexpr = expr.expr
|
|
dsubexpr = array_derive(subexpr, x)
|
|
tp = _array_tensor_product(
|
|
dsubexpr,
|
|
ArrayElementwiseApplyFunc(fdiff, subexpr)
|
|
)
|
|
b = get_rank(x)
|
|
c = get_rank(expr)
|
|
diag_indices = [(b + i, b + c + i) for i in range(c)]
|
|
return _array_diagonal(tp, *diag_indices)
|
|
|
|
|
|
@array_derive.register(MatrixExpr)
|
|
def _(expr: MatrixExpr, x: Expr):
|
|
cg = convert_matrix_to_array(expr)
|
|
return array_derive(cg, x)
|
|
|
|
|
|
@array_derive.register(HadamardProduct)
|
|
def _(expr: HadamardProduct, x: Expr):
|
|
raise NotImplementedError()
|
|
|
|
|
|
@array_derive.register(ArrayContraction)
|
|
def _(expr: ArrayContraction, x: Expr):
|
|
fd = array_derive(expr.expr, x)
|
|
rank_x = len(get_shape(x))
|
|
contraction_indices = expr.contraction_indices
|
|
new_contraction_indices = [tuple(j + rank_x for j in i) for i in contraction_indices]
|
|
return _array_contraction(fd, *new_contraction_indices)
|
|
|
|
|
|
@array_derive.register(ArrayDiagonal)
|
|
def _(expr: ArrayDiagonal, x: Expr):
|
|
dsubexpr = array_derive(expr.expr, x)
|
|
rank_x = len(get_shape(x))
|
|
diag_indices = [[j + rank_x for j in i] for i in expr.diagonal_indices]
|
|
return _array_diagonal(dsubexpr, *diag_indices)
|
|
|
|
|
|
@array_derive.register(ArrayAdd)
|
|
def _(expr: ArrayAdd, x: Expr):
|
|
return _array_add(*[array_derive(arg, x) for arg in expr.args])
|
|
|
|
|
|
@array_derive.register(PermuteDims)
|
|
def _(expr: PermuteDims, x: Expr):
|
|
de = array_derive(expr.expr, x)
|
|
perm = [0, 1] + [i + 2 for i in expr.permutation.array_form]
|
|
return _permute_dims(de, perm)
|
|
|
|
|
|
@array_derive.register(Reshape)
|
|
def _(expr: Reshape, x: Expr):
|
|
de = array_derive(expr.expr, x)
|
|
return Reshape(de, get_shape(x) + expr.shape)
|
|
|
|
|
|
def matrix_derive(expr, x):
|
|
from sympy.tensor.array.expressions.conv_array_to_matrix import convert_array_to_matrix
|
|
ce = convert_matrix_to_array(expr)
|
|
dce = array_derive(ce, x)
|
|
return convert_array_to_matrix(dce).doit()
|