图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

149 lines
4.4 KiB

  1. from functools import reduce
  2. import operator
  3. from sympy.core import Basic, sympify
  4. from sympy.core.add import add, Add, _could_extract_minus_sign
  5. from sympy.core.sorting import default_sort_key
  6. from sympy.functions import adjoint
  7. from sympy.matrices.common import ShapeError
  8. from sympy.matrices.matrices import MatrixBase
  9. from sympy.matrices.expressions.transpose import transpose
  10. from sympy.strategies import (rm_id, unpack, flatten, sort, condition,
  11. exhaust, do_one, glom)
  12. from sympy.matrices.expressions.matexpr import MatrixExpr
  13. from sympy.matrices.expressions.special import ZeroMatrix, GenericZeroMatrix
  14. from sympy.utilities import sift
  15. # XXX: MatAdd should perhaps not subclass directly from Add
  16. class MatAdd(MatrixExpr, Add):
  17. """A Sum of Matrix Expressions
  18. MatAdd inherits from and operates like SymPy Add
  19. Examples
  20. ========
  21. >>> from sympy import MatAdd, MatrixSymbol
  22. >>> A = MatrixSymbol('A', 5, 5)
  23. >>> B = MatrixSymbol('B', 5, 5)
  24. >>> C = MatrixSymbol('C', 5, 5)
  25. >>> MatAdd(A, B, C)
  26. A + B + C
  27. """
  28. is_MatAdd = True
  29. identity = GenericZeroMatrix()
  30. def __new__(cls, *args, evaluate=False, check=False, _sympify=True):
  31. if not args:
  32. return cls.identity
  33. # This must be removed aggressively in the constructor to avoid
  34. # TypeErrors from GenericZeroMatrix().shape
  35. args = list(filter(lambda i: cls.identity != i, args))
  36. if _sympify:
  37. args = list(map(sympify, args))
  38. obj = Basic.__new__(cls, *args)
  39. if check:
  40. if not any(isinstance(i, MatrixExpr) for i in args):
  41. return Add.fromiter(args)
  42. validate(*args)
  43. if evaluate:
  44. if not any(isinstance(i, MatrixExpr) for i in args):
  45. return Add(*args, evaluate=True)
  46. obj = canonicalize(obj)
  47. return obj
  48. @property
  49. def shape(self):
  50. return self.args[0].shape
  51. def could_extract_minus_sign(self):
  52. return _could_extract_minus_sign(self)
  53. def _entry(self, i, j, **kwargs):
  54. return Add(*[arg._entry(i, j, **kwargs) for arg in self.args])
  55. def _eval_transpose(self):
  56. return MatAdd(*[transpose(arg) for arg in self.args]).doit()
  57. def _eval_adjoint(self):
  58. return MatAdd(*[adjoint(arg) for arg in self.args]).doit()
  59. def _eval_trace(self):
  60. from .trace import trace
  61. return Add(*[trace(arg) for arg in self.args]).doit()
  62. def doit(self, **kwargs):
  63. deep = kwargs.get('deep', True)
  64. if deep:
  65. args = [arg.doit(**kwargs) for arg in self.args]
  66. else:
  67. args = self.args
  68. return canonicalize(MatAdd(*args))
  69. def _eval_derivative_matrix_lines(self, x):
  70. add_lines = [arg._eval_derivative_matrix_lines(x) for arg in self.args]
  71. return [j for i in add_lines for j in i]
  72. add.register_handlerclass((Add, MatAdd), MatAdd)
  73. def validate(*args):
  74. if not all(arg.is_Matrix for arg in args):
  75. raise TypeError("Mix of Matrix and Scalar symbols")
  76. A = args[0]
  77. for B in args[1:]:
  78. if A.shape != B.shape:
  79. raise ShapeError("Matrices %s and %s are not aligned"%(A, B))
  80. factor_of = lambda arg: arg.as_coeff_mmul()[0]
  81. matrix_of = lambda arg: unpack(arg.as_coeff_mmul()[1])
  82. def combine(cnt, mat):
  83. if cnt == 1:
  84. return mat
  85. else:
  86. return cnt * mat
  87. def merge_explicit(matadd):
  88. """ Merge explicit MatrixBase arguments
  89. Examples
  90. ========
  91. >>> from sympy import MatrixSymbol, eye, Matrix, MatAdd, pprint
  92. >>> from sympy.matrices.expressions.matadd import merge_explicit
  93. >>> A = MatrixSymbol('A', 2, 2)
  94. >>> B = eye(2)
  95. >>> C = Matrix([[1, 2], [3, 4]])
  96. >>> X = MatAdd(A, B, C)
  97. >>> pprint(X)
  98. [1 0] [1 2]
  99. A + [ ] + [ ]
  100. [0 1] [3 4]
  101. >>> pprint(merge_explicit(X))
  102. [2 2]
  103. A + [ ]
  104. [3 5]
  105. """
  106. groups = sift(matadd.args, lambda arg: isinstance(arg, MatrixBase))
  107. if len(groups[True]) > 1:
  108. return MatAdd(*(groups[False] + [reduce(operator.add, groups[True])]))
  109. else:
  110. return matadd
  111. rules = (rm_id(lambda x: x == 0 or isinstance(x, ZeroMatrix)),
  112. unpack,
  113. flatten,
  114. glom(matrix_of, factor_of, combine),
  115. merge_explicit,
  116. sort(default_sort_key))
  117. canonicalize = exhaust(condition(lambda x: isinstance(x, MatAdd),
  118. do_one(*rules)))