m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

56 lines
2.7 KiB

6 months ago
  1. from sympy.integrals.rubi.parsetools.parse import generate_sympy_from_parsed, parse_full_form, rubi_printer
  2. from sympy.core.sympify import sympify
  3. from sympy.integrals.rubi.utility_function import List, If
  4. import os, inspect
  5. def rubi_sstr(a):
  6. return rubi_printer(a, sympy_integers=True)
  7. def generate_test_file():
  8. '''
  9. This function is assuming the name of file containing the fullform is test_1.m.
  10. It can be changes as per use.
  11. For more details, see
  12. `https://github.com/sympy/sympy/wiki/Rubi-parsing-guide#parsing-tests`
  13. '''
  14. res =[]
  15. file_name = 'test_1.m'
  16. with open(file_name) as myfile:
  17. fullform =myfile.read().replace('\n', '')
  18. fullform = fullform.replace('$VersionNumber', 'version_number')
  19. fullform = fullform.replace('Defer[Int][', 'Integrate[')
  20. path_header = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
  21. h = open(os.path.join(path_header, "header.py.txt")).read()
  22. header = "import sys\nfrom sympy.external import import_module\nmatchpy = import_module({})".format('\"matchpy\"')
  23. header += "\nif not matchpy:\n disabled = True\n"
  24. header += "if sys.version_info[:2] < (3, 6):\n disabled = True\n"
  25. header += "\n".join(h.split("\n")[8:-9])
  26. header += "from sympy.integrals.rubi.rubi import rubi_integrate\n"
  27. header += "from sympy import Integral as Integrate, exp, log\n"
  28. header += "\na, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z = symbols('a b c d e f g h i j k l m n o p q r s t u v w x y z')"
  29. header += "\nA, B, C, F, G, H, J, K, L, M, N, O, P, Q, R, T, U, V, W, X, Y, Z = symbols('A B C F G H J K L M N O P Q R T U V W X Y Z')"
  30. header += "\n\ndef {}():\n".format(file_name[0:-2])
  31. s = parse_full_form(fullform)
  32. tests = []
  33. for i in s:
  34. res[:] = []
  35. if i[0] == 'HoldComplete':
  36. ss = sympify(generate_sympy_from_parsed(i[1]), locals = { 'version_number' : 11, 'If' : If})
  37. ss = List(*ss.args)
  38. tests.append(ss)
  39. t = ''
  40. for a in tests:
  41. if len(a) == 5:
  42. r = 'rubi_integrate({}, x)'.format(rubi_sstr(a[0]))
  43. t += '\n assert rubi_test({}, {}, {}, expand=True, _diff=True, _numerical=True) or rubi_test({}, {}, {}, expand=True, _diff=True, _numerical=True)'.format(r, rubi_sstr(a[1]), rubi_sstr(a[3]), r, rubi_sstr(a[1]),rubi_sstr(a[4]))
  44. else:
  45. r = 'rubi_integrate({}, x)'.format(rubi_sstr(a[0]))
  46. t += '\n assert rubi_test({}, {}, {}, expand=True, _diff=True, _numerical=True)'.format(r, rubi_sstr(a[1]), rubi_sstr(a[3]))
  47. t = header+t+'\n'
  48. test = open('parsed_tests.py', 'w')
  49. test.write(t)
  50. test.close()