m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

209 lines
6.6 KiB

#
# sympy.polys.matrices.linsolve module
#
# This module defines the _linsolve function which is the internal workhorse
# used by linsolve. This computes the solution of a system of linear equations
# using the SDM sparse matrix implementation in sympy.polys.matrices.sdm. This
# is a replacement for solve_lin_sys in sympy.polys.solvers which is
# inefficient for large sparse systems due to the use of a PolyRing with many
# generators:
#
# https://github.com/sympy/sympy/issues/20857
#
# The implementation of _linsolve here handles:
#
# - Extracting the coefficients from the Expr/Eq input equations.
# - Constructing a domain and converting the coefficients to
# that domain.
# - Using the SDM.rref, SDM.nullspace etc methods to generate the full
# solution working with arithmetic only in the domain of the coefficients.
#
# The routines here are particularly designed to be efficient for large sparse
# systems of linear equations although as well as dense systems. It is
# possible that for some small dense systems solve_lin_sys which uses the
# dense matrix implementation DDM will be more efficient. With smaller systems
# though the bulk of the time is spent just preprocessing the inputs and the
# relative time spent in rref is too small to be noticeable.
#
from collections import defaultdict
from sympy.core.add import Add
from sympy.core.mul import Mul
from sympy.core.singleton import S
from sympy.polys.constructor import construct_domain
from sympy.polys.solvers import PolyNonlinearError
from .sdm import (
SDM,
sdm_irref,
sdm_particular_from_rref,
sdm_nullspace_from_rref
)
def _linsolve(eqs, syms):
"""Solve a linear system of equations.
Examples
========
Solve a linear system with a unique solution:
>>> from sympy import symbols, Eq
>>> from sympy.polys.matrices.linsolve import _linsolve
>>> x, y = symbols('x, y')
>>> eqs = [Eq(x + y, 1), Eq(x - y, 2)]
>>> _linsolve(eqs, [x, y])
{x: 3/2, y: -1/2}
In the case of underdetermined systems the solution will be expressed in
terms of the unknown symbols that are unconstrained:
>>> _linsolve([Eq(x + y, 0)], [x, y])
{x: -y, y: y}
"""
# Number of unknowns (columns in the non-augmented matrix)
nsyms = len(syms)
# Convert to sparse augmented matrix (len(eqs) x (nsyms+1))
eqsdict, rhs = _linear_eq_to_dict(eqs, syms)
Aaug = sympy_dict_to_dm(eqsdict, rhs, syms)
K = Aaug.domain
# sdm_irref has issues with float matrices. This uses the ddm_rref()
# function. When sdm_rref() can handle float matrices reasonably this
# should be removed...
if K.is_RealField or K.is_ComplexField:
Aaug = Aaug.to_ddm().rref()[0].to_sdm()
# Compute reduced-row echelon form (RREF)
Arref, pivots, nzcols = sdm_irref(Aaug)
# No solution:
if pivots and pivots[-1] == nsyms:
return None
# Particular solution for non-homogeneous system:
P = sdm_particular_from_rref(Arref, nsyms+1, pivots)
# Nullspace - general solution to homogeneous system
# Note: using nsyms not nsyms+1 to ignore last column
V, nonpivots = sdm_nullspace_from_rref(Arref, K.one, nsyms, pivots, nzcols)
# Collect together terms from particular and nullspace:
sol = defaultdict(list)
for i, v in P.items():
sol[syms[i]].append(K.to_sympy(v))
for npi, Vi in zip(nonpivots, V):
sym = syms[npi]
for i, v in Vi.items():
sol[syms[i]].append(sym * K.to_sympy(v))
# Use a single call to Add for each term:
sol = {s: Add(*terms) for s, terms in sol.items()}
# Fill in the zeros:
zero = S.Zero
for s in set(syms) - set(sol):
sol[s] = zero
# All done!
return sol
def sympy_dict_to_dm(eqs_coeffs, eqs_rhs, syms):
"""Convert a system of dict equations to a sparse augmented matrix"""
elems = set(eqs_rhs).union(*(e.values() for e in eqs_coeffs))
K, elems_K = construct_domain(elems, field=True, extension=True)
elem_map = dict(zip(elems, elems_K))
neqs = len(eqs_coeffs)
nsyms = len(syms)
sym2index = dict(zip(syms, range(nsyms)))
eqsdict = []
for eq, rhs in zip(eqs_coeffs, eqs_rhs):
eqdict = {sym2index[s]: elem_map[c] for s, c in eq.items()}
if rhs:
eqdict[nsyms] = - elem_map[rhs]
if eqdict:
eqsdict.append(eqdict)
sdm_aug = SDM(enumerate(eqsdict), (neqs, nsyms+1), K)
return sdm_aug
def _expand_eqs_deprecated(eqs):
"""Use expand to cancel nonlinear terms.
This approach matches previous behaviour of linsolve but should be
deprecated.
"""
def expand_eq(eq):
if eq.is_Equality:
eq = eq.lhs - eq.rhs
return eq.expand()
return [expand_eq(eq) for eq in eqs]
def _linear_eq_to_dict(eqs, syms):
"""Convert a system Expr/Eq equations into dict form"""
try:
return _linear_eq_to_dict_inner(eqs, syms)
except PolyNonlinearError:
# XXX: This should be deprecated:
eqs = _expand_eqs_deprecated(eqs)
return _linear_eq_to_dict_inner(eqs, syms)
def _linear_eq_to_dict_inner(eqs, syms):
"""Convert a system Expr/Eq equations into dict form"""
syms = set(syms)
eqsdict, eqs_rhs = [], []
for eq in eqs:
rhs, eqdict = _lin_eq2dict(eq, syms)
eqsdict.append(eqdict)
eqs_rhs.append(rhs)
return eqsdict, eqs_rhs
def _lin_eq2dict(a, symset):
"""Efficiently convert a linear equation to a dict of coefficients"""
if a in symset:
return S.Zero, {a: S.One}
elif a.is_Add:
terms_list = defaultdict(list)
coeff_list = []
for ai in a.args:
ci, ti = _lin_eq2dict(ai, symset)
coeff_list.append(ci)
for mij, cij in ti.items():
terms_list[mij].append(cij)
coeff = Add(*coeff_list)
terms = {sym: Add(*coeffs) for sym, coeffs in terms_list.items()}
return coeff, terms
elif a.is_Mul:
terms = terms_coeff = None
coeff_list = []
for ai in a.args:
ci, ti = _lin_eq2dict(ai, symset)
if not ti:
coeff_list.append(ci)
elif terms is None:
terms = ti
terms_coeff = ci
else:
raise PolyNonlinearError
coeff = Mul(*coeff_list)
if terms is None:
return coeff, {}
else:
terms = {sym: coeff * c for sym, c in terms.items()}
return coeff * terms_coeff, terms
elif a.is_Equality:
return _lin_eq2dict(a.lhs - a.rhs, symset)
elif not a.has_free(*symset):
return a, {}
else:
raise PolyNonlinearError