m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

467 lines
14 KiB

from sympy.assumptions.ask import ask, Q
from sympy.assumptions.refine import handlers_dict
from sympy.core import Basic, sympify, S
from sympy.core.mul import mul, Mul
from sympy.core.numbers import Number, Integer
from sympy.core.symbol import Dummy
from sympy.functions import adjoint
from sympy.strategies import (rm_id, unpack, typed, flatten, exhaust,
do_one, new)
from sympy.matrices.common import ShapeError, NonInvertibleMatrixError
from sympy.matrices.matrices import MatrixBase
from .inverse import Inverse
from .matexpr import MatrixExpr
from .matpow import MatPow
from .transpose import transpose
from .permutation import PermutationMatrix
from .special import ZeroMatrix, Identity, GenericIdentity, OneMatrix
# XXX: MatMul should perhaps not subclass directly from Mul
class MatMul(MatrixExpr, Mul):
"""
A product of matrix expressions
Examples
========
>>> from sympy import MatMul, MatrixSymbol
>>> A = MatrixSymbol('A', 5, 4)
>>> B = MatrixSymbol('B', 4, 3)
>>> C = MatrixSymbol('C', 3, 6)
>>> MatMul(A, B, C)
A*B*C
"""
is_MatMul = True
identity = GenericIdentity()
def __new__(cls, *args, evaluate=False, check=True, _sympify=True):
if not args:
return cls.identity
# This must be removed aggressively in the constructor to avoid
# TypeErrors from GenericIdentity().shape
args = list(filter(lambda i: cls.identity != i, args))
if _sympify:
args = list(map(sympify, args))
obj = Basic.__new__(cls, *args)
factor, matrices = obj.as_coeff_matrices()
if check:
validate(*matrices)
if not matrices:
# Should it be
#
# return Basic.__neq__(cls, factor, GenericIdentity()) ?
return factor
if evaluate:
return canonicalize(obj)
return obj
@property
def shape(self):
matrices = [arg for arg in self.args if arg.is_Matrix]
return (matrices[0].rows, matrices[-1].cols)
def could_extract_minus_sign(self):
return self.args[0].could_extract_minus_sign()
def _entry(self, i, j, expand=True, **kwargs):
# Avoid cyclic imports
from sympy.concrete.summations import Sum
from sympy.matrices.immutable import ImmutableMatrix
coeff, matrices = self.as_coeff_matrices()
if len(matrices) == 1: # situation like 2*X, matmul is just X
return coeff * matrices[0][i, j]
indices = [None]*(len(matrices) + 1)
ind_ranges = [None]*(len(matrices) - 1)
indices[0] = i
indices[-1] = j
def f():
counter = 1
while True:
yield Dummy("i_%i" % counter)
counter += 1
dummy_generator = kwargs.get("dummy_generator", f())
for i in range(1, len(matrices)):
indices[i] = next(dummy_generator)
for i, arg in enumerate(matrices[:-1]):
ind_ranges[i] = arg.shape[1] - 1
matrices = [arg._entry(indices[i], indices[i+1], dummy_generator=dummy_generator) for i, arg in enumerate(matrices)]
expr_in_sum = Mul.fromiter(matrices)
if any(v.has(ImmutableMatrix) for v in matrices):
expand = True
result = coeff*Sum(
expr_in_sum,
*zip(indices[1:-1], [0]*len(ind_ranges), ind_ranges)
)
# Don't waste time in result.doit() if the sum bounds are symbolic
if not any(isinstance(v, (Integer, int)) for v in ind_ranges):
expand = False
return result.doit() if expand else result
def as_coeff_matrices(self):
scalars = [x for x in self.args if not x.is_Matrix]
matrices = [x for x in self.args if x.is_Matrix]
coeff = Mul(*scalars)
if coeff.is_commutative is False:
raise NotImplementedError("noncommutative scalars in MatMul are not supported.")
return coeff, matrices
def as_coeff_mmul(self):
coeff, matrices = self.as_coeff_matrices()
return coeff, MatMul(*matrices)
def _eval_transpose(self):
"""Transposition of matrix multiplication.
Notes
=====
The following rules are applied.
Transposition for matrix multiplied with another matrix:
`\\left(A B\\right)^{T} = B^{T} A^{T}`
Transposition for matrix multiplied with scalar:
`\\left(c A\\right)^{T} = c A^{T}`
References
==========
.. [1] https://en.wikipedia.org/wiki/Transpose
"""
coeff, matrices = self.as_coeff_matrices()
return MatMul(
coeff, *[transpose(arg) for arg in matrices[::-1]]).doit()
def _eval_adjoint(self):
return MatMul(*[adjoint(arg) for arg in self.args[::-1]]).doit()
def _eval_trace(self):
factor, mmul = self.as_coeff_mmul()
if factor != 1:
from .trace import trace
return factor * trace(mmul.doit())
else:
raise NotImplementedError("Can't simplify any further")
def _eval_determinant(self):
from sympy.matrices.expressions.determinant import Determinant
factor, matrices = self.as_coeff_matrices()
square_matrices = only_squares(*matrices)
return factor**self.rows * Mul(*list(map(Determinant, square_matrices)))
def _eval_inverse(self):
try:
return MatMul(*[
arg.inverse() if isinstance(arg, MatrixExpr) else arg**-1
for arg in self.args[::-1]]).doit()
except ShapeError:
return Inverse(self)
def doit(self, **kwargs):
deep = kwargs.get('deep', True)
if deep:
args = [arg.doit(**kwargs) for arg in self.args]
else:
args = self.args
# treat scalar*MatrixSymbol or scalar*MatPow separately
expr = canonicalize(MatMul(*args))
return expr
# Needed for partial compatibility with Mul
def args_cnc(self, **kwargs):
coeff_c = [x for x in self.args if x.is_commutative]
coeff_nc = [x for x in self.args if not x.is_commutative]
return [coeff_c, coeff_nc]
def _eval_derivative_matrix_lines(self, x):
from .transpose import Transpose
with_x_ind = [i for i, arg in enumerate(self.args) if arg.has(x)]
lines = []
for ind in with_x_ind:
left_args = self.args[:ind]
right_args = self.args[ind+1:]
if right_args:
right_mat = MatMul.fromiter(right_args)
else:
right_mat = Identity(self.shape[1])
if left_args:
left_rev = MatMul.fromiter([Transpose(i).doit() if i.is_Matrix else i for i in reversed(left_args)])
else:
left_rev = Identity(self.shape[0])
d = self.args[ind]._eval_derivative_matrix_lines(x)
for i in d:
i.append_first(left_rev)
i.append_second(right_mat)
lines.append(i)
return lines
mul.register_handlerclass((Mul, MatMul), MatMul)
def validate(*matrices):
""" Checks for valid shapes for args of MatMul """
for i in range(len(matrices)-1):
A, B = matrices[i:i+2]
if A.cols != B.rows:
raise ShapeError("Matrices %s and %s are not aligned"%(A, B))
# Rules
def newmul(*args):
if args[0] == 1:
args = args[1:]
return new(MatMul, *args)
def any_zeros(mul):
if any(arg.is_zero or (arg.is_Matrix and arg.is_ZeroMatrix)
for arg in mul.args):
matrices = [arg for arg in mul.args if arg.is_Matrix]
return ZeroMatrix(matrices[0].rows, matrices[-1].cols)
return mul
def merge_explicit(matmul):
""" Merge explicit MatrixBase arguments
>>> from sympy import MatrixSymbol, Matrix, MatMul, pprint
>>> from sympy.matrices.expressions.matmul import merge_explicit
>>> A = MatrixSymbol('A', 2, 2)
>>> B = Matrix([[1, 1], [1, 1]])
>>> C = Matrix([[1, 2], [3, 4]])
>>> X = MatMul(A, B, C)
>>> pprint(X)
[1 1] [1 2]
A*[ ]*[ ]
[1 1] [3 4]
>>> pprint(merge_explicit(X))
[4 6]
A*[ ]
[4 6]
>>> X = MatMul(B, A, C)
>>> pprint(X)
[1 1] [1 2]
[ ]*A*[ ]
[1 1] [3 4]
>>> pprint(merge_explicit(X))
[1 1] [1 2]
[ ]*A*[ ]
[1 1] [3 4]
"""
if not any(isinstance(arg, MatrixBase) for arg in matmul.args):
return matmul
newargs = []
last = matmul.args[0]
for arg in matmul.args[1:]:
if isinstance(arg, (MatrixBase, Number)) and isinstance(last, (MatrixBase, Number)):
last = last * arg
else:
newargs.append(last)
last = arg
newargs.append(last)
return MatMul(*newargs)
def remove_ids(mul):
""" Remove Identities from a MatMul
This is a modified version of sympy.strategies.rm_id.
This is necesssary because MatMul may contain both MatrixExprs and Exprs
as args.
See Also
========
sympy.strategies.rm_id
"""
# Separate Exprs from MatrixExprs in args
factor, mmul = mul.as_coeff_mmul()
# Apply standard rm_id for MatMuls
result = rm_id(lambda x: x.is_Identity is True)(mmul)
if result != mmul:
return newmul(factor, *result.args) # Recombine and return
else:
return mul
def factor_in_front(mul):
factor, matrices = mul.as_coeff_matrices()
if factor != 1:
return newmul(factor, *matrices)
return mul
def combine_powers(mul):
r"""Combine consecutive powers with the same base into one, e.g.
$$A \times A^2 \Rightarrow A^3$$
This also cancels out the possible matrix inverses using the
knowledgebase of :class:`~.Inverse`, e.g.,
$$ Y \times X \times X^{-1} \Rightarrow Y $$
"""
factor, args = mul.as_coeff_matrices()
new_args = [args[0]]
for B in args[1:]:
A = new_args[-1]
if A.is_square == False or B.is_square == False:
new_args.append(B)
continue
if isinstance(A, MatPow):
A_base, A_exp = A.args
else:
A_base, A_exp = A, S.One
if isinstance(B, MatPow):
B_base, B_exp = B.args
else:
B_base, B_exp = B, S.One
if A_base == B_base:
new_exp = A_exp + B_exp
new_args[-1] = MatPow(A_base, new_exp).doit(deep=False)
continue
elif not isinstance(B_base, MatrixBase):
try:
B_base_inv = B_base.inverse()
except NonInvertibleMatrixError:
B_base_inv = None
if B_base_inv is not None and A_base == B_base_inv:
new_exp = A_exp - B_exp
new_args[-1] = MatPow(A_base, new_exp).doit(deep=False)
continue
new_args.append(B)
return newmul(factor, *new_args)
def combine_permutations(mul):
"""Refine products of permutation matrices as the products of cycles.
"""
args = mul.args
l = len(args)
if l < 2:
return mul
result = [args[0]]
for i in range(1, l):
A = result[-1]
B = args[i]
if isinstance(A, PermutationMatrix) and \
isinstance(B, PermutationMatrix):
cycle_1 = A.args[0]
cycle_2 = B.args[0]
result[-1] = PermutationMatrix(cycle_1 * cycle_2)
else:
result.append(B)
return MatMul(*result)
def combine_one_matrices(mul):
"""
Combine products of OneMatrix
e.g. OneMatrix(2, 3) * OneMatrix(3, 4) -> 3 * OneMatrix(2, 4)
"""
factor, args = mul.as_coeff_matrices()
new_args = [args[0]]
for B in args[1:]:
A = new_args[-1]
if not isinstance(A, OneMatrix) or not isinstance(B, OneMatrix):
new_args.append(B)
continue
new_args.pop()
new_args.append(OneMatrix(A.shape[0], B.shape[1]))
factor *= A.shape[1]
return newmul(factor, *new_args)
def distribute_monom(mul):
"""
Simplify MatMul expressions but distributing
rational term to MatMul.
e.g. 2*(A+B) -> 2*A + 2*B
"""
args = mul.args
if len(args) == 2:
from .matadd import MatAdd
if args[0].is_MatAdd and args[1].is_Rational:
return MatAdd(*[MatMul(mat, args[1]).doit() for mat in args[0].args])
if args[1].is_MatAdd and args[0].is_Rational:
return MatAdd(*[MatMul(args[0], mat).doit() for mat in args[1].args])
return mul
rules = (
distribute_monom, any_zeros, remove_ids, combine_one_matrices, combine_powers, unpack, rm_id(lambda x: x == 1),
merge_explicit, factor_in_front, flatten, combine_permutations)
canonicalize = exhaust(typed({MatMul: do_one(*rules)}))
def only_squares(*matrices):
"""factor matrices only if they are square"""
if matrices[0].rows != matrices[-1].cols:
raise RuntimeError("Invalid matrices being multiplied")
out = []
start = 0
for i, M in enumerate(matrices):
if M.cols == matrices[start].rows:
out.append(MatMul(*matrices[start:i+1]).doit())
start = i+1
return out
def refine_MatMul(expr, assumptions):
"""
>>> from sympy import MatrixSymbol, Q, assuming, refine
>>> X = MatrixSymbol('X', 2, 2)
>>> expr = X * X.T
>>> print(expr)
X*X.T
>>> with assuming(Q.orthogonal(X)):
... print(refine(expr))
I
"""
newargs = []
exprargs = []
for args in expr.args:
if args.is_Matrix:
exprargs.append(args)
else:
newargs.append(args)
last = exprargs[0]
for arg in exprargs[1:]:
if arg == last.T and ask(Q.orthogonal(arg), assumptions):
last = Identity(arg.shape[0])
elif arg == last.conjugate() and ask(Q.unitary(arg), assumptions):
last = Identity(arg.shape[0])
else:
newargs.append(last)
last = arg
newargs.append(last)
return MatMul(*newargs)
handlers_dict['MatMul'] = refine_MatMul