m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

128 lines
8.1 KiB

6 months ago
  1. import sys
  2. from sympy.external import import_module
  3. matchpy = import_module("matchpy")
  4. if not matchpy:
  5. #bin/test will not execute any tests now
  6. disabled = True
  7. if sys.version_info[:2] < (3, 6):
  8. disabled = True
  9. from sympy.integrals.rubi.parsetools.parse import (rubi_rule_parser,
  10. get_default_values, add_wildcards, parse_freeq, seperate_freeq,
  11. get_free_symbols, divide_constraint, generate_sympy_from_parsed,
  12. setWC, replaceWith, rubi_printer, set_matchq_in_constraint, contains_diff_return_type,
  13. process_return_type, extract_set)
  14. from sympy.core.symbol import (Symbol, symbols)
  15. from sympy.core.sympify import sympify
  16. from sympy.logic.boolalg import Not
  17. a, b, c, d, e, j, m, n, p, q, x, Pq, Pqq = symbols('a b c d e j m n p q x Pq Pqq')
  18. def test_rubi_rule_parser():
  19. header = '''
  20. from matchpy import Operation, CommutativeOperation
  21. rubi = ManyToOneReplacer()
  22. '''
  23. fullform = 'List[RuleDelayed[HoldPattern[Int[Power[Pattern[x,Blank[]],Optional[Pattern[m,Blank[]]]],Pattern[x,Blank[Symbol]]]],Condition[Times[Power[x,Plus[m,1]],Power[Plus[m,1],-1]],NonzeroQ[Plus[m,1]]]]]'
  24. rules, constraint = rubi_rule_parser(fullform, header)
  25. result_rule = '''
  26. from matchpy import Operation, CommutativeOperation
  27. rubi = ManyToOneReplacer()
  28. from sympy.integrals.rubi.constraints import cons1
  29. pattern1 = Pattern(Integral(x_**WC('m', S(1)), x_), cons1)
  30. def replacement1(m, x):
  31. rubi.append(1)
  32. return x**(m + S(1))/(m + S(1))
  33. rule1 = ReplacementRule(pattern1, replacement1)
  34. return [rule1, ]
  35. '''
  36. result_constraint = '''
  37. from matchpy import Operation, CommutativeOperation
  38. def cons_f1(m):
  39. return NonzeroQ(m + S(1))
  40. cons1 = CustomConstraint(cons_f1)
  41. '''
  42. assert len(result_rule.strip()) == len(rules.strip()) # failing randomly while using `result.strip() == rules`
  43. assert len(result_constraint.strip()) == len(constraint.strip())
  44. def test_get_default_values():
  45. s = ['Int', ['Power', ['Plus', ['Optional', ['Pattern', 'a', ['Blank']]], ['Times', ['Optional', ['Pattern', 'b', ['Blank']]], ['Pattern', 'x', ['Blank']]]], ['Pattern', 'm', ['Blank']]], ['Pattern', 'x', ['Blank', 'Symbol']]]
  46. assert get_default_values(s, {}) == {'a': 0, 'b': 1}
  47. s = ['Int', ['Power', ['Pattern', 'x', ['Blank']], ['Optional', ['Pattern', 'm', ['Blank']]]], ['Pattern', 'x', ['Blank', 'Symbol']]]
  48. assert get_default_values(s, {}) == {'m': 1}
  49. def test_add_wildcards():
  50. s = 'Integral(Pow(Pattern(x, Blank), Optional(Pattern(m, Blank))), Pattern(x, Blank(Symbol)))'
  51. assert add_wildcards(s, {'m': 1}) == ("Integral(Pow(x_, WC('m', S(1))), x_)", ['m', 'x', 'x'])
  52. def test_seperate_freeq():
  53. s = ['FreeQ', ['List', 'a', 'b'], 'x']
  54. assert seperate_freeq(s) == (['a', 'b'], 'x')
  55. def test_parse_freeq():
  56. l = ['a', 'b']
  57. x = 'x'
  58. symbols = ['x', 'a', 'b']
  59. assert parse_freeq(l, x, 0, {}, [], symbols) == (', cons1, cons2', '\n def cons_f1(a, x):\n return FreeQ(a, x)\n\n cons1 = CustomConstraint(cons_f1)\n\n def cons_f2(b, x):\n return FreeQ(b, x)\n\n cons2 = CustomConstraint(cons_f2)\n', 2)
  60. def test_get_free_symbols():
  61. s = ['NonzeroQ', ['Plus', 'm', '1']]
  62. symbols = ['m', 'x']
  63. assert get_free_symbols(s, symbols, []) == ['m']
  64. def test_divide_constraint():
  65. s = ['And', ['FreeQ', 'm', 'x'], ['NonzeroQ', ['Plus', 'm', '1']]]
  66. assert divide_constraint(s, ['m', 'x'], 0, {}, []) == (', cons1', '\n def cons_f1(m):\n return NonzeroQ(m + S(1))\n\n cons1 = CustomConstraint(cons_f1)\n', 1)
  67. def test_setWC():
  68. assert setWC('Integral(x_**WC(m, S(1)), x_)') == "Integral(x_**WC('m', S(1)), x_)"
  69. def test_replaceWith():
  70. s = sympify('Module(List(Set(r, Numerator(Rt(a/b, n))), Set(s, Denominator(Rt(a/b, n))), k, u), CompoundExpression(Set(u, Integral((r - s*x*cos(Pi*(2*k - 1)/n))/(r**2 - 2*r*s*x*cos(Pi*(2*k - 1)/n) + s**2*x**2), x)), Dist(2*r/(a*n), _Sum(u, List(k, 1, n/2 - 1/2)), x) + r*Integral(1/(r + s*x), x)/(a*n)))')
  71. symbols = ['x', 'a', 'n', 'b']
  72. assert replaceWith(s, symbols, 1) == (" def With1(x, a, n, b):\n r = Numerator(Rt(a/b, n))\n s = Denominator(Rt(a/b, n))\n k = Symbol('k')\n u = Symbol('u')\n u = Integral((r - s*x*cos(Pi*(S(2)*k + S(-1))/n))/(r**S(2) - S(2)*r*s*x*cos(Pi*(S(2)*k + S(-1))/n) + s**S(2)*x**S(2)), x)\n u = Integral((r - s*x*cos(Pi*(2*k - 1)/n))/(r**2 - 2*r*s*x*cos(Pi*(2*k - 1)/n) + s**2*x**2), x)\n rubi.append(1)\n return Dist(S(2)*r/(a*n), _Sum(u, List(k, S(1), n/S(2) + S(-1)/2)), x) + r*Integral(S(1)/(r + s*x), x)/(a*n)", ' ', None)
  73. def test_generate_sympy_from_parsed():
  74. s = ['Int', ['Power', ['Plus', ['Pattern', 'a', ['Blank']], ['Times', ['Optional', ['Pattern', 'b', ['Blank']]], ['Power', ['Pattern', 'x', ['Blank']], ['Pattern', 'n', ['Blank']]]]], '-1'], ['Pattern', 'x', ['Blank', 'Symbol']]]
  75. assert generate_sympy_from_parsed(s, wild=True) == 'Int(Pow(Add(Pattern(a, Blank), Mul(Optional(Pattern(b, Blank)), Pow(Pattern(x, Blank), Pattern(n, Blank)))), S(-1)), Pattern(x, Blank(Symbol)))'
  76. assert generate_sympy_from_parsed(s ,replace_Int=True) == 'Integral(Pow(Add(Pattern(a, Blank), Mul(Optional(Pattern(b, Blank)), Pow(Pattern(x, Blank), Pattern(n, Blank)))), S(-1)), Pattern(x, Blank(Symbol)))'
  77. s = ['And', ['FreeQ', ['List', 'a', 'b'], 'x'], ['PositiveIntegerQ', ['Times', ['Plus', 'n', '-3'], ['Power', '2', '-1']]], ['PosQ', ['Times', 'a', ['Power', 'b', '-1']]]]
  78. assert generate_sympy_from_parsed(s) == 'And(FreeQ(List(a, b), x), PositiveIntegerQ(Mul(Add(n, S(-3)), Pow(S(2), S(-1)))), PosQ(Mul(a, Pow(b, S(-1)))))'
  79. def test_rubi_printer():
  80. #14819
  81. a = Symbol('a')
  82. assert rubi_printer(Not(a)) == 'Not(a)'
  83. def test_contains_diff_return_type():
  84. assert contains_diff_return_type(['Plus', ['BinomialDegree', 'u', 'x'], ['Times', '-1', ['BinomialDegree', 'z', 'x']]])
  85. def test_set_matchq_in_constraint():
  86. expected = ('result_matchq', " def _cons_f_1229(g, m):\n return FreeQ(List(g, m), x)\n _cons_1229 = CustomConstraint(_cons_f_1229)\n pat = Pattern(UtilityOperator((x*WC('g', S(1)))**WC('m', S(1)), x), _cons_1229)\n result_matchq = is_match(UtilityOperator(v, x), pat)")
  87. expected1 = ('result_matchq', " def _cons_f_1229(m, g):\n return FreeQ(List(g, m), x)\n _cons_1229 = CustomConstraint(_cons_f_1229)\n pat = Pattern(UtilityOperator((x*WC('g', S(1)))**WC('m', S(1)), x), _cons_1229)\n result_matchq = is_match(UtilityOperator(v, x), pat)")
  88. result = set_matchq_in_constraint(['MatchQ', 'v', ['Condition', ['Power', ['Times', ['Optional',\
  89. ['Pattern', 'g', ['Blank']]], 'x'], ['Optional', ['Pattern', 'm', ['Blank']]]], ['FreeQ', ['List', 'g', 'm'], 'x']]], 1229)
  90. assert result == expected1 or result == expected
  91. def test_process_return_type():
  92. from sympy.core.function import Function
  93. Int = Function("Int")
  94. ExpandToSum = Function("ExpandToSum")
  95. s = ('\n q = Expon(Pq, x)\n Pqq = Coeff(Pq, x, q)', 'With(List(Set(Pqq, Coeff(Pq, x, q))), Pqq*c**(n - q + S(-1))*(c*x)**(m - n + q + S(1))*(a*x**j + b*x**n)**(p + S(1))/(b*(m + n*p + q + S(1))) + Int((c*x)**m*(a*x**j + b*x**n)**p*ExpandToSum(Pq - Pqq*a*x**(-n + q)*(m - n + q + S(1))/(b*(m + n*p + q + S(1))) - Pqq*x**q, x), x))')
  96. result = process_return_type(s, [])
  97. expected = ('\n Pqq = Coeff(Pq, x, q)',\
  98. Pqq*c**(n - q - 1)*(c*x)**(m - n + q + 1)*(a*x**j + b*x**n)**(p + 1)/(b*(m + n*p + q + 1)) + Int((c*x)**m*(a*x**j + b*x**n)**p*ExpandToSum(Pq - Pqq*a*x**(-n + q)*(m - n + q + 1)/(b*(m + n*p + q + 1)) - Pqq*x**q, x), x),\
  99. True)
  100. assert result == expected
  101. def test_extract_set():
  102. s = sympify('Module(List(Set(r, Numerator(Rt(a/b, n))), Set(s, Denominator(Rt(a/b, n))), k, u), CompoundExpression(Set(u, Integral((r - s*x*cos(Pi*(2*k - 1)/n))/(r**2 - 2*r*s*x*cos(Pi*(2*k - 1)/n) + s**2*x**2), x)), Dist(2*r/(a*n), _Sum(u, List(k, 1, n/2 - 1/2)), x) + r*Integral(1/(r + s*x), x)/(a*n)))')
  103. expected = list(sympify('Set(r, Numerator(Rt(a/b, n))), Set(s, Denominator(Rt(a/b, n))), Set(u, Integral((r - s*x*cos(Pi*(2*k - 1)/n))/(r**2 - 2*r*s*x*cos(Pi*(2*k - 1)/n) + s**2*x**2), x))'))
  104. assert extract_set(s, []) == expected