图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

439 lines
13 KiB

  1. from typing import Any, Dict as tDict, Tuple as tTuple
  2. from itertools import product
  3. import re
  4. from sympy.core.sympify import sympify
  5. def mathematica(s, additional_translations=None):
  6. '''
  7. Users can add their own translation dictionary.
  8. variable-length argument needs '*' character.
  9. Examples
  10. ========
  11. >>> from sympy.parsing.mathematica import mathematica
  12. >>> mathematica('Log3[9]', {'Log3[x]':'log(x,3)'})
  13. 2
  14. >>> mathematica('F[7,5,3]', {'F[*x]':'Max(*x)*Min(*x)'})
  15. 21
  16. '''
  17. parser = MathematicaParser(additional_translations)
  18. return sympify(parser.parse(s))
  19. def _deco(cls):
  20. cls._initialize_class()
  21. return cls
  22. @_deco
  23. class MathematicaParser:
  24. '''An instance of this class converts a string of a basic Mathematica
  25. expression to SymPy style. Output is string type.'''
  26. # left: Mathematica, right: SymPy
  27. CORRESPONDENCES = {
  28. 'Sqrt[x]': 'sqrt(x)',
  29. 'Exp[x]': 'exp(x)',
  30. 'Log[x]': 'log(x)',
  31. 'Log[x,y]': 'log(y,x)',
  32. 'Log2[x]': 'log(x,2)',
  33. 'Log10[x]': 'log(x,10)',
  34. 'Mod[x,y]': 'Mod(x,y)',
  35. 'Max[*x]': 'Max(*x)',
  36. 'Min[*x]': 'Min(*x)',
  37. 'Pochhammer[x,y]':'rf(x,y)',
  38. 'ArcTan[x,y]':'atan2(y,x)',
  39. 'ExpIntegralEi[x]': 'Ei(x)',
  40. 'SinIntegral[x]': 'Si(x)',
  41. 'CosIntegral[x]': 'Ci(x)',
  42. 'AiryAi[x]': 'airyai(x)',
  43. 'AiryAiPrime[x]': 'airyaiprime(x)',
  44. 'AiryBi[x]' :'airybi(x)',
  45. 'AiryBiPrime[x]' :'airybiprime(x)',
  46. 'LogIntegral[x]':' li(x)',
  47. 'PrimePi[x]': 'primepi(x)',
  48. 'Prime[x]': 'prime(x)',
  49. 'PrimeQ[x]': 'isprime(x)'
  50. }
  51. # trigonometric, e.t.c.
  52. for arc, tri, h in product(('', 'Arc'), (
  53. 'Sin', 'Cos', 'Tan', 'Cot', 'Sec', 'Csc'), ('', 'h')):
  54. fm = arc + tri + h + '[x]'
  55. if arc: # arc func
  56. fs = 'a' + tri.lower() + h + '(x)'
  57. else: # non-arc func
  58. fs = tri.lower() + h + '(x)'
  59. CORRESPONDENCES.update({fm: fs})
  60. REPLACEMENTS = {
  61. ' ': '',
  62. '^': '**',
  63. '{': '[',
  64. '}': ']',
  65. }
  66. RULES = {
  67. # a single whitespace to '*'
  68. 'whitespace': (
  69. re.compile(r'''
  70. (?<=[a-zA-Z\d]) # a letter or a number
  71. \ # a whitespace
  72. (?=[a-zA-Z\d]) # a letter or a number
  73. ''', re.VERBOSE),
  74. '*'),
  75. # add omitted '*' character
  76. 'add*_1': (
  77. re.compile(r'''
  78. (?<=[])\d]) # ], ) or a number
  79. # ''
  80. (?=[(a-zA-Z]) # ( or a single letter
  81. ''', re.VERBOSE),
  82. '*'),
  83. # add omitted '*' character (variable letter preceding)
  84. 'add*_2': (
  85. re.compile(r'''
  86. (?<=[a-zA-Z]) # a letter
  87. \( # ( as a character
  88. (?=.) # any characters
  89. ''', re.VERBOSE),
  90. '*('),
  91. # convert 'Pi' to 'pi'
  92. 'Pi': (
  93. re.compile(r'''
  94. (?:
  95. \A|(?<=[^a-zA-Z])
  96. )
  97. Pi # 'Pi' is 3.14159... in Mathematica
  98. (?=[^a-zA-Z])
  99. ''', re.VERBOSE),
  100. 'pi'),
  101. }
  102. # Mathematica function name pattern
  103. FM_PATTERN = re.compile(r'''
  104. (?:
  105. \A|(?<=[^a-zA-Z]) # at the top or a non-letter
  106. )
  107. [A-Z][a-zA-Z\d]* # Function
  108. (?=\[) # [ as a character
  109. ''', re.VERBOSE)
  110. # list or matrix pattern (for future usage)
  111. ARG_MTRX_PATTERN = re.compile(r'''
  112. \{.*\}
  113. ''', re.VERBOSE)
  114. # regex string for function argument pattern
  115. ARGS_PATTERN_TEMPLATE = r'''
  116. (?:
  117. \A|(?<=[^a-zA-Z])
  118. )
  119. {arguments} # model argument like x, y,...
  120. (?=[^a-zA-Z])
  121. '''
  122. # will contain transformed CORRESPONDENCES dictionary
  123. TRANSLATIONS = {} # type: tDict[tTuple[str, int], tDict[str, Any]]
  124. # cache for a raw users' translation dictionary
  125. cache_original = {} # type: tDict[tTuple[str, int], tDict[str, Any]]
  126. # cache for a compiled users' translation dictionary
  127. cache_compiled = {} # type: tDict[tTuple[str, int], tDict[str, Any]]
  128. @classmethod
  129. def _initialize_class(cls):
  130. # get a transformed CORRESPONDENCES dictionary
  131. d = cls._compile_dictionary(cls.CORRESPONDENCES)
  132. cls.TRANSLATIONS.update(d)
  133. def __init__(self, additional_translations=None):
  134. self.translations = {}
  135. # update with TRANSLATIONS (class constant)
  136. self.translations.update(self.TRANSLATIONS)
  137. if additional_translations is None:
  138. additional_translations = {}
  139. # check the latest added translations
  140. if self.__class__.cache_original != additional_translations:
  141. if not isinstance(additional_translations, dict):
  142. raise ValueError('The argument must be dict type')
  143. # get a transformed additional_translations dictionary
  144. d = self._compile_dictionary(additional_translations)
  145. # update cache
  146. self.__class__.cache_original = additional_translations
  147. self.__class__.cache_compiled = d
  148. # merge user's own translations
  149. self.translations.update(self.__class__.cache_compiled)
  150. @classmethod
  151. def _compile_dictionary(cls, dic):
  152. # for return
  153. d = {}
  154. for fm, fs in dic.items():
  155. # check function form
  156. cls._check_input(fm)
  157. cls._check_input(fs)
  158. # uncover '*' hiding behind a whitespace
  159. fm = cls._apply_rules(fm, 'whitespace')
  160. fs = cls._apply_rules(fs, 'whitespace')
  161. # remove whitespace(s)
  162. fm = cls._replace(fm, ' ')
  163. fs = cls._replace(fs, ' ')
  164. # search Mathematica function name
  165. m = cls.FM_PATTERN.search(fm)
  166. # if no-hit
  167. if m is None:
  168. err = "'{f}' function form is invalid.".format(f=fm)
  169. raise ValueError(err)
  170. # get Mathematica function name like 'Log'
  171. fm_name = m.group()
  172. # get arguments of Mathematica function
  173. args, end = cls._get_args(m)
  174. # function side check. (e.g.) '2*Func[x]' is invalid.
  175. if m.start() != 0 or end != len(fm):
  176. err = "'{f}' function form is invalid.".format(f=fm)
  177. raise ValueError(err)
  178. # check the last argument's 1st character
  179. if args[-1][0] == '*':
  180. key_arg = '*'
  181. else:
  182. key_arg = len(args)
  183. key = (fm_name, key_arg)
  184. # convert '*x' to '\\*x' for regex
  185. re_args = [x if x[0] != '*' else '\\' + x for x in args]
  186. # for regex. Example: (?:(x|y|z))
  187. xyz = '(?:(' + '|'.join(re_args) + '))'
  188. # string for regex compile
  189. patStr = cls.ARGS_PATTERN_TEMPLATE.format(arguments=xyz)
  190. pat = re.compile(patStr, re.VERBOSE)
  191. # update dictionary
  192. d[key] = {}
  193. d[key]['fs'] = fs # SymPy function template
  194. d[key]['args'] = args # args are ['x', 'y'] for example
  195. d[key]['pat'] = pat
  196. return d
  197. def _convert_function(self, s):
  198. '''Parse Mathematica function to SymPy one'''
  199. # compiled regex object
  200. pat = self.FM_PATTERN
  201. scanned = '' # converted string
  202. cur = 0 # position cursor
  203. while True:
  204. m = pat.search(s)
  205. if m is None:
  206. # append the rest of string
  207. scanned += s
  208. break
  209. # get Mathematica function name
  210. fm = m.group()
  211. # get arguments, and the end position of fm function
  212. args, end = self._get_args(m)
  213. # the start position of fm function
  214. bgn = m.start()
  215. # convert Mathematica function to SymPy one
  216. s = self._convert_one_function(s, fm, args, bgn, end)
  217. # update cursor
  218. cur = bgn
  219. # append converted part
  220. scanned += s[:cur]
  221. # shrink s
  222. s = s[cur:]
  223. return scanned
  224. def _convert_one_function(self, s, fm, args, bgn, end):
  225. # no variable-length argument
  226. if (fm, len(args)) in self.translations:
  227. key = (fm, len(args))
  228. # x, y,... model arguments
  229. x_args = self.translations[key]['args']
  230. # make CORRESPONDENCES between model arguments and actual ones
  231. d = {k: v for k, v in zip(x_args, args)}
  232. # with variable-length argument
  233. elif (fm, '*') in self.translations:
  234. key = (fm, '*')
  235. # x, y,..*args (model arguments)
  236. x_args = self.translations[key]['args']
  237. # make CORRESPONDENCES between model arguments and actual ones
  238. d = {}
  239. for i, x in enumerate(x_args):
  240. if x[0] == '*':
  241. d[x] = ','.join(args[i:])
  242. break
  243. d[x] = args[i]
  244. # out of self.translations
  245. else:
  246. err = "'{f}' is out of the whitelist.".format(f=fm)
  247. raise ValueError(err)
  248. # template string of converted function
  249. template = self.translations[key]['fs']
  250. # regex pattern for x_args
  251. pat = self.translations[key]['pat']
  252. scanned = ''
  253. cur = 0
  254. while True:
  255. m = pat.search(template)
  256. if m is None:
  257. scanned += template
  258. break
  259. # get model argument
  260. x = m.group()
  261. # get a start position of the model argument
  262. xbgn = m.start()
  263. # add the corresponding actual argument
  264. scanned += template[:xbgn] + d[x]
  265. # update cursor to the end of the model argument
  266. cur = m.end()
  267. # shrink template
  268. template = template[cur:]
  269. # update to swapped string
  270. s = s[:bgn] + scanned + s[end:]
  271. return s
  272. @classmethod
  273. def _get_args(cls, m):
  274. '''Get arguments of a Mathematica function'''
  275. s = m.string # whole string
  276. anc = m.end() + 1 # pointing the first letter of arguments
  277. square, curly = [], [] # stack for brakets
  278. args = []
  279. # current cursor
  280. cur = anc
  281. for i, c in enumerate(s[anc:], anc):
  282. # extract one argument
  283. if c == ',' and (not square) and (not curly):
  284. args.append(s[cur:i]) # add an argument
  285. cur = i + 1 # move cursor
  286. # handle list or matrix (for future usage)
  287. if c == '{':
  288. curly.append(c)
  289. elif c == '}':
  290. curly.pop()
  291. # seek corresponding ']' with skipping irrevant ones
  292. if c == '[':
  293. square.append(c)
  294. elif c == ']':
  295. if square:
  296. square.pop()
  297. else: # empty stack
  298. args.append(s[cur:i])
  299. break
  300. # the next position to ']' bracket (the function end)
  301. func_end = i + 1
  302. return args, func_end
  303. @classmethod
  304. def _replace(cls, s, bef):
  305. aft = cls.REPLACEMENTS[bef]
  306. s = s.replace(bef, aft)
  307. return s
  308. @classmethod
  309. def _apply_rules(cls, s, bef):
  310. pat, aft = cls.RULES[bef]
  311. return pat.sub(aft, s)
  312. @classmethod
  313. def _check_input(cls, s):
  314. for bracket in (('[', ']'), ('{', '}'), ('(', ')')):
  315. if s.count(bracket[0]) != s.count(bracket[1]):
  316. err = "'{f}' function form is invalid.".format(f=s)
  317. raise ValueError(err)
  318. if '{' in s:
  319. err = "Currently list is not supported."
  320. raise ValueError(err)
  321. def parse(self, s):
  322. # input check
  323. self._check_input(s)
  324. # uncover '*' hiding behind a whitespace
  325. s = self._apply_rules(s, 'whitespace')
  326. # remove whitespace(s)
  327. s = self._replace(s, ' ')
  328. # add omitted '*' character
  329. s = self._apply_rules(s, 'add*_1')
  330. s = self._apply_rules(s, 'add*_2')
  331. # translate function
  332. s = self._convert_function(s)
  333. # '^' to '**'
  334. s = self._replace(s, '^')
  335. # 'Pi' to 'pi'
  336. s = self._apply_rules(s, 'Pi')
  337. # '{', '}' to '[', ']', respectively
  338. # s = cls._replace(s, '{') # currently list is not taken into account
  339. # s = cls._replace(s, '}')
  340. return s