图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

204 lines
6.6 KiB

  1. from sympy.core.expr import ExprBuilder
  2. from sympy.core.function import (Function, FunctionClass, Lambda)
  3. from sympy.core.symbol import Dummy
  4. from sympy.core.sympify import sympify, _sympify
  5. from sympy.matrices.expressions import MatrixExpr
  6. from sympy.matrices.matrices import MatrixBase
  7. class ElementwiseApplyFunction(MatrixExpr):
  8. r"""
  9. Apply function to a matrix elementwise without evaluating.
  10. Examples
  11. ========
  12. It can be created by calling ``.applyfunc(<function>)`` on a matrix
  13. expression:
  14. >>> from sympy import MatrixSymbol
  15. >>> from sympy.matrices.expressions.applyfunc import ElementwiseApplyFunction
  16. >>> from sympy import exp
  17. >>> X = MatrixSymbol("X", 3, 3)
  18. >>> X.applyfunc(exp)
  19. Lambda(_d, exp(_d)).(X)
  20. Otherwise using the class constructor:
  21. >>> from sympy import eye
  22. >>> expr = ElementwiseApplyFunction(exp, eye(3))
  23. >>> expr
  24. Lambda(_d, exp(_d)).(Matrix([
  25. [1, 0, 0],
  26. [0, 1, 0],
  27. [0, 0, 1]]))
  28. >>> expr.doit()
  29. Matrix([
  30. [E, 1, 1],
  31. [1, E, 1],
  32. [1, 1, E]])
  33. Notice the difference with the real mathematical functions:
  34. >>> exp(eye(3))
  35. Matrix([
  36. [E, 0, 0],
  37. [0, E, 0],
  38. [0, 0, E]])
  39. """
  40. def __new__(cls, function, expr):
  41. expr = _sympify(expr)
  42. if not expr.is_Matrix:
  43. raise ValueError("{} must be a matrix instance.".format(expr))
  44. if expr.shape == (1, 1):
  45. # Check if the function returns a matrix, in that case, just apply
  46. # the function instead of creating an ElementwiseApplyFunc object:
  47. ret = function(expr)
  48. if isinstance(ret, MatrixExpr):
  49. return ret
  50. if not isinstance(function, (FunctionClass, Lambda)):
  51. d = Dummy('d')
  52. function = Lambda(d, function(d))
  53. function = sympify(function)
  54. if not isinstance(function, (FunctionClass, Lambda)):
  55. raise ValueError(
  56. "{} should be compatible with SymPy function classes."
  57. .format(function))
  58. if 1 not in function.nargs:
  59. raise ValueError(
  60. '{} should be able to accept 1 arguments.'.format(function))
  61. if not isinstance(function, Lambda):
  62. d = Dummy('d')
  63. function = Lambda(d, function(d))
  64. obj = MatrixExpr.__new__(cls, function, expr)
  65. return obj
  66. @property
  67. def function(self):
  68. return self.args[0]
  69. @property
  70. def expr(self):
  71. return self.args[1]
  72. @property
  73. def shape(self):
  74. return self.expr.shape
  75. def doit(self, **kwargs):
  76. deep = kwargs.get("deep", True)
  77. expr = self.expr
  78. if deep:
  79. expr = expr.doit(**kwargs)
  80. function = self.function
  81. if isinstance(function, Lambda) and function.is_identity:
  82. # This is a Lambda containing the identity function.
  83. return expr
  84. if isinstance(expr, MatrixBase):
  85. return expr.applyfunc(self.function)
  86. elif isinstance(expr, ElementwiseApplyFunction):
  87. return ElementwiseApplyFunction(
  88. lambda x: self.function(expr.function(x)),
  89. expr.expr
  90. ).doit()
  91. else:
  92. return self
  93. def _entry(self, i, j, **kwargs):
  94. return self.function(self.expr._entry(i, j, **kwargs))
  95. def _get_function_fdiff(self):
  96. d = Dummy("d")
  97. function = self.function(d)
  98. fdiff = function.diff(d)
  99. if isinstance(fdiff, Function):
  100. fdiff = type(fdiff)
  101. else:
  102. fdiff = Lambda(d, fdiff)
  103. return fdiff
  104. def _eval_derivative(self, x):
  105. from sympy.matrices.expressions.hadamard import hadamard_product
  106. dexpr = self.expr.diff(x)
  107. fdiff = self._get_function_fdiff()
  108. return hadamard_product(
  109. dexpr,
  110. ElementwiseApplyFunction(fdiff, self.expr)
  111. )
  112. def _eval_derivative_matrix_lines(self, x):
  113. from sympy.matrices.expressions.special import Identity
  114. from sympy.tensor.array.expressions.array_expressions import ArrayContraction
  115. from sympy.tensor.array.expressions.array_expressions import ArrayDiagonal
  116. from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct
  117. fdiff = self._get_function_fdiff()
  118. lr = self.expr._eval_derivative_matrix_lines(x)
  119. ewdiff = ElementwiseApplyFunction(fdiff, self.expr)
  120. if 1 in x.shape:
  121. # Vector:
  122. iscolumn = self.shape[1] == 1
  123. for i in lr:
  124. if iscolumn:
  125. ptr1 = i.first_pointer
  126. ptr2 = Identity(self.shape[1])
  127. else:
  128. ptr1 = Identity(self.shape[0])
  129. ptr2 = i.second_pointer
  130. subexpr = ExprBuilder(
  131. ArrayDiagonal,
  132. [
  133. ExprBuilder(
  134. ArrayTensorProduct,
  135. [
  136. ewdiff,
  137. ptr1,
  138. ptr2,
  139. ]
  140. ),
  141. (0, 2) if iscolumn else (1, 4)
  142. ],
  143. validator=ArrayDiagonal._validate
  144. )
  145. i._lines = [subexpr]
  146. i._first_pointer_parent = subexpr.args[0].args
  147. i._first_pointer_index = 1
  148. i._second_pointer_parent = subexpr.args[0].args
  149. i._second_pointer_index = 2
  150. else:
  151. # Matrix case:
  152. for i in lr:
  153. ptr1 = i.first_pointer
  154. ptr2 = i.second_pointer
  155. newptr1 = Identity(ptr1.shape[1])
  156. newptr2 = Identity(ptr2.shape[1])
  157. subexpr = ExprBuilder(
  158. ArrayContraction,
  159. [
  160. ExprBuilder(
  161. ArrayTensorProduct,
  162. [ptr1, newptr1, ewdiff, ptr2, newptr2]
  163. ),
  164. (1, 2, 4),
  165. (5, 7, 8),
  166. ],
  167. validator=ArrayContraction._validate
  168. )
  169. i._first_pointer_parent = subexpr.args[0].args
  170. i._first_pointer_index = 1
  171. i._second_pointer_parent = subexpr.args[0].args
  172. i._second_pointer_index = 4
  173. i._lines = [subexpr]
  174. return lr
  175. def _eval_transpose(self):
  176. from sympy.matrices.expressions.transpose import Transpose
  177. return self.func(self.function, Transpose(self.expr).doit())