m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

130 lines
4.0 KiB

from sympy.core.relational import (Eq, Ne)
from sympy.core.singleton import S
from sympy.core.symbol import symbols
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.functions.elementary.trigonometric import (cos, sin)
from sympy.external import import_module
from sympy.testing.pytest import skip
from sympy.utilities.matchpy_connector import WildDot, WildPlus, WildStar, Replacer
matchpy = import_module("matchpy")
x, y, z = symbols("x y z")
def _get_first_match(expr, pattern):
from matchpy import ManyToOneMatcher, Pattern
matcher = ManyToOneMatcher()
matcher.add(Pattern(pattern))
return next(iter(matcher.match(expr)))
def test_matchpy_connector():
if matchpy is None:
skip("matchpy not installed")
from multiset import Multiset
from matchpy import Pattern, Substitution
w_ = WildDot("w_")
w__ = WildPlus("w__")
w___ = WildStar("w___")
expr = x + y
pattern = x + w_
p, subst = _get_first_match(expr, pattern)
assert p == Pattern(pattern)
assert subst == Substitution({'w_': y})
expr = x + y + z
pattern = x + w__
p, subst = _get_first_match(expr, pattern)
assert p == Pattern(pattern)
assert subst == Substitution({'w__': Multiset([y, z])})
expr = x + y + z
pattern = x + y + z + w___
p, subst = _get_first_match(expr, pattern)
assert p == Pattern(pattern)
assert subst == Substitution({'w___': Multiset()})
def test_matchpy_optional():
if matchpy is None:
skip("matchpy not installed")
from matchpy import Pattern, Substitution
from matchpy import ManyToOneReplacer, ReplacementRule
p = WildDot("p", optional=1)
q = WildDot("q", optional=0)
pattern = p*x + q
expr1 = 2*x
pa, subst = _get_first_match(expr1, pattern)
assert pa == Pattern(pattern)
assert subst == Substitution({'p': 2, 'q': 0})
expr2 = x + 3
pa, subst = _get_first_match(expr2, pattern)
assert pa == Pattern(pattern)
assert subst == Substitution({'p': 1, 'q': 3})
expr3 = x
pa, subst = _get_first_match(expr3, pattern)
assert pa == Pattern(pattern)
assert subst == Substitution({'p': 1, 'q': 0})
expr4 = x*y + z
pa, subst = _get_first_match(expr4, pattern)
assert pa == Pattern(pattern)
assert subst == Substitution({'p': y, 'q': z})
replacer = ManyToOneReplacer()
replacer.add(ReplacementRule(Pattern(pattern), lambda p, q: sin(p)*cos(q)))
assert replacer.replace(expr1) == sin(2)*cos(0)
assert replacer.replace(expr2) == sin(1)*cos(3)
assert replacer.replace(expr3) == sin(1)*cos(0)
assert replacer.replace(expr4) == sin(y)*cos(z)
def test_replacer():
if matchpy is None:
skip("matchpy not installed")
x1_ = WildDot("x1_")
x2_ = WildDot("x2_")
a_ = WildDot("a_", optional=S.One)
b_ = WildDot("b_", optional=S.One)
c_ = WildDot("c_", optional=S.Zero)
replacer = Replacer(common_constraints=[
matchpy.CustomConstraint(lambda a_: not a_.has(x)),
matchpy.CustomConstraint(lambda b_: not b_.has(x)),
matchpy.CustomConstraint(lambda c_: not c_.has(x)),
])
# Rewrite the equation into implicit form, unless it's already solved:
replacer.add(Eq(x1_, x2_), Eq(x1_ - x2_, 0), conditions_nonfalse=[Ne(x2_, 0), Ne(x1_, 0), Ne(x1_, x), Ne(x2_, x)])
# Simple equation solver for real numbers:
replacer.add(Eq(a_*x + b_, 0), Eq(x, -b_/a_))
disc = b_**2 - 4*a_*c_
replacer.add(
Eq(a_*x**2 + b_*x + c_, 0),
Eq(x, (-b_ - sqrt(disc))/(2*a_)) | Eq(x, (-b_ + sqrt(disc))/(2*a_)),
conditions_nonfalse=[disc >= 0]
)
replacer.add(
Eq(a_*x**2 + c_, 0),
Eq(x, sqrt(-c_/a_)) | Eq(x, -sqrt(-c_/a_)),
conditions_nonfalse=[-c_*a_ > 0]
)
assert replacer.replace(Eq(3*x, y)) == Eq(x, y/3)
assert replacer.replace(Eq(x**2 + 1, 0)) == Eq(x**2 + 1, 0)
assert replacer.replace(Eq(x**2, 4)) == (Eq(x, 2) | Eq(x, -2))
assert replacer.replace(Eq(x**2 + 4*y*x + 4*y**2, 0)) == Eq(x, -2*y)