m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

593 lines
19 KiB

6 months ago
  1. # Ported from latex2sympy by @augustt198
  2. # https://github.com/augustt198/latex2sympy
  3. # See license in LICENSE.txt
  4. import sympy
  5. from sympy.external import import_module
  6. from sympy.printing.str import StrPrinter
  7. from sympy.physics.quantum.state import Bra, Ket
  8. from .errors import LaTeXParsingError
  9. LaTeXParser = LaTeXLexer = MathErrorListener = None
  10. try:
  11. LaTeXParser = import_module('sympy.parsing.latex._antlr.latexparser',
  12. import_kwargs={'fromlist': ['LaTeXParser']}).LaTeXParser
  13. LaTeXLexer = import_module('sympy.parsing.latex._antlr.latexlexer',
  14. import_kwargs={'fromlist': ['LaTeXLexer']}).LaTeXLexer
  15. except Exception:
  16. pass
  17. ErrorListener = import_module('antlr4.error.ErrorListener',
  18. warn_not_installed=True,
  19. import_kwargs={'fromlist': ['ErrorListener']}
  20. )
  21. if ErrorListener:
  22. class MathErrorListener(ErrorListener.ErrorListener): # type: ignore
  23. def __init__(self, src):
  24. super(ErrorListener.ErrorListener, self).__init__()
  25. self.src = src
  26. def syntaxError(self, recog, symbol, line, col, msg, e):
  27. fmt = "%s\n%s\n%s"
  28. marker = "~" * col + "^"
  29. if msg.startswith("missing"):
  30. err = fmt % (msg, self.src, marker)
  31. elif msg.startswith("no viable"):
  32. err = fmt % ("I expected something else here", self.src, marker)
  33. elif msg.startswith("mismatched"):
  34. names = LaTeXParser.literalNames
  35. expected = [
  36. names[i] for i in e.getExpectedTokens() if i < len(names)
  37. ]
  38. if len(expected) < 10:
  39. expected = " ".join(expected)
  40. err = (fmt % ("I expected one of these: " + expected, self.src,
  41. marker))
  42. else:
  43. err = (fmt % ("I expected something else here", self.src,
  44. marker))
  45. else:
  46. err = fmt % ("I don't understand this", self.src, marker)
  47. raise LaTeXParsingError(err)
  48. def parse_latex(sympy):
  49. antlr4 = import_module('antlr4', warn_not_installed=True)
  50. if None in [antlr4, MathErrorListener]:
  51. raise ImportError("LaTeX parsing requires the antlr4 Python package,"
  52. " provided by pip (antlr4-python2-runtime or"
  53. " antlr4-python3-runtime) or"
  54. " conda (antlr-python-runtime)")
  55. matherror = MathErrorListener(sympy)
  56. stream = antlr4.InputStream(sympy)
  57. lex = LaTeXLexer(stream)
  58. lex.removeErrorListeners()
  59. lex.addErrorListener(matherror)
  60. tokens = antlr4.CommonTokenStream(lex)
  61. parser = LaTeXParser(tokens)
  62. # remove default console error listener
  63. parser.removeErrorListeners()
  64. parser.addErrorListener(matherror)
  65. relation = parser.math().relation()
  66. expr = convert_relation(relation)
  67. return expr
  68. def convert_relation(rel):
  69. if rel.expr():
  70. return convert_expr(rel.expr())
  71. lh = convert_relation(rel.relation(0))
  72. rh = convert_relation(rel.relation(1))
  73. if rel.LT():
  74. return sympy.StrictLessThan(lh, rh)
  75. elif rel.LTE():
  76. return sympy.LessThan(lh, rh)
  77. elif rel.GT():
  78. return sympy.StrictGreaterThan(lh, rh)
  79. elif rel.GTE():
  80. return sympy.GreaterThan(lh, rh)
  81. elif rel.EQUAL():
  82. return sympy.Eq(lh, rh)
  83. elif rel.NEQ():
  84. return sympy.Ne(lh, rh)
  85. def convert_expr(expr):
  86. return convert_add(expr.additive())
  87. def convert_add(add):
  88. if add.ADD():
  89. lh = convert_add(add.additive(0))
  90. rh = convert_add(add.additive(1))
  91. return sympy.Add(lh, rh, evaluate=False)
  92. elif add.SUB():
  93. lh = convert_add(add.additive(0))
  94. rh = convert_add(add.additive(1))
  95. return sympy.Add(lh, sympy.Mul(-1, rh, evaluate=False),
  96. evaluate=False)
  97. else:
  98. return convert_mp(add.mp())
  99. def convert_mp(mp):
  100. if hasattr(mp, 'mp'):
  101. mp_left = mp.mp(0)
  102. mp_right = mp.mp(1)
  103. else:
  104. mp_left = mp.mp_nofunc(0)
  105. mp_right = mp.mp_nofunc(1)
  106. if mp.MUL() or mp.CMD_TIMES() or mp.CMD_CDOT():
  107. lh = convert_mp(mp_left)
  108. rh = convert_mp(mp_right)
  109. return sympy.Mul(lh, rh, evaluate=False)
  110. elif mp.DIV() or mp.CMD_DIV() or mp.COLON():
  111. lh = convert_mp(mp_left)
  112. rh = convert_mp(mp_right)
  113. return sympy.Mul(lh, sympy.Pow(rh, -1, evaluate=False), evaluate=False)
  114. else:
  115. if hasattr(mp, 'unary'):
  116. return convert_unary(mp.unary())
  117. else:
  118. return convert_unary(mp.unary_nofunc())
  119. def convert_unary(unary):
  120. if hasattr(unary, 'unary'):
  121. nested_unary = unary.unary()
  122. else:
  123. nested_unary = unary.unary_nofunc()
  124. if hasattr(unary, 'postfix_nofunc'):
  125. first = unary.postfix()
  126. tail = unary.postfix_nofunc()
  127. postfix = [first] + tail
  128. else:
  129. postfix = unary.postfix()
  130. if unary.ADD():
  131. return convert_unary(nested_unary)
  132. elif unary.SUB():
  133. numabs = convert_unary(nested_unary)
  134. # Use Integer(-n) instead of Mul(-1, n)
  135. return -numabs
  136. elif postfix:
  137. return convert_postfix_list(postfix)
  138. def convert_postfix_list(arr, i=0):
  139. if i >= len(arr):
  140. raise LaTeXParsingError("Index out of bounds")
  141. res = convert_postfix(arr[i])
  142. if isinstance(res, sympy.Expr):
  143. if i == len(arr) - 1:
  144. return res # nothing to multiply by
  145. else:
  146. if i > 0:
  147. left = convert_postfix(arr[i - 1])
  148. right = convert_postfix(arr[i + 1])
  149. if isinstance(left, sympy.Expr) and isinstance(
  150. right, sympy.Expr):
  151. left_syms = convert_postfix(arr[i - 1]).atoms(sympy.Symbol)
  152. right_syms = convert_postfix(arr[i + 1]).atoms(
  153. sympy.Symbol)
  154. # if the left and right sides contain no variables and the
  155. # symbol in between is 'x', treat as multiplication.
  156. if not (left_syms or right_syms) and str(res) == 'x':
  157. return convert_postfix_list(arr, i + 1)
  158. # multiply by next
  159. return sympy.Mul(
  160. res, convert_postfix_list(arr, i + 1), evaluate=False)
  161. else: # must be derivative
  162. wrt = res[0]
  163. if i == len(arr) - 1:
  164. raise LaTeXParsingError("Expected expression for derivative")
  165. else:
  166. expr = convert_postfix_list(arr, i + 1)
  167. return sympy.Derivative(expr, wrt)
  168. def do_subs(expr, at):
  169. if at.expr():
  170. at_expr = convert_expr(at.expr())
  171. syms = at_expr.atoms(sympy.Symbol)
  172. if len(syms) == 0:
  173. return expr
  174. elif len(syms) > 0:
  175. sym = next(iter(syms))
  176. return expr.subs(sym, at_expr)
  177. elif at.equality():
  178. lh = convert_expr(at.equality().expr(0))
  179. rh = convert_expr(at.equality().expr(1))
  180. return expr.subs(lh, rh)
  181. def convert_postfix(postfix):
  182. if hasattr(postfix, 'exp'):
  183. exp_nested = postfix.exp()
  184. else:
  185. exp_nested = postfix.exp_nofunc()
  186. exp = convert_exp(exp_nested)
  187. for op in postfix.postfix_op():
  188. if op.BANG():
  189. if isinstance(exp, list):
  190. raise LaTeXParsingError("Cannot apply postfix to derivative")
  191. exp = sympy.factorial(exp, evaluate=False)
  192. elif op.eval_at():
  193. ev = op.eval_at()
  194. at_b = None
  195. at_a = None
  196. if ev.eval_at_sup():
  197. at_b = do_subs(exp, ev.eval_at_sup())
  198. if ev.eval_at_sub():
  199. at_a = do_subs(exp, ev.eval_at_sub())
  200. if at_b is not None and at_a is not None:
  201. exp = sympy.Add(at_b, -1 * at_a, evaluate=False)
  202. elif at_b is not None:
  203. exp = at_b
  204. elif at_a is not None:
  205. exp = at_a
  206. return exp
  207. def convert_exp(exp):
  208. if hasattr(exp, 'exp'):
  209. exp_nested = exp.exp()
  210. else:
  211. exp_nested = exp.exp_nofunc()
  212. if exp_nested:
  213. base = convert_exp(exp_nested)
  214. if isinstance(base, list):
  215. raise LaTeXParsingError("Cannot raise derivative to power")
  216. if exp.atom():
  217. exponent = convert_atom(exp.atom())
  218. elif exp.expr():
  219. exponent = convert_expr(exp.expr())
  220. return sympy.Pow(base, exponent, evaluate=False)
  221. else:
  222. if hasattr(exp, 'comp'):
  223. return convert_comp(exp.comp())
  224. else:
  225. return convert_comp(exp.comp_nofunc())
  226. def convert_comp(comp):
  227. if comp.group():
  228. return convert_expr(comp.group().expr())
  229. elif comp.abs_group():
  230. return sympy.Abs(convert_expr(comp.abs_group().expr()), evaluate=False)
  231. elif comp.atom():
  232. return convert_atom(comp.atom())
  233. elif comp.frac():
  234. return convert_frac(comp.frac())
  235. elif comp.binom():
  236. return convert_binom(comp.binom())
  237. elif comp.floor():
  238. return convert_floor(comp.floor())
  239. elif comp.ceil():
  240. return convert_ceil(comp.ceil())
  241. elif comp.func():
  242. return convert_func(comp.func())
  243. def convert_atom(atom):
  244. if atom.LETTER():
  245. subscriptName = ''
  246. if atom.subexpr():
  247. subscript = None
  248. if atom.subexpr().expr(): # subscript is expr
  249. subscript = convert_expr(atom.subexpr().expr())
  250. else: # subscript is atom
  251. subscript = convert_atom(atom.subexpr().atom())
  252. subscriptName = '_{' + StrPrinter().doprint(subscript) + '}'
  253. return sympy.Symbol(atom.LETTER().getText() + subscriptName)
  254. elif atom.SYMBOL():
  255. s = atom.SYMBOL().getText()[1:]
  256. if s == "infty":
  257. return sympy.oo
  258. else:
  259. if atom.subexpr():
  260. subscript = None
  261. if atom.subexpr().expr(): # subscript is expr
  262. subscript = convert_expr(atom.subexpr().expr())
  263. else: # subscript is atom
  264. subscript = convert_atom(atom.subexpr().atom())
  265. subscriptName = StrPrinter().doprint(subscript)
  266. s += '_{' + subscriptName + '}'
  267. return sympy.Symbol(s)
  268. elif atom.NUMBER():
  269. s = atom.NUMBER().getText().replace(",", "")
  270. return sympy.Number(s)
  271. elif atom.DIFFERENTIAL():
  272. var = get_differential_var(atom.DIFFERENTIAL())
  273. return sympy.Symbol('d' + var.name)
  274. elif atom.mathit():
  275. text = rule2text(atom.mathit().mathit_text())
  276. return sympy.Symbol(text)
  277. elif atom.bra():
  278. val = convert_expr(atom.bra().expr())
  279. return Bra(val)
  280. elif atom.ket():
  281. val = convert_expr(atom.ket().expr())
  282. return Ket(val)
  283. def rule2text(ctx):
  284. stream = ctx.start.getInputStream()
  285. # starting index of starting token
  286. startIdx = ctx.start.start
  287. # stopping index of stopping token
  288. stopIdx = ctx.stop.stop
  289. return stream.getText(startIdx, stopIdx)
  290. def convert_frac(frac):
  291. diff_op = False
  292. partial_op = False
  293. lower_itv = frac.lower.getSourceInterval()
  294. lower_itv_len = lower_itv[1] - lower_itv[0] + 1
  295. if (frac.lower.start == frac.lower.stop
  296. and frac.lower.start.type == LaTeXLexer.DIFFERENTIAL):
  297. wrt = get_differential_var_str(frac.lower.start.text)
  298. diff_op = True
  299. elif (lower_itv_len == 2 and frac.lower.start.type == LaTeXLexer.SYMBOL
  300. and frac.lower.start.text == '\\partial'
  301. and (frac.lower.stop.type == LaTeXLexer.LETTER
  302. or frac.lower.stop.type == LaTeXLexer.SYMBOL)):
  303. partial_op = True
  304. wrt = frac.lower.stop.text
  305. if frac.lower.stop.type == LaTeXLexer.SYMBOL:
  306. wrt = wrt[1:]
  307. if diff_op or partial_op:
  308. wrt = sympy.Symbol(wrt)
  309. if (diff_op and frac.upper.start == frac.upper.stop
  310. and frac.upper.start.type == LaTeXLexer.LETTER
  311. and frac.upper.start.text == 'd'):
  312. return [wrt]
  313. elif (partial_op and frac.upper.start == frac.upper.stop
  314. and frac.upper.start.type == LaTeXLexer.SYMBOL
  315. and frac.upper.start.text == '\\partial'):
  316. return [wrt]
  317. upper_text = rule2text(frac.upper)
  318. expr_top = None
  319. if diff_op and upper_text.startswith('d'):
  320. expr_top = parse_latex(upper_text[1:])
  321. elif partial_op and frac.upper.start.text == '\\partial':
  322. expr_top = parse_latex(upper_text[len('\\partial'):])
  323. if expr_top:
  324. return sympy.Derivative(expr_top, wrt)
  325. expr_top = convert_expr(frac.upper)
  326. expr_bot = convert_expr(frac.lower)
  327. inverse_denom = sympy.Pow(expr_bot, -1, evaluate=False)
  328. if expr_top == 1:
  329. return inverse_denom
  330. else:
  331. return sympy.Mul(expr_top, inverse_denom, evaluate=False)
  332. def convert_binom(binom):
  333. expr_n = convert_expr(binom.n)
  334. expr_k = convert_expr(binom.k)
  335. return sympy.binomial(expr_n, expr_k, evaluate=False)
  336. def convert_floor(floor):
  337. val = convert_expr(floor.val)
  338. return sympy.floor(val, evaluate=False)
  339. def convert_ceil(ceil):
  340. val = convert_expr(ceil.val)
  341. return sympy.ceiling(val, evaluate=False)
  342. def convert_func(func):
  343. if func.func_normal():
  344. if func.L_PAREN(): # function called with parenthesis
  345. arg = convert_func_arg(func.func_arg())
  346. else:
  347. arg = convert_func_arg(func.func_arg_noparens())
  348. name = func.func_normal().start.text[1:]
  349. # change arc<trig> -> a<trig>
  350. if name in [
  351. "arcsin", "arccos", "arctan", "arccsc", "arcsec", "arccot"
  352. ]:
  353. name = "a" + name[3:]
  354. expr = getattr(sympy.functions, name)(arg, evaluate=False)
  355. if name in ["arsinh", "arcosh", "artanh"]:
  356. name = "a" + name[2:]
  357. expr = getattr(sympy.functions, name)(arg, evaluate=False)
  358. if name == "exp":
  359. expr = sympy.exp(arg, evaluate=False)
  360. if (name == "log" or name == "ln"):
  361. if func.subexpr():
  362. if func.subexpr().expr():
  363. base = convert_expr(func.subexpr().expr())
  364. else:
  365. base = convert_atom(func.subexpr().atom())
  366. elif name == "log":
  367. base = 10
  368. elif name == "ln":
  369. base = sympy.E
  370. expr = sympy.log(arg, base, evaluate=False)
  371. func_pow = None
  372. should_pow = True
  373. if func.supexpr():
  374. if func.supexpr().expr():
  375. func_pow = convert_expr(func.supexpr().expr())
  376. else:
  377. func_pow = convert_atom(func.supexpr().atom())
  378. if name in [
  379. "sin", "cos", "tan", "csc", "sec", "cot", "sinh", "cosh",
  380. "tanh"
  381. ]:
  382. if func_pow == -1:
  383. name = "a" + name
  384. should_pow = False
  385. expr = getattr(sympy.functions, name)(arg, evaluate=False)
  386. if func_pow and should_pow:
  387. expr = sympy.Pow(expr, func_pow, evaluate=False)
  388. return expr
  389. elif func.LETTER() or func.SYMBOL():
  390. if func.LETTER():
  391. fname = func.LETTER().getText()
  392. elif func.SYMBOL():
  393. fname = func.SYMBOL().getText()[1:]
  394. fname = str(fname) # can't be unicode
  395. if func.subexpr():
  396. subscript = None
  397. if func.subexpr().expr(): # subscript is expr
  398. subscript = convert_expr(func.subexpr().expr())
  399. else: # subscript is atom
  400. subscript = convert_atom(func.subexpr().atom())
  401. subscriptName = StrPrinter().doprint(subscript)
  402. fname += '_{' + subscriptName + '}'
  403. input_args = func.args()
  404. output_args = []
  405. while input_args.args(): # handle multiple arguments to function
  406. output_args.append(convert_expr(input_args.expr()))
  407. input_args = input_args.args()
  408. output_args.append(convert_expr(input_args.expr()))
  409. return sympy.Function(fname)(*output_args)
  410. elif func.FUNC_INT():
  411. return handle_integral(func)
  412. elif func.FUNC_SQRT():
  413. expr = convert_expr(func.base)
  414. if func.root:
  415. r = convert_expr(func.root)
  416. return sympy.root(expr, r, evaluate=False)
  417. else:
  418. return sympy.sqrt(expr, evaluate=False)
  419. elif func.FUNC_OVERLINE():
  420. expr = convert_expr(func.base)
  421. return sympy.conjugate(expr, evaluate=False)
  422. elif func.FUNC_SUM():
  423. return handle_sum_or_prod(func, "summation")
  424. elif func.FUNC_PROD():
  425. return handle_sum_or_prod(func, "product")
  426. elif func.FUNC_LIM():
  427. return handle_limit(func)
  428. def convert_func_arg(arg):
  429. if hasattr(arg, 'expr'):
  430. return convert_expr(arg.expr())
  431. else:
  432. return convert_mp(arg.mp_nofunc())
  433. def handle_integral(func):
  434. if func.additive():
  435. integrand = convert_add(func.additive())
  436. elif func.frac():
  437. integrand = convert_frac(func.frac())
  438. else:
  439. integrand = 1
  440. int_var = None
  441. if func.DIFFERENTIAL():
  442. int_var = get_differential_var(func.DIFFERENTIAL())
  443. else:
  444. for sym in integrand.atoms(sympy.Symbol):
  445. s = str(sym)
  446. if len(s) > 1 and s[0] == 'd':
  447. if s[1] == '\\':
  448. int_var = sympy.Symbol(s[2:])
  449. else:
  450. int_var = sympy.Symbol(s[1:])
  451. int_sym = sym
  452. if int_var:
  453. integrand = integrand.subs(int_sym, 1)
  454. else:
  455. # Assume dx by default
  456. int_var = sympy.Symbol('x')
  457. if func.subexpr():
  458. if func.subexpr().atom():
  459. lower = convert_atom(func.subexpr().atom())
  460. else:
  461. lower = convert_expr(func.subexpr().expr())
  462. if func.supexpr().atom():
  463. upper = convert_atom(func.supexpr().atom())
  464. else:
  465. upper = convert_expr(func.supexpr().expr())
  466. return sympy.Integral(integrand, (int_var, lower, upper))
  467. else:
  468. return sympy.Integral(integrand, int_var)
  469. def handle_sum_or_prod(func, name):
  470. val = convert_mp(func.mp())
  471. iter_var = convert_expr(func.subeq().equality().expr(0))
  472. start = convert_expr(func.subeq().equality().expr(1))
  473. if func.supexpr().expr(): # ^{expr}
  474. end = convert_expr(func.supexpr().expr())
  475. else: # ^atom
  476. end = convert_atom(func.supexpr().atom())
  477. if name == "summation":
  478. return sympy.Sum(val, (iter_var, start, end))
  479. elif name == "product":
  480. return sympy.Product(val, (iter_var, start, end))
  481. def handle_limit(func):
  482. sub = func.limit_sub()
  483. if sub.LETTER():
  484. var = sympy.Symbol(sub.LETTER().getText())
  485. elif sub.SYMBOL():
  486. var = sympy.Symbol(sub.SYMBOL().getText()[1:])
  487. else:
  488. var = sympy.Symbol('x')
  489. if sub.SUB():
  490. direction = "-"
  491. else:
  492. direction = "+"
  493. approaching = convert_expr(sub.expr())
  494. content = convert_mp(func.mp())
  495. return sympy.Limit(content, var, approaching, direction)
  496. def get_differential_var(d):
  497. text = get_differential_var_str(d.getText())
  498. return sympy.Symbol(text)
  499. def get_differential_var_str(text):
  500. for i in range(1, len(text)):
  501. c = text[i]
  502. if not (c == " " or c == "\r" or c == "\n" or c == "\t"):
  503. idx = i
  504. break
  505. text = text[idx:]
  506. if text[0] == "\\":
  507. text = text[1:]
  508. return text