m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

721 lines
27 KiB

6 months ago
  1. """
  2. Parser for FullForm[Downvalues[]] of Mathematica rules.
  3. This parser is customised to parse the output in MatchPy rules format. Multiple
  4. `Constraints` are divided into individual `Constraints` because it helps the
  5. MatchPy's `ManyToOneReplacer` to backtrack earlier and improve the speed.
  6. Parsed output is formatted into readable format by using `sympify` and print the
  7. expression using `sstr`. This replaces `And`, `Mul`, 'Pow' by their respective
  8. symbols.
  9. Mathematica
  10. ===========
  11. To get the full form from Wolfram Mathematica, type:
  12. ```
  13. ShowSteps = False
  14. Import["RubiLoader.m"]
  15. Export["output.txt", ToString@FullForm@DownValues@Int]
  16. ```
  17. The file ``output.txt`` will then contain the rules in parseable format.
  18. References
  19. ==========
  20. [1] http://reference.wolfram.com/language/ref/FullForm.html
  21. [2] http://reference.wolfram.com/language/ref/DownValues.html
  22. [3] https://gist.github.com/Upabjojr/bc07c49262944f9c1eb0
  23. """
  24. import re
  25. import os
  26. import inspect
  27. from sympy.core.function import Function
  28. from sympy.core.symbol import Symbol
  29. from sympy.core.sympify import sympify
  30. from sympy.sets.sets import Set
  31. from sympy.printing import StrPrinter
  32. from sympy.utilities.misc import debug
  33. class RubiStrPrinter(StrPrinter):
  34. def _print_Not(self, expr):
  35. return "Not(%s)" % self._print(expr.args[0])
  36. def rubi_printer(expr, **settings):
  37. return RubiStrPrinter(settings).doprint(expr)
  38. replacements = dict( # Mathematica equivalent functions in SymPy
  39. Times="Mul",
  40. Plus="Add",
  41. Power="Pow",
  42. Log='log',
  43. Exp='exp',
  44. Sqrt='sqrt',
  45. Cos='cos',
  46. Sin='sin',
  47. Tan='tan',
  48. Cot='1/tan',
  49. cot='1/tan',
  50. Sec='1/cos',
  51. sec='1/cos',
  52. Csc='1/sin',
  53. csc='1/sin',
  54. ArcSin='asin',
  55. ArcCos='acos',
  56. # ArcTan='atan',
  57. ArcCot='acot',
  58. ArcSec='asec',
  59. ArcCsc='acsc',
  60. Sinh='sinh',
  61. Cosh='cosh',
  62. Tanh='tanh',
  63. Coth='1/tanh',
  64. coth='1/tanh',
  65. Sech='1/cosh',
  66. sech='1/cosh',
  67. Csch='1/sinh',
  68. csch='1/sinh',
  69. ArcSinh='asinh',
  70. ArcCosh='acosh',
  71. ArcTanh='atanh',
  72. ArcCoth='acoth',
  73. ArcSech='asech',
  74. ArcCsch='acsch',
  75. Expand='expand',
  76. Im='im',
  77. Re='re',
  78. Flatten='flatten',
  79. Polylog='polylog',
  80. Cancel='cancel',
  81. #Gamma='gamma',
  82. TrigExpand='expand_trig',
  83. Sign='sign',
  84. Simplify='simplify',
  85. Defer='UnevaluatedExpr',
  86. Identity = 'S',
  87. Sum = 'Sum_doit',
  88. Module = 'With',
  89. Block = 'With',
  90. Null = 'None'
  91. )
  92. temporary_variable_replacement = { # Temporarily rename because it can raise errors while sympifying
  93. 'gcd' : "_gcd",
  94. 'jn' : "_jn",
  95. }
  96. permanent_variable_replacement = { # Permamenely rename these variables
  97. r"\[ImaginaryI]" : 'ImaginaryI',
  98. "$UseGamma": '_UseGamma',
  99. }
  100. # These functions have different return type in different cases. So better to use a try and except in the constraints, when any of these appear
  101. f_diff_return_type = ['BinomialParts', 'BinomialDegree', 'TrinomialParts', 'GeneralizedBinomialParts', 'GeneralizedTrinomialParts', 'PseudoBinomialParts', 'PerfectPowerTest',
  102. 'SquareFreeFactorTest', 'SubstForFractionalPowerOfQuotientOfLinears', 'FractionalPowerOfQuotientOfLinears', 'InverseFunctionOfQuotientOfLinears',
  103. 'FractionalPowerOfSquareQ', 'FunctionOfLinear', 'FunctionOfInverseLinear', 'FunctionOfTrig', 'FindTrigFactor', 'FunctionOfLog',
  104. 'PowerVariableExpn', 'FunctionOfSquareRootOfQuadratic', 'SubstForFractionalPowerOfLinear', 'FractionalPowerOfLinear', 'InverseFunctionOfLinear',
  105. 'Divides', 'DerivativeDivides', 'TrigSquare', 'SplitProduct', 'SubstForFractionalPowerOfQuotientOfLinears', 'InverseFunctionOfQuotientOfLinears',
  106. 'FunctionOfHyperbolic', 'SplitSum']
  107. def contains_diff_return_type(a):
  108. """
  109. This function returns whether an expression contains functions which have different return types in
  110. diiferent cases.
  111. """
  112. if isinstance(a, list):
  113. for i in a:
  114. if contains_diff_return_type(i):
  115. return True
  116. elif type(a) == Function('With') or type(a) == Function('Module'):
  117. for i in f_diff_return_type:
  118. if a.has(Function(i)):
  119. return True
  120. else:
  121. if a in f_diff_return_type:
  122. return True
  123. return False
  124. def parse_full_form(wmexpr):
  125. """
  126. Parses FullForm[Downvalues[]] generated by Mathematica
  127. """
  128. out = []
  129. stack = [out]
  130. generator = re.finditer(r'[\[\],]', wmexpr)
  131. last_pos = 0
  132. for match in generator:
  133. if match is None:
  134. break
  135. position = match.start()
  136. last_expr = wmexpr[last_pos:position].replace(',', '').replace(']', '').replace('[', '').strip()
  137. if match.group() == ',':
  138. if last_expr != '':
  139. stack[-1].append(last_expr)
  140. elif match.group() == ']':
  141. if last_expr != '':
  142. stack[-1].append(last_expr)
  143. stack.pop()
  144. elif match.group() == '[':
  145. stack[-1].append([last_expr])
  146. stack.append(stack[-1][-1])
  147. last_pos = match.end()
  148. return out[0]
  149. def get_default_values(parsed, default_values={}):
  150. """
  151. Returns Optional variables and their values in the pattern
  152. """
  153. if not isinstance(parsed, list):
  154. return default_values
  155. if parsed[0] == "Times": # find Default arguments for "Times"
  156. for i in parsed[1:]:
  157. if i[0] == "Optional":
  158. default_values[(i[1][1])] = 1
  159. if parsed[0] == "Plus": # find Default arguments for "Plus"
  160. for i in parsed[1:]:
  161. if i[0] == "Optional":
  162. default_values[(i[1][1])] = 0
  163. if parsed[0] == "Power": # find Default arguments for "Power"
  164. for i in parsed[1:]:
  165. if i[0] == "Optional":
  166. default_values[(i[1][1])] = 1
  167. if len(parsed) == 1:
  168. return default_values
  169. for i in parsed:
  170. default_values = get_default_values(i, default_values)
  171. return default_values
  172. def add_wildcards(string, optional={}):
  173. """
  174. Replaces `Pattern(variable)` by `variable` in `string`.
  175. Returns the free symbols present in the string.
  176. """
  177. symbols = [] # stores symbols present in the expression
  178. p = r'(Optional\(Pattern\((\w+), Blank\)\))'
  179. matches = re.findall(p, string)
  180. for i in matches:
  181. string = string.replace(i[0], "WC('{}', S({}))".format(i[1], optional[i[1]]))
  182. symbols.append(i[1])
  183. p = r'(Pattern\((\w+), Blank\))'
  184. matches = re.findall(p, string)
  185. for i in matches:
  186. string = string.replace(i[0], i[1] + '_')
  187. symbols.append(i[1])
  188. p = r'(Pattern\((\w+), Blank\(Symbol\)\))'
  189. matches = re.findall(p, string)
  190. for i in matches:
  191. string = string.replace(i[0], i[1] + '_')
  192. symbols.append(i[1])
  193. return string, symbols
  194. def seperate_freeq(s, variables=[], x=None):
  195. """
  196. Returns list of symbols in FreeQ.
  197. """
  198. if s[0] == 'FreeQ':
  199. if len(s[1]) == 1:
  200. variables = [s[1]]
  201. else:
  202. variables = s[1][1:]
  203. x = s[2]
  204. else:
  205. for i in s[1:]:
  206. variables, x = seperate_freeq(i, variables, x)
  207. return variables, x
  208. return variables, x
  209. def parse_freeq(l, x, cons_index, cons_dict, cons_import, symbols=None):
  210. """
  211. Converts FreeQ constraints into MatchPy constraint
  212. """
  213. res = []
  214. cons = ''
  215. for i in l:
  216. if isinstance(i, str):
  217. r = ' return FreeQ({}, {})'.format(i, x)
  218. # First it checks if a constraint is already present in `cons_dict`, If yes, use it else create a new one.
  219. if r not in cons_dict.values():
  220. cons_index += 1
  221. c = '\n def cons_f{}({}, {}):\n'.format(cons_index, i, x)
  222. c += r
  223. c += '\n\n cons{} = CustomConstraint({})\n'.format(cons_index, 'cons_f{}'.format(cons_index))
  224. cons_name = 'cons{}'.format(cons_index)
  225. cons_dict[cons_name] = r
  226. else:
  227. c = ''
  228. cons_name = next(key for key, value in sorted(cons_dict.items()) if value == r)
  229. elif isinstance(i, list):
  230. s = sorted(set(get_free_symbols(i, symbols)))
  231. s = ', '.join(s)
  232. r = ' return FreeQ({}, {})'.format(generate_sympy_from_parsed(i), x)
  233. if r not in cons_dict.values():
  234. cons_index += 1
  235. c = '\n def cons_f{}({}):\n'.format(cons_index, s)
  236. c += r
  237. c += '\n\n cons{} = CustomConstraint({})\n'.format(cons_index, 'cons_f{}'.format(cons_index))
  238. cons_name = 'cons{}'.format(cons_index)
  239. cons_dict[cons_name] = r
  240. else:
  241. c = ''
  242. cons_name = next(key for key, value in cons_dict.items() if value == r)
  243. if cons_name not in cons_import:
  244. cons_import.append(cons_name)
  245. res.append(cons_name)
  246. cons += c
  247. if res != []:
  248. return ', ' + ', '.join(res), cons, cons_index
  249. return '', cons, cons_index
  250. def generate_sympy_from_parsed(parsed, wild=False, symbols=(), replace_Int=False):
  251. """
  252. Parses list into Python syntax.
  253. Parameters
  254. ==========
  255. wild : When set to True, the symbols are replaced as wild symbols.
  256. symbols : Symbols already present in the pattern.
  257. replace_Int: when set to True, `Int` is replaced by `Integral`(used to parse pattern).
  258. """
  259. out = ""
  260. if not isinstance(parsed, list):
  261. try: # return S(number) if parsed is Number
  262. float(parsed)
  263. return "S({})".format(parsed)
  264. except:
  265. pass
  266. if parsed in symbols:
  267. if wild:
  268. return parsed + '_'
  269. return parsed
  270. if parsed[0] == 'Rational':
  271. return 'S({})/S({})'.format(generate_sympy_from_parsed(parsed[1], wild=wild, symbols=symbols, replace_Int=replace_Int), generate_sympy_from_parsed(parsed[2], wild=wild, symbols=symbols, replace_Int=replace_Int))
  272. if parsed[0] in replacements:
  273. out += replacements[parsed[0]]
  274. elif parsed[0] == 'Int' and replace_Int:
  275. out += 'Integral'
  276. else:
  277. out += parsed[0]
  278. if len(parsed) == 1:
  279. return out
  280. result = [generate_sympy_from_parsed(i, wild=wild, symbols=symbols, replace_Int=replace_Int) for i in parsed[1:]]
  281. if '' in result:
  282. result.remove('')
  283. out += "("
  284. out += ", ".join(result)
  285. out += ")"
  286. return out
  287. def get_free_symbols(s, symbols, free_symbols=None):
  288. """
  289. Returns free_symbols present in `s`.
  290. """
  291. free_symbols = free_symbols or []
  292. if not isinstance(s, list):
  293. if s in symbols:
  294. free_symbols.append(s)
  295. return free_symbols
  296. for i in s:
  297. free_symbols = get_free_symbols(i, symbols, free_symbols)
  298. return free_symbols
  299. def set_matchq_in_constraint(a, cons_index):
  300. """
  301. Takes care of the case, when a pattern matching has to be done inside a constraint.
  302. """
  303. lst = []
  304. res = ''
  305. if isinstance(a, list):
  306. if a[0] == 'MatchQ':
  307. s = a
  308. optional = get_default_values(s, {})
  309. r = generate_sympy_from_parsed(s, replace_Int=True)
  310. r, free_symbols = add_wildcards(r, optional=optional)
  311. free_symbols = sorted(set(free_symbols)) # remove common symbols
  312. r = sympify(r, locals={"Or": Function("Or"), "And": Function("And"), "Not":Function("Not")})
  313. pattern = r.args[1].args[0]
  314. cons = r.args[1].args[1]
  315. pattern = rubi_printer(pattern, sympy_integers=True)
  316. pattern = setWC(pattern)
  317. res = ' def _cons_f_{}({}):\n return {}\n'.format(cons_index, ', '.join(free_symbols), cons)
  318. res += ' _cons_{} = CustomConstraint(_cons_f_{})\n'.format(cons_index, cons_index)
  319. res += ' pat = Pattern(UtilityOperator({}, x), _cons_{})\n'.format(pattern, cons_index)
  320. res += ' result_matchq = is_match(UtilityOperator({}, x), pat)'.format(r.args[0])
  321. return "result_matchq", res
  322. else:
  323. for i in a:
  324. if isinstance(i, list):
  325. r = set_matchq_in_constraint(i, cons_index)
  326. lst.append(r[0])
  327. res = r[1]
  328. else:
  329. lst.append(i)
  330. return lst, res
  331. def _divide_constriant(s, symbols, cons_index, cons_dict, cons_import):
  332. # Creates a CustomConstraint of the form `CustomConstraint(lambda a, x: FreeQ(a, x))`
  333. lambda_symbols = sorted(set(get_free_symbols(s, symbols, [])))
  334. r = generate_sympy_from_parsed(s)
  335. r = sympify(r, locals={"Or": Function("Or"), "And": Function("And"), "Not":Function("Not")})
  336. if r.has(Function('MatchQ')):
  337. match_res = set_matchq_in_constraint(s, cons_index)
  338. res = match_res[1]
  339. res += '\n return {}'.format(rubi_printer(sympify(generate_sympy_from_parsed(match_res[0]), locals={"Or": Function("Or"), "And": Function("And"), "Not":Function("Not")}), sympy_integers = True))
  340. elif contains_diff_return_type(s):
  341. res = ' try:\n return {}\n except (TypeError, AttributeError):\n return False'.format(rubi_printer(r, sympy_integers=True))
  342. else:
  343. res = ' return {}'.format(rubi_printer(r, sympy_integers=True))
  344. # First it checks if a constraint is already present in `cons_dict`, If yes, use it else create a new one.
  345. if res not in cons_dict.values():
  346. cons_index += 1
  347. cons = '\n def cons_f{}({}):\n'.format(cons_index, ', '.join(lambda_symbols))
  348. if 'x' in lambda_symbols:
  349. cons += ' if isinstance(x, (int, Integer, float, Float)):\n return False\n'
  350. cons += res
  351. cons += '\n\n cons{} = CustomConstraint({})\n'.format(cons_index, 'cons_f{}'.format(cons_index))
  352. cons_name = 'cons{}'.format(cons_index)
  353. cons_dict[cons_name] = res
  354. else:
  355. cons = ''
  356. cons_name = next(key for key, value in cons_dict.items() if value == res)
  357. if cons_name not in cons_import:
  358. cons_import.append(cons_name)
  359. return cons_name, cons, cons_index
  360. def divide_constraint(s, symbols, cons_index, cons_dict, cons_import):
  361. """
  362. Divides multiple constraints into smaller constraints.
  363. Parameters
  364. ==========
  365. s : constraint as list
  366. symbols : all the symbols present in the expression
  367. """
  368. result =[]
  369. cons = ''
  370. if s[0] == 'And':
  371. for i in s[1:]:
  372. if i[0]!= 'FreeQ':
  373. a = _divide_constriant(i, symbols, cons_index, cons_dict, cons_import)
  374. result.append(a[0])
  375. cons += a[1]
  376. cons_index = a[2]
  377. else:
  378. a = _divide_constriant(s, symbols, cons_index, cons_dict, cons_import)
  379. result.append(a[0])
  380. cons += a[1]
  381. cons_index = a[2]
  382. r = ['']
  383. for i in result:
  384. if i != '':
  385. r.append(i)
  386. return ', '.join(r),cons, cons_index
  387. def setWC(string):
  388. """
  389. Replaces `WC(a, b)` by `WC('a', S(b))`
  390. """
  391. p = r'(WC\((\w+), S\(([-+]?\d)\)\))'
  392. matches = re.findall(p, string)
  393. for i in matches:
  394. string = string.replace(i[0], "WC('{}', S({}))".format(i[1], i[2]))
  395. return string
  396. def process_return_type(a1, L):
  397. """
  398. Functions like `Set`, `With` and `CompoundExpression` has to be taken special care.
  399. """
  400. a = sympify(a1[1])
  401. x = ''
  402. processed = False
  403. return_value = ''
  404. if type(a) == Function('With') or type(a) == Function('Module'):
  405. for i in a.args:
  406. for s in i.args:
  407. if isinstance(s, Set) and s not in L:
  408. x += '\n {} = {}'.format(s.args[0], rubi_printer(s.args[1], sympy_integers=True))
  409. if not type(i) in (Function('List'), Function('CompoundExpression')) and not i.has(Function('CompoundExpression')):
  410. return_value = i
  411. processed = True
  412. elif type(i) == Function('CompoundExpression'):
  413. return_value = i.args[-1]
  414. processed = True
  415. elif type(i.args[0]) == Function('CompoundExpression'):
  416. C = i.args[0]
  417. return_value = '{}({}, {})'.format(i.func, C.args[-1], i.args[1])
  418. processed = True
  419. return x, return_value, processed
  420. def extract_set(s, L):
  421. """
  422. this function extracts all `Set` functions
  423. """
  424. lst = []
  425. if isinstance(s, Set) and s not in L:
  426. lst.append(s)
  427. else:
  428. try:
  429. for i in s.args:
  430. lst += extract_set(i, L)
  431. except: # when s has no attribute args (like `bool`)
  432. pass
  433. return lst
  434. def replaceWith(s, symbols, index):
  435. """
  436. Replaces `With` and `Module by python functions`
  437. """
  438. return_type = None
  439. with_value = ''
  440. if type(s) == Function('With') or type(s) == Function('Module'):
  441. constraints = ' '
  442. result = '\n\n\ndef With{}({}):'.format(index, ', '.join(symbols))
  443. if type(s.args[0]) == Function('List'): # get all local variables of With and Module
  444. L = list(s.args[0].args)
  445. else:
  446. L = [s.args[0]]
  447. lst = []
  448. for i in s.args[1:]:
  449. lst += extract_set(i, L)
  450. L += lst
  451. for i in L: # define local variables
  452. if isinstance(i, Set):
  453. with_value += '\n {} = {}'.format(i.args[0], rubi_printer(i.args[1], sympy_integers=True))
  454. elif isinstance(i, Symbol):
  455. with_value += "\n {} = Symbol('{}')".format(i, i)
  456. #result += with_value
  457. if type(s.args[1]) == Function('CompoundExpression'): # Expand CompoundExpression
  458. C = s.args[1]
  459. result += with_value
  460. if isinstance(C.args[0], Set):
  461. result += '\n {} = {}'.format(C.args[0].args[0], C.args[0].args[1])
  462. result += '\n return {}'.format(rubi_printer(C.args[1], sympy_integers=True))
  463. return result, constraints, return_type
  464. elif type(s.args[1]) == Function('Condition'):
  465. C = s.args[1]
  466. if len(C.args) == 2:
  467. if all(j in symbols for j in [str(i) for i in C.free_symbols]):
  468. result += with_value
  469. #constraints += 'CustomConstraint(lambda {}: {})'.format(', '.join([str(i) for i in C.free_symbols]), sstr(C.args[1], sympy_integers=True))
  470. result += '\n return {}'.format(rubi_printer(C.args[0], sympy_integers=True))
  471. else:
  472. if 'x' in symbols:
  473. result += '\n if isinstance(x, (int, Integer, float, Float)):\n return False'
  474. if contains_diff_return_type(s):
  475. n_with_value = with_value.replace('\n', '\n ')
  476. result += '\n try:{}\n res = {}'.format(n_with_value, rubi_printer(C.args[1], sympy_integers=True))
  477. result += '\n except (TypeError, AttributeError):\n return False'
  478. result += '\n if res:'
  479. else:
  480. result+=with_value
  481. result += '\n if {}:'.format(rubi_printer(C.args[1], sympy_integers=True))
  482. return_type = (with_value, rubi_printer(C.args[0], sympy_integers=True))
  483. return_type1 = process_return_type(return_type, L)
  484. if return_type1[2]:
  485. return_type = (with_value+return_type1[0], rubi_printer(return_type1[1]))
  486. result += '\n return True'
  487. result += '\n return False'
  488. constraints = ', CustomConstraint(With{})'.format(index)
  489. return result, constraints, return_type
  490. elif type(s.args[1]) == Function('Module') or type(s.args[1]) == Function('With'):
  491. C = s.args[1]
  492. result += with_value
  493. return_type = (with_value, rubi_printer(C, sympy_integers=True))
  494. return_type1 = process_return_type(return_type, L)
  495. if return_type1[2]:
  496. return_type = (with_value+return_type1[0], rubi_printer(return_type1[1]))
  497. result += return_type1[0]
  498. result += '\n return {}'.format(rubi_printer(return_type1[1]))
  499. return result, constraints, None
  500. elif s.args[1].has(Function("CompoundExpression")):
  501. C = s.args[1].args[0]
  502. result += with_value
  503. if isinstance(C.args[0], Set):
  504. result += '\n {} = {}'.format(C.args[0].args[0], C.args[0].args[1])
  505. result += '\n return {}({}, {})'.format(s.args[1].func, C.args[-1], s.args[1].args[1])
  506. return result, constraints, None
  507. result += with_value
  508. result += '\n return {}'.format(rubi_printer(s.args[1], sympy_integers=True))
  509. return result, constraints, return_type
  510. else:
  511. return rubi_printer(s, sympy_integers=True), '', return_type
  512. def downvalues_rules(r, header, cons_dict, cons_index, index):
  513. """
  514. Function which generates parsed rules by substituting all possible
  515. combinations of default values.
  516. """
  517. rules = '['
  518. parsed = '\n\n'
  519. repl_funcs = '\n\n'
  520. cons = ''
  521. cons_import = [] # it contains name of constraints that need to be imported for rules.
  522. for i in r:
  523. debug('parsing rule {}'.format(r.index(i) + 1))
  524. # Parse Pattern
  525. if i[1][1][0] == 'Condition':
  526. p = i[1][1][1].copy()
  527. else:
  528. p = i[1][1].copy()
  529. optional = get_default_values(p, {})
  530. pattern = generate_sympy_from_parsed(p.copy(), replace_Int=True)
  531. pattern, free_symbols = add_wildcards(pattern, optional=optional)
  532. free_symbols = sorted(set(free_symbols)) #remove common symbols
  533. # Parse Transformed Expression and Constraints
  534. if i[2][0] == 'Condition': # parse rules without constraints separately
  535. constriant, constraint_def, cons_index = divide_constraint(i[2][2], free_symbols, cons_index, cons_dict, cons_import) # separate And constraints into individual constraints
  536. FreeQ_vars, FreeQ_x = seperate_freeq(i[2][2].copy()) # separate FreeQ into individual constraints
  537. transformed = generate_sympy_from_parsed(i[2][1].copy(), symbols=free_symbols)
  538. else:
  539. constriant = ''
  540. constraint_def = ''
  541. FreeQ_vars, FreeQ_x = [], []
  542. transformed = generate_sympy_from_parsed(i[2].copy(), symbols=free_symbols)
  543. FreeQ_constraint, free_cons_def, cons_index = parse_freeq(FreeQ_vars, FreeQ_x, cons_index, cons_dict, cons_import, free_symbols)
  544. pattern = sympify(pattern, locals={"Or": Function("Or"), "And": Function("And"), "Not":Function("Not") })
  545. pattern = rubi_printer(pattern, sympy_integers=True)
  546. pattern = setWC(pattern)
  547. transformed = sympify(transformed, locals={"Or": Function("Or"), "And": Function("And"), "Not":Function("Not") })
  548. constraint_def = constraint_def + free_cons_def
  549. cons += constraint_def
  550. index += 1
  551. # below are certain if - else condition depending on various situation that may be encountered
  552. if type(transformed) == Function('With') or type(transformed) == Function('Module'): # define separate function when With appears
  553. transformed, With_constraints, return_type = replaceWith(transformed, free_symbols, index)
  554. if return_type is None:
  555. repl_funcs += '{}'.format(transformed)
  556. parsed += '\n pattern' + str(index) + ' = Pattern(' + pattern + '' + FreeQ_constraint + '' + constriant + ')'
  557. parsed += '\n ' + 'rule' + str(index) + ' = ReplacementRule(' + 'pattern' + rubi_printer(index, sympy_integers=True) + ', With{}'.format(index) + ')\n'
  558. else:
  559. repl_funcs += '{}'.format(transformed)
  560. parsed += '\n pattern' + str(index) + ' = Pattern(' + pattern + '' + FreeQ_constraint + '' + constriant + With_constraints + ')'
  561. repl_funcs += '\n\n\ndef replacement{}({}):\n'.format(
  562. index, ', '.join(free_symbols)
  563. ) + return_type[0] + '\n return '.format(index) + return_type[1]
  564. parsed += '\n ' + 'rule' + str(index) + ' = ReplacementRule(' + 'pattern' + rubi_printer(index, sympy_integers=True) + ', replacement{}'.format(index) + ')\n'
  565. else:
  566. transformed = rubi_printer(transformed, sympy_integers=True)
  567. parsed += '\n pattern' + str(index) + ' = Pattern(' + pattern + '' + FreeQ_constraint + '' + constriant + ')'
  568. repl_funcs += '\n\n\ndef replacement{}({}):\n return '.format(index, ', '.join(free_symbols), index) + transformed
  569. parsed += '\n ' + 'rule' + str(index) + ' = ReplacementRule(' + 'pattern' + rubi_printer(index, sympy_integers=True) + ', replacement{}'.format(index) + ')\n'
  570. rules += 'rule{}, '.format(index)
  571. rules += ']'
  572. parsed += ' return ' + rules +'\n'
  573. header += ' from sympy.integrals.rubi.constraints import ' + ', '.join(word for word in cons_import)
  574. parsed = header + parsed + repl_funcs
  575. return parsed, cons_index, cons, index
  576. def rubi_rule_parser(fullform, header=None, module_name='rubi_object'):
  577. """
  578. Parses rules in MatchPy format.
  579. Parameters
  580. ==========
  581. fullform : FullForm of the rule as string.
  582. header : Header imports for the file. Uses default imports if None.
  583. module_name : name of RUBI module
  584. References
  585. ==========
  586. [1] http://reference.wolfram.com/language/ref/FullForm.html
  587. [2] http://reference.wolfram.com/language/ref/DownValues.html
  588. [3] https://gist.github.com/Upabjojr/bc07c49262944f9c1eb0
  589. """
  590. if header is None: # use default header values
  591. path_header = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
  592. header = open(os.path.join(path_header, "header.py.txt")).read()
  593. header = header.format(module_name)
  594. cons_dict = {} # dict keeps track of constraints that has been encountered, thus avoids repetition of constraints.
  595. cons_index = 0 # for index of a constraint
  596. index = 0 # indicates the number of a rule.
  597. cons = ''
  598. # Temporarily rename these variables because it
  599. # can raise errors while sympifying
  600. for i in temporary_variable_replacement:
  601. fullform = fullform.replace(i, temporary_variable_replacement[i])
  602. # Permanently rename these variables
  603. for i in permanent_variable_replacement:
  604. fullform = fullform.replace(i, permanent_variable_replacement[i])
  605. rules = []
  606. for i in parse_full_form(fullform): # separate all rules
  607. if i[0] == 'RuleDelayed':
  608. rules.append(i)
  609. parsed = downvalues_rules(rules, header, cons_dict, cons_index, index)
  610. result = parsed[0].strip() + '\n'
  611. cons += parsed[2]
  612. # Replace temporary variables by actual values
  613. for i in temporary_variable_replacement:
  614. cons = cons.replace(temporary_variable_replacement[i], i)
  615. result = result.replace(temporary_variable_replacement[i], i)
  616. cons = "\n".join(header.split("\n")[:-2]) + '\n' + cons
  617. return result, cons