图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

129 lines
4.8 KiB

  1. from typing import Optional
  2. from sympy.core.expr import Expr
  3. from sympy.core.function import Derivative
  4. from sympy.core.numbers import Integer
  5. from sympy.matrices.common import MatrixCommon
  6. from .ndim_array import NDimArray
  7. from .arrayop import derive_by_array
  8. from sympy.matrices.expressions.matexpr import MatrixExpr
  9. from sympy.matrices.expressions.special import ZeroMatrix
  10. from sympy.matrices.expressions.matexpr import _matrix_derivative
  11. class ArrayDerivative(Derivative):
  12. is_scalar = False
  13. def __new__(cls, expr, *variables, **kwargs):
  14. obj = super().__new__(cls, expr, *variables, **kwargs)
  15. if isinstance(obj, ArrayDerivative):
  16. obj._shape = obj._get_shape()
  17. return obj
  18. def _get_shape(self):
  19. shape = ()
  20. for v, count in self.variable_count:
  21. if hasattr(v, "shape"):
  22. for i in range(count):
  23. shape += v.shape
  24. if hasattr(self.expr, "shape"):
  25. shape += self.expr.shape
  26. return shape
  27. @property
  28. def shape(self):
  29. return self._shape
  30. @classmethod
  31. def _get_zero_with_shape_like(cls, expr):
  32. if isinstance(expr, (MatrixCommon, NDimArray)):
  33. return expr.zeros(*expr.shape)
  34. elif isinstance(expr, MatrixExpr):
  35. return ZeroMatrix(*expr.shape)
  36. else:
  37. raise RuntimeError("Unable to determine shape of array-derivative.")
  38. @staticmethod
  39. def _call_derive_scalar_by_matrix(expr, v): # type: (Expr, MatrixCommon) -> Expr
  40. return v.applyfunc(lambda x: expr.diff(x))
  41. @staticmethod
  42. def _call_derive_scalar_by_matexpr(expr, v): # type: (Expr, MatrixExpr) -> Expr
  43. if expr.has(v):
  44. return _matrix_derivative(expr, v)
  45. else:
  46. return ZeroMatrix(*v.shape)
  47. @staticmethod
  48. def _call_derive_scalar_by_array(expr, v): # type: (Expr, NDimArray) -> Expr
  49. return v.applyfunc(lambda x: expr.diff(x))
  50. @staticmethod
  51. def _call_derive_matrix_by_scalar(expr, v): # type: (MatrixCommon, Expr) -> Expr
  52. return _matrix_derivative(expr, v)
  53. @staticmethod
  54. def _call_derive_matexpr_by_scalar(expr, v): # type: (MatrixExpr, Expr) -> Expr
  55. return expr._eval_derivative(v)
  56. @staticmethod
  57. def _call_derive_array_by_scalar(expr, v): # type: (NDimArray, Expr) -> Expr
  58. return expr.applyfunc(lambda x: x.diff(v))
  59. @staticmethod
  60. def _call_derive_default(expr, v): # type: (Expr, Expr) -> Optional[Expr]
  61. if expr.has(v):
  62. return _matrix_derivative(expr, v)
  63. else:
  64. return None
  65. @classmethod
  66. def _dispatch_eval_derivative_n_times(cls, expr, v, count):
  67. # Evaluate the derivative `n` times. If
  68. # `_eval_derivative_n_times` is not overridden by the current
  69. # object, the default in `Basic` will call a loop over
  70. # `_eval_derivative`:
  71. if not isinstance(count, (int, Integer)) or ((count <= 0) == True):
  72. return None
  73. # TODO: this could be done with multiple-dispatching:
  74. if expr.is_scalar:
  75. if isinstance(v, MatrixCommon):
  76. result = cls._call_derive_scalar_by_matrix(expr, v)
  77. elif isinstance(v, MatrixExpr):
  78. result = cls._call_derive_scalar_by_matexpr(expr, v)
  79. elif isinstance(v, NDimArray):
  80. result = cls._call_derive_scalar_by_array(expr, v)
  81. elif v.is_scalar:
  82. # scalar by scalar has a special
  83. return super()._dispatch_eval_derivative_n_times(expr, v, count)
  84. else:
  85. return None
  86. elif v.is_scalar:
  87. if isinstance(expr, MatrixCommon):
  88. result = cls._call_derive_matrix_by_scalar(expr, v)
  89. elif isinstance(expr, MatrixExpr):
  90. result = cls._call_derive_matexpr_by_scalar(expr, v)
  91. elif isinstance(expr, NDimArray):
  92. result = cls._call_derive_array_by_scalar(expr, v)
  93. else:
  94. return None
  95. else:
  96. # Both `expr` and `v` are some array/matrix type:
  97. if isinstance(expr, MatrixCommon) or isinstance(expr, MatrixCommon):
  98. result = derive_by_array(expr, v)
  99. elif isinstance(expr, MatrixExpr) and isinstance(v, MatrixExpr):
  100. result = cls._call_derive_default(expr, v)
  101. elif isinstance(expr, MatrixExpr) or isinstance(v, MatrixExpr):
  102. # if one expression is a symbolic matrix expression while the other isn't, don't evaluate:
  103. return None
  104. else:
  105. result = derive_by_array(expr, v)
  106. if result is None:
  107. return None
  108. if count == 1:
  109. return result
  110. else:
  111. return cls._dispatch_eval_derivative_n_times(result, v, count - 1)