m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

186 lines
5.9 KiB

6 months ago
  1. import operator
  2. from functools import reduce, singledispatch
  3. from sympy.core.expr import Expr
  4. from sympy.core.singleton import S
  5. from sympy.matrices.expressions.hadamard import HadamardProduct
  6. from sympy.matrices.expressions.inverse import Inverse
  7. from sympy.matrices.expressions.matexpr import (MatrixExpr, MatrixSymbol)
  8. from sympy.matrices.expressions.special import Identity
  9. from sympy.matrices.expressions.transpose import Transpose
  10. from sympy.combinatorics.permutations import _af_invert
  11. from sympy.matrices.expressions.applyfunc import ElementwiseApplyFunction
  12. from sympy.tensor.array.expressions.array_expressions import (
  13. _ArrayExpr, ZeroArray, ArraySymbol, ArrayTensorProduct, ArrayAdd,
  14. PermuteDims, ArrayDiagonal, ArrayElementwiseApplyFunc, get_rank,
  15. get_shape, ArrayContraction, _array_tensor_product, _array_contraction,
  16. _array_diagonal, _array_add, _permute_dims, Reshape)
  17. from sympy.tensor.array.expressions.conv_matrix_to_array import convert_matrix_to_array
  18. @singledispatch
  19. def array_derive(expr, x):
  20. raise NotImplementedError(f"not implemented for type {type(expr)}")
  21. @array_derive.register(Expr)
  22. def _(expr: Expr, x: _ArrayExpr):
  23. return ZeroArray(*x.shape)
  24. @array_derive.register(ArrayTensorProduct)
  25. def _(expr: ArrayTensorProduct, x: Expr):
  26. args = expr.args
  27. addend_list = []
  28. for i, arg in enumerate(expr.args):
  29. darg = array_derive(arg, x)
  30. if darg == 0:
  31. continue
  32. args_prev = args[:i]
  33. args_succ = args[i+1:]
  34. shape_prev = reduce(operator.add, map(get_shape, args_prev), ())
  35. shape_succ = reduce(operator.add, map(get_shape, args_succ), ())
  36. addend = _array_tensor_product(*args_prev, darg, *args_succ)
  37. tot1 = len(get_shape(x))
  38. tot2 = tot1 + len(shape_prev)
  39. tot3 = tot2 + len(get_shape(arg))
  40. tot4 = tot3 + len(shape_succ)
  41. perm = [i for i in range(tot1, tot2)] + \
  42. [i for i in range(tot1)] + [i for i in range(tot2, tot3)] + \
  43. [i for i in range(tot3, tot4)]
  44. addend = _permute_dims(addend, _af_invert(perm))
  45. addend_list.append(addend)
  46. if len(addend_list) == 1:
  47. return addend_list[0]
  48. elif len(addend_list) == 0:
  49. return S.Zero
  50. else:
  51. return _array_add(*addend_list)
  52. @array_derive.register(ArraySymbol)
  53. def _(expr: ArraySymbol, x: _ArrayExpr):
  54. if expr == x:
  55. return _permute_dims(
  56. ArrayTensorProduct.fromiter(Identity(i) for i in expr.shape),
  57. [2*i for i in range(len(expr.shape))] + [2*i+1 for i in range(len(expr.shape))]
  58. )
  59. return ZeroArray(*(x.shape + expr.shape))
  60. @array_derive.register(MatrixSymbol)
  61. def _(expr: MatrixSymbol, x: _ArrayExpr):
  62. m, n = expr.shape
  63. if expr == x:
  64. return _permute_dims(
  65. _array_tensor_product(Identity(m), Identity(n)),
  66. [0, 2, 1, 3]
  67. )
  68. return ZeroArray(*(x.shape + expr.shape))
  69. @array_derive.register(Identity)
  70. def _(expr: Identity, x: _ArrayExpr):
  71. return ZeroArray(*(x.shape + expr.shape))
  72. @array_derive.register(Transpose)
  73. def _(expr: Transpose, x: Expr):
  74. # D(A.T, A) ==> (m,n,i,j) ==> D(A_ji, A_mn) = d_mj d_ni
  75. # D(B.T, A) ==> (m,n,i,j) ==> D(B_ji, A_mn)
  76. fd = array_derive(expr.arg, x)
  77. return _permute_dims(fd, [0, 1, 3, 2])
  78. @array_derive.register(Inverse)
  79. def _(expr: Inverse, x: Expr):
  80. mat = expr.I
  81. dexpr = array_derive(mat, x)
  82. tp = _array_tensor_product(-expr, dexpr, expr)
  83. mp = _array_contraction(tp, (1, 4), (5, 6))
  84. pp = _permute_dims(mp, [1, 2, 0, 3])
  85. return pp
  86. @array_derive.register(ElementwiseApplyFunction)
  87. def _(expr: ElementwiseApplyFunction, x: Expr):
  88. assert get_rank(expr) == 2
  89. assert get_rank(x) == 2
  90. fdiff = expr._get_function_fdiff()
  91. dexpr = array_derive(expr.expr, x)
  92. tp = _array_tensor_product(
  93. ElementwiseApplyFunction(fdiff, expr.expr),
  94. dexpr
  95. )
  96. td = _array_diagonal(
  97. tp, (0, 4), (1, 5)
  98. )
  99. return td
  100. @array_derive.register(ArrayElementwiseApplyFunc)
  101. def _(expr: ArrayElementwiseApplyFunc, x: Expr):
  102. fdiff = expr._get_function_fdiff()
  103. subexpr = expr.expr
  104. dsubexpr = array_derive(subexpr, x)
  105. tp = _array_tensor_product(
  106. dsubexpr,
  107. ArrayElementwiseApplyFunc(fdiff, subexpr)
  108. )
  109. b = get_rank(x)
  110. c = get_rank(expr)
  111. diag_indices = [(b + i, b + c + i) for i in range(c)]
  112. return _array_diagonal(tp, *diag_indices)
  113. @array_derive.register(MatrixExpr)
  114. def _(expr: MatrixExpr, x: Expr):
  115. cg = convert_matrix_to_array(expr)
  116. return array_derive(cg, x)
  117. @array_derive.register(HadamardProduct)
  118. def _(expr: HadamardProduct, x: Expr):
  119. raise NotImplementedError()
  120. @array_derive.register(ArrayContraction)
  121. def _(expr: ArrayContraction, x: Expr):
  122. fd = array_derive(expr.expr, x)
  123. rank_x = len(get_shape(x))
  124. contraction_indices = expr.contraction_indices
  125. new_contraction_indices = [tuple(j + rank_x for j in i) for i in contraction_indices]
  126. return _array_contraction(fd, *new_contraction_indices)
  127. @array_derive.register(ArrayDiagonal)
  128. def _(expr: ArrayDiagonal, x: Expr):
  129. dsubexpr = array_derive(expr.expr, x)
  130. rank_x = len(get_shape(x))
  131. diag_indices = [[j + rank_x for j in i] for i in expr.diagonal_indices]
  132. return _array_diagonal(dsubexpr, *diag_indices)
  133. @array_derive.register(ArrayAdd)
  134. def _(expr: ArrayAdd, x: Expr):
  135. return _array_add(*[array_derive(arg, x) for arg in expr.args])
  136. @array_derive.register(PermuteDims)
  137. def _(expr: PermuteDims, x: Expr):
  138. de = array_derive(expr.expr, x)
  139. perm = [0, 1] + [i + 2 for i in expr.permutation.array_form]
  140. return _permute_dims(de, perm)
  141. @array_derive.register(Reshape)
  142. def _(expr: Reshape, x: Expr):
  143. de = array_derive(expr.expr, x)
  144. return Reshape(de, get_shape(x) + expr.shape)
  145. def matrix_derive(expr, x):
  146. from sympy.tensor.array.expressions.conv_array_to_matrix import convert_array_to_matrix
  147. ce = convert_matrix_to_array(expr)
  148. dce = array_derive(ce, x)
  149. return convert_array_to_matrix(dce).doit()