m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

198 lines
6.6 KiB

6 months ago
  1. from sympy.core.numbers import Number
  2. from sympy.core.singleton import S
  3. from sympy.core.symbol import Symbol
  4. from sympy.core.sympify import sympify
  5. from sympy.tensor.array.dense_ndim_array import MutableDenseNDimArray
  6. from sympy.tensor.tensor import (Tensor, TensExpr, TensAdd, TensMul,
  7. TensorIndex)
  8. class PartialDerivative(TensExpr):
  9. """
  10. Partial derivative for tensor expressions.
  11. Examples
  12. ========
  13. >>> from sympy.tensor.tensor import TensorIndexType, TensorHead
  14. >>> from sympy.tensor.toperators import PartialDerivative
  15. >>> from sympy import symbols
  16. >>> L = TensorIndexType("L")
  17. >>> A = TensorHead("A", [L])
  18. >>> i, j = symbols("i j")
  19. >>> expr = PartialDerivative(A(i), A(j))
  20. >>> expr
  21. PartialDerivative(A(i), A(j))
  22. The ``PartialDerivative`` object behaves like a tensorial expression:
  23. >>> expr.get_indices()
  24. [i, -j]
  25. Indices can be contracted:
  26. >>> expr = PartialDerivative(A(i), A(i))
  27. >>> expr
  28. PartialDerivative(A(L_0), A(L_0))
  29. >>> expr.get_indices()
  30. [L_0, -L_0]
  31. """
  32. def __new__(cls, expr, *variables):
  33. # Flatten:
  34. if isinstance(expr, PartialDerivative):
  35. variables = expr.variables + variables
  36. expr = expr.expr
  37. args, indices, free, dum = cls._contract_indices_for_derivative(
  38. S(expr), variables)
  39. obj = TensExpr.__new__(cls, *args)
  40. obj._indices = indices
  41. obj._free = free
  42. obj._dum = dum
  43. return obj
  44. @property
  45. def coeff(self):
  46. return S.One
  47. @property
  48. def nocoeff(self):
  49. return self
  50. @classmethod
  51. def _contract_indices_for_derivative(cls, expr, variables):
  52. variables_opposite_valence = []
  53. for i in variables:
  54. if isinstance(i, Tensor):
  55. i_free_indices = i.get_free_indices()
  56. variables_opposite_valence.append(
  57. i.xreplace({k: -k for k in i_free_indices}))
  58. elif isinstance(i, Symbol):
  59. variables_opposite_valence.append(i)
  60. args, indices, free, dum = TensMul._tensMul_contract_indices(
  61. [expr] + variables_opposite_valence, replace_indices=True)
  62. for i in range(1, len(args)):
  63. args_i = args[i]
  64. if isinstance(args_i, Tensor):
  65. i_indices = args[i].get_free_indices()
  66. args[i] = args[i].xreplace({k: -k for k in i_indices})
  67. return args, indices, free, dum
  68. def doit(self):
  69. args, indices, free, dum = self._contract_indices_for_derivative(self.expr, self.variables)
  70. obj = self.func(*args)
  71. obj._indices = indices
  72. obj._free = free
  73. obj._dum = dum
  74. return obj
  75. def _expand_partial_derivative(self):
  76. args, indices, free, dum = self._contract_indices_for_derivative(self.expr, self.variables)
  77. obj = self.func(*args)
  78. obj._indices = indices
  79. obj._free = free
  80. obj._dum = dum
  81. result = obj
  82. if not args[0].free_symbols:
  83. return S.Zero
  84. elif isinstance(obj.expr, TensAdd):
  85. # take care of sums of multi PDs
  86. result = obj.expr.func(*[
  87. self.func(a, *obj.variables)._expand_partial_derivative()
  88. for a in result.expr.args])
  89. elif isinstance(obj.expr, TensMul):
  90. # take care of products of multi PDs
  91. if len(obj.variables) == 1:
  92. # derivative with respect to single variable
  93. terms = []
  94. mulargs = list(obj.expr.args)
  95. for ind in range(len(mulargs)):
  96. if not isinstance(sympify(mulargs[ind]), Number):
  97. # a number coefficient is not considered for
  98. # expansion of PartialDerivative
  99. d = self.func(mulargs[ind], *obj.variables)._expand_partial_derivative()
  100. terms.append(TensMul(*(mulargs[:ind]
  101. + [d]
  102. + mulargs[(ind + 1):])))
  103. result = TensAdd.fromiter(terms)
  104. else:
  105. # derivative with respect to multiple variables
  106. # decompose:
  107. # partial(expr, (u, v))
  108. # = partial(partial(expr, u).doit(), v).doit()
  109. result = obj.expr # init with expr
  110. for v in obj.variables:
  111. result = self.func(result, v)._expand_partial_derivative()
  112. # then throw PD on it
  113. return result
  114. def _perform_derivative(self):
  115. result = self.expr
  116. for v in self.variables:
  117. if isinstance(result, TensExpr):
  118. result = result._eval_partial_derivative(v)
  119. else:
  120. if v._diff_wrt:
  121. result = result._eval_derivative(v)
  122. else:
  123. result = S.Zero
  124. return result
  125. def get_indices(self):
  126. return self._indices
  127. def get_free_indices(self):
  128. free = sorted(self._free, key=lambda x: x[1])
  129. return [i[0] for i in free]
  130. def _replace_indices(self, repl):
  131. expr = self.expr.xreplace(repl)
  132. mirrored = {-k: -v for k, v in repl.items()}
  133. variables = [i.xreplace(mirrored) for i in self.variables]
  134. return self.func(expr, *variables)
  135. @property
  136. def expr(self):
  137. return self.args[0]
  138. @property
  139. def variables(self):
  140. return self.args[1:]
  141. def _extract_data(self, replacement_dict):
  142. from .array import derive_by_array, tensorcontraction
  143. indices, array = self.expr._extract_data(replacement_dict)
  144. for variable in self.variables:
  145. var_indices, var_array = variable._extract_data(replacement_dict)
  146. var_indices = [-i for i in var_indices]
  147. coeff_array, var_array = zip(*[i.as_coeff_Mul() for i in var_array])
  148. array = derive_by_array(array, var_array)
  149. array = array.as_mutable() # type: MutableDenseNDimArray
  150. varindex = var_indices[0] # type: TensorIndex
  151. # Remove coefficients of base vector:
  152. coeff_index = [0] + [slice(None) for i in range(len(indices))]
  153. for i, coeff in enumerate(coeff_array):
  154. coeff_index[0] = i
  155. array[tuple(coeff_index)] /= coeff
  156. if -varindex in indices:
  157. pos = indices.index(-varindex)
  158. array = tensorcontraction(array, (0, pos+1))
  159. indices.pop(pos)
  160. else:
  161. indices.append(varindex)
  162. return indices, array