m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

129 lines
4.8 KiB

6 months ago
  1. from typing import Optional
  2. from sympy.core.expr import Expr
  3. from sympy.core.function import Derivative
  4. from sympy.core.numbers import Integer
  5. from sympy.matrices.common import MatrixCommon
  6. from .ndim_array import NDimArray
  7. from .arrayop import derive_by_array
  8. from sympy.matrices.expressions.matexpr import MatrixExpr
  9. from sympy.matrices.expressions.special import ZeroMatrix
  10. from sympy.matrices.expressions.matexpr import _matrix_derivative
  11. class ArrayDerivative(Derivative):
  12. is_scalar = False
  13. def __new__(cls, expr, *variables, **kwargs):
  14. obj = super().__new__(cls, expr, *variables, **kwargs)
  15. if isinstance(obj, ArrayDerivative):
  16. obj._shape = obj._get_shape()
  17. return obj
  18. def _get_shape(self):
  19. shape = ()
  20. for v, count in self.variable_count:
  21. if hasattr(v, "shape"):
  22. for i in range(count):
  23. shape += v.shape
  24. if hasattr(self.expr, "shape"):
  25. shape += self.expr.shape
  26. return shape
  27. @property
  28. def shape(self):
  29. return self._shape
  30. @classmethod
  31. def _get_zero_with_shape_like(cls, expr):
  32. if isinstance(expr, (MatrixCommon, NDimArray)):
  33. return expr.zeros(*expr.shape)
  34. elif isinstance(expr, MatrixExpr):
  35. return ZeroMatrix(*expr.shape)
  36. else:
  37. raise RuntimeError("Unable to determine shape of array-derivative.")
  38. @staticmethod
  39. def _call_derive_scalar_by_matrix(expr, v): # type: (Expr, MatrixCommon) -> Expr
  40. return v.applyfunc(lambda x: expr.diff(x))
  41. @staticmethod
  42. def _call_derive_scalar_by_matexpr(expr, v): # type: (Expr, MatrixExpr) -> Expr
  43. if expr.has(v):
  44. return _matrix_derivative(expr, v)
  45. else:
  46. return ZeroMatrix(*v.shape)
  47. @staticmethod
  48. def _call_derive_scalar_by_array(expr, v): # type: (Expr, NDimArray) -> Expr
  49. return v.applyfunc(lambda x: expr.diff(x))
  50. @staticmethod
  51. def _call_derive_matrix_by_scalar(expr, v): # type: (MatrixCommon, Expr) -> Expr
  52. return _matrix_derivative(expr, v)
  53. @staticmethod
  54. def _call_derive_matexpr_by_scalar(expr, v): # type: (MatrixExpr, Expr) -> Expr
  55. return expr._eval_derivative(v)
  56. @staticmethod
  57. def _call_derive_array_by_scalar(expr, v): # type: (NDimArray, Expr) -> Expr
  58. return expr.applyfunc(lambda x: x.diff(v))
  59. @staticmethod
  60. def _call_derive_default(expr, v): # type: (Expr, Expr) -> Optional[Expr]
  61. if expr.has(v):
  62. return _matrix_derivative(expr, v)
  63. else:
  64. return None
  65. @classmethod
  66. def _dispatch_eval_derivative_n_times(cls, expr, v, count):
  67. # Evaluate the derivative `n` times. If
  68. # `_eval_derivative_n_times` is not overridden by the current
  69. # object, the default in `Basic` will call a loop over
  70. # `_eval_derivative`:
  71. if not isinstance(count, (int, Integer)) or ((count <= 0) == True):
  72. return None
  73. # TODO: this could be done with multiple-dispatching:
  74. if expr.is_scalar:
  75. if isinstance(v, MatrixCommon):
  76. result = cls._call_derive_scalar_by_matrix(expr, v)
  77. elif isinstance(v, MatrixExpr):
  78. result = cls._call_derive_scalar_by_matexpr(expr, v)
  79. elif isinstance(v, NDimArray):
  80. result = cls._call_derive_scalar_by_array(expr, v)
  81. elif v.is_scalar:
  82. # scalar by scalar has a special
  83. return super()._dispatch_eval_derivative_n_times(expr, v, count)
  84. else:
  85. return None
  86. elif v.is_scalar:
  87. if isinstance(expr, MatrixCommon):
  88. result = cls._call_derive_matrix_by_scalar(expr, v)
  89. elif isinstance(expr, MatrixExpr):
  90. result = cls._call_derive_matexpr_by_scalar(expr, v)
  91. elif isinstance(expr, NDimArray):
  92. result = cls._call_derive_array_by_scalar(expr, v)
  93. else:
  94. return None
  95. else:
  96. # Both `expr` and `v` are some array/matrix type:
  97. if isinstance(expr, MatrixCommon) or isinstance(expr, MatrixCommon):
  98. result = derive_by_array(expr, v)
  99. elif isinstance(expr, MatrixExpr) and isinstance(v, MatrixExpr):
  100. result = cls._call_derive_default(expr, v)
  101. elif isinstance(expr, MatrixExpr) or isinstance(v, MatrixExpr):
  102. # if one expression is a symbolic matrix expression while the other isn't, don't evaluate:
  103. return None
  104. else:
  105. result = derive_by_array(expr, v)
  106. if result is None:
  107. return None
  108. if count == 1:
  109. return result
  110. else:
  111. return cls._dispatch_eval_derivative_n_times(result, v, count - 1)