图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

142 lines
4.8 KiB

  1. from sympy.matrices.common import NonSquareMatrixError
  2. from .matexpr import MatrixExpr
  3. from .special import Identity
  4. from sympy.core import S
  5. from sympy.core.expr import ExprBuilder
  6. from sympy.core.cache import cacheit
  7. from sympy.core.power import Pow
  8. from sympy.core.sympify import _sympify
  9. from sympy.matrices import MatrixBase
  10. class MatPow(MatrixExpr):
  11. def __new__(cls, base, exp, evaluate=False, **options):
  12. base = _sympify(base)
  13. if not base.is_Matrix:
  14. raise TypeError("MatPow base should be a matrix")
  15. if not base.is_square:
  16. raise NonSquareMatrixError("Power of non-square matrix %s" % base)
  17. exp = _sympify(exp)
  18. obj = super().__new__(cls, base, exp)
  19. if evaluate:
  20. obj = obj.doit(deep=False)
  21. return obj
  22. @property
  23. def base(self):
  24. return self.args[0]
  25. @property
  26. def exp(self):
  27. return self.args[1]
  28. @property
  29. def shape(self):
  30. return self.base.shape
  31. @cacheit
  32. def _get_explicit_matrix(self):
  33. return self.base.as_explicit()**self.exp
  34. def _entry(self, i, j, **kwargs):
  35. from sympy.matrices.expressions import MatMul
  36. A = self.doit()
  37. if isinstance(A, MatPow):
  38. # We still have a MatPow, make an explicit MatMul out of it.
  39. if A.exp.is_Integer and A.exp.is_positive:
  40. A = MatMul(*[A.base for k in range(A.exp)])
  41. elif not self._is_shape_symbolic():
  42. return A._get_explicit_matrix()[i, j]
  43. else:
  44. # Leave the expression unevaluated:
  45. from sympy.matrices.expressions.matexpr import MatrixElement
  46. return MatrixElement(self, i, j)
  47. return A[i, j]
  48. def doit(self, **kwargs):
  49. if kwargs.get('deep', True):
  50. base, exp = [arg.doit(**kwargs) for arg in self.args]
  51. else:
  52. base, exp = self.args
  53. # combine all powers, e.g. (A ** 2) ** 3 -> A ** 6
  54. while isinstance(base, MatPow):
  55. exp *= base.args[1]
  56. base = base.args[0]
  57. if isinstance(base, MatrixBase):
  58. # Delegate
  59. return base ** exp
  60. # Handle simple cases so that _eval_power() in MatrixExpr sub-classes can ignore them
  61. if exp == S.One:
  62. return base
  63. if exp == S.Zero:
  64. return Identity(base.rows)
  65. if exp == S.NegativeOne:
  66. from sympy.matrices.expressions import Inverse
  67. return Inverse(base).doit(**kwargs)
  68. eval_power = getattr(base, '_eval_power', None)
  69. if eval_power is not None:
  70. return eval_power(exp)
  71. return MatPow(base, exp)
  72. def _eval_transpose(self):
  73. base, exp = self.args
  74. return MatPow(base.T, exp)
  75. def _eval_derivative(self, x):
  76. return Pow._eval_derivative(self, x)
  77. def _eval_derivative_matrix_lines(self, x):
  78. from sympy.tensor.array.expressions.array_expressions import ArrayContraction
  79. from ...tensor.array.expressions.array_expressions import ArrayTensorProduct
  80. from .matmul import MatMul
  81. from .inverse import Inverse
  82. exp = self.exp
  83. if self.base.shape == (1, 1) and not exp.has(x):
  84. lr = self.base._eval_derivative_matrix_lines(x)
  85. for i in lr:
  86. subexpr = ExprBuilder(
  87. ArrayContraction,
  88. [
  89. ExprBuilder(
  90. ArrayTensorProduct,
  91. [
  92. Identity(1),
  93. i._lines[0],
  94. exp*self.base**(exp-1),
  95. i._lines[1],
  96. Identity(1),
  97. ]
  98. ),
  99. (0, 3, 4), (5, 7, 8)
  100. ],
  101. validator=ArrayContraction._validate
  102. )
  103. i._first_pointer_parent = subexpr.args[0].args
  104. i._first_pointer_index = 0
  105. i._second_pointer_parent = subexpr.args[0].args
  106. i._second_pointer_index = 4
  107. i._lines = [subexpr]
  108. return lr
  109. if (exp > 0) == True:
  110. newexpr = MatMul.fromiter([self.base for i in range(exp)])
  111. elif (exp == -1) == True:
  112. return Inverse(self.base)._eval_derivative_matrix_lines(x)
  113. elif (exp < 0) == True:
  114. newexpr = MatMul.fromiter([Inverse(self.base) for i in range(-exp)])
  115. elif (exp == 0) == True:
  116. return self.doit()._eval_derivative_matrix_lines(x)
  117. else:
  118. raise NotImplementedError("cannot evaluate %s derived by %s" % (self, x))
  119. return newexpr._eval_derivative_matrix_lines(x)
  120. def _eval_inverse(self):
  121. return MatPow(self.base, -self.exp)