m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

271 lines
8.8 KiB

6 months ago
  1. from typing import Callable, Dict as tDict, Optional, Tuple as tTuple, Union as tUnion
  2. from collections import OrderedDict
  3. import os
  4. import re
  5. import subprocess
  6. from .util import (
  7. find_binary_of_command, unique_list, CompileError
  8. )
  9. class CompilerRunner:
  10. """ CompilerRunner base class.
  11. Parameters
  12. ==========
  13. sources : list of str
  14. Paths to sources.
  15. out : str
  16. flags : iterable of str
  17. Compiler flags.
  18. run_linker : bool
  19. compiler_name_exe : (str, str) tuple
  20. Tuple of compiler name & command to call.
  21. cwd : str
  22. Path of root of relative paths.
  23. include_dirs : list of str
  24. Include directories.
  25. libraries : list of str
  26. Libraries to link against.
  27. library_dirs : list of str
  28. Paths to search for shared libraries.
  29. std : str
  30. Standard string, e.g. ``'c++11'``, ``'c99'``, ``'f2003'``.
  31. define: iterable of strings
  32. macros to define
  33. undef : iterable of strings
  34. macros to undefine
  35. preferred_vendor : string
  36. name of preferred vendor e.g. 'gnu' or 'intel'
  37. Methods
  38. =======
  39. run():
  40. Invoke compilation as a subprocess.
  41. """
  42. # Subclass to vendor/binary dict
  43. compiler_dict = None # type: tDict[str, str]
  44. # Standards should be a tuple of supported standards
  45. # (first one will be the default)
  46. standards = None # type: tTuple[tUnion[None, str], ...]
  47. # Subclass to dict of binary/formater-callback
  48. std_formater = None # type: tDict[str, Callable[[Optional[str]], str]]
  49. # subclass to be e.g. {'gcc': 'gnu', ...}
  50. compiler_name_vendor_mapping = None # type: tDict[str, str]
  51. def __init__(self, sources, out, flags=None, run_linker=True, compiler=None, cwd='.',
  52. include_dirs=None, libraries=None, library_dirs=None, std=None, define=None,
  53. undef=None, strict_aliasing=None, preferred_vendor=None, linkline=None, **kwargs):
  54. if isinstance(sources, str):
  55. raise ValueError("Expected argument sources to be a list of strings.")
  56. self.sources = list(sources)
  57. self.out = out
  58. self.flags = flags or []
  59. self.cwd = cwd
  60. if compiler:
  61. self.compiler_name, self.compiler_binary = compiler
  62. else:
  63. # Find a compiler
  64. if preferred_vendor is None:
  65. preferred_vendor = os.environ.get('SYMPY_COMPILER_VENDOR', None)
  66. self.compiler_name, self.compiler_binary, self.compiler_vendor = self.find_compiler(preferred_vendor)
  67. if self.compiler_binary is None:
  68. raise ValueError("No compiler found (searched: {})".format(', '.join(self.compiler_dict.values())))
  69. self.define = define or []
  70. self.undef = undef or []
  71. self.include_dirs = include_dirs or []
  72. self.libraries = libraries or []
  73. self.library_dirs = library_dirs or []
  74. self.std = std or self.standards[0]
  75. self.run_linker = run_linker
  76. if self.run_linker:
  77. # both gnu and intel compilers use '-c' for disabling linker
  78. self.flags = list(filter(lambda x: x != '-c', self.flags))
  79. else:
  80. if '-c' not in self.flags:
  81. self.flags.append('-c')
  82. if self.std:
  83. self.flags.append(self.std_formater[
  84. self.compiler_name](self.std))
  85. self.linkline = linkline or []
  86. if strict_aliasing is not None:
  87. nsa_re = re.compile("no-strict-aliasing$")
  88. sa_re = re.compile("strict-aliasing$")
  89. if strict_aliasing is True:
  90. if any(map(nsa_re.match, flags)):
  91. raise CompileError("Strict aliasing cannot be both enforced and disabled")
  92. elif any(map(sa_re.match, flags)):
  93. pass # already enforced
  94. else:
  95. flags.append('-fstrict-aliasing')
  96. elif strict_aliasing is False:
  97. if any(map(nsa_re.match, flags)):
  98. pass # already disabled
  99. else:
  100. if any(map(sa_re.match, flags)):
  101. raise CompileError("Strict aliasing cannot be both enforced and disabled")
  102. else:
  103. flags.append('-fno-strict-aliasing')
  104. else:
  105. msg = "Expected argument strict_aliasing to be True/False, got {}"
  106. raise ValueError(msg.format(strict_aliasing))
  107. @classmethod
  108. def find_compiler(cls, preferred_vendor=None):
  109. """ Identify a suitable C/fortran/other compiler. """
  110. candidates = list(cls.compiler_dict.keys())
  111. if preferred_vendor:
  112. if preferred_vendor in candidates:
  113. candidates = [preferred_vendor]+candidates
  114. else:
  115. raise ValueError("Unknown vendor {}".format(preferred_vendor))
  116. name, path = find_binary_of_command([cls.compiler_dict[x] for x in candidates])
  117. return name, path, cls.compiler_name_vendor_mapping[name]
  118. def cmd(self):
  119. """ List of arguments (str) to be passed to e.g. ``subprocess.Popen``. """
  120. cmd = (
  121. [self.compiler_binary] +
  122. self.flags +
  123. ['-U'+x for x in self.undef] +
  124. ['-D'+x for x in self.define] +
  125. ['-I'+x for x in self.include_dirs] +
  126. self.sources
  127. )
  128. if self.run_linker:
  129. cmd += (['-L'+x for x in self.library_dirs] +
  130. ['-l'+x for x in self.libraries] +
  131. self.linkline)
  132. counted = []
  133. for envvar in re.findall(r'\$\{(\w+)\}', ' '.join(cmd)):
  134. if os.getenv(envvar) is None:
  135. if envvar not in counted:
  136. counted.append(envvar)
  137. msg = "Environment variable '{}' undefined.".format(envvar)
  138. raise CompileError(msg)
  139. return cmd
  140. def run(self):
  141. self.flags = unique_list(self.flags)
  142. # Append output flag and name to tail of flags
  143. self.flags.extend(['-o', self.out])
  144. env = os.environ.copy()
  145. env['PWD'] = self.cwd
  146. # NOTE: intel compilers seems to need shell=True
  147. p = subprocess.Popen(' '.join(self.cmd()),
  148. shell=True,
  149. cwd=self.cwd,
  150. stdin=subprocess.PIPE,
  151. stdout=subprocess.PIPE,
  152. stderr=subprocess.STDOUT,
  153. env=env)
  154. comm = p.communicate()
  155. try:
  156. self.cmd_outerr = comm[0].decode('utf-8')
  157. except UnicodeDecodeError:
  158. self.cmd_outerr = comm[0].decode('iso-8859-1') # win32
  159. self.cmd_returncode = p.returncode
  160. # Error handling
  161. if self.cmd_returncode != 0:
  162. msg = "Error executing '{}' in {} (exited status {}):\n {}\n".format(
  163. ' '.join(self.cmd()), self.cwd, str(self.cmd_returncode), self.cmd_outerr
  164. )
  165. raise CompileError(msg)
  166. return self.cmd_outerr, self.cmd_returncode
  167. class CCompilerRunner(CompilerRunner):
  168. compiler_dict = OrderedDict([
  169. ('gnu', 'gcc'),
  170. ('intel', 'icc'),
  171. ('llvm', 'clang'),
  172. ])
  173. standards = ('c89', 'c90', 'c99', 'c11') # First is default
  174. std_formater = {
  175. 'gcc': '-std={}'.format,
  176. 'icc': '-std={}'.format,
  177. 'clang': '-std={}'.format,
  178. }
  179. compiler_name_vendor_mapping = {
  180. 'gcc': 'gnu',
  181. 'icc': 'intel',
  182. 'clang': 'llvm'
  183. }
  184. def _mk_flag_filter(cmplr_name): # helper for class initialization
  185. not_welcome = {'g++': ("Wimplicit-interface",)} # "Wstrict-prototypes",)}
  186. if cmplr_name in not_welcome:
  187. def fltr(x):
  188. for nw in not_welcome[cmplr_name]:
  189. if nw in x:
  190. return False
  191. return True
  192. else:
  193. def fltr(x):
  194. return True
  195. return fltr
  196. class CppCompilerRunner(CompilerRunner):
  197. compiler_dict = OrderedDict([
  198. ('gnu', 'g++'),
  199. ('intel', 'icpc'),
  200. ('llvm', 'clang++'),
  201. ])
  202. # First is the default, c++0x == c++11
  203. standards = ('c++98', 'c++0x')
  204. std_formater = {
  205. 'g++': '-std={}'.format,
  206. 'icpc': '-std={}'.format,
  207. 'clang++': '-std={}'.format,
  208. }
  209. compiler_name_vendor_mapping = {
  210. 'g++': 'gnu',
  211. 'icpc': 'intel',
  212. 'clang++': 'llvm'
  213. }
  214. class FortranCompilerRunner(CompilerRunner):
  215. standards = (None, 'f77', 'f95', 'f2003', 'f2008')
  216. std_formater = {
  217. 'gfortran': lambda x: '-std=gnu' if x is None else '-std=legacy' if x == 'f77' else '-std={}'.format(x),
  218. 'ifort': lambda x: '-stand f08' if x is None else '-stand f{}'.format(x[-2:]), # f2008 => f08
  219. }
  220. compiler_dict = OrderedDict([
  221. ('gnu', 'gfortran'),
  222. ('intel', 'ifort'),
  223. ])
  224. compiler_name_vendor_mapping = {
  225. 'gfortran': 'gnu',
  226. 'ifort': 'intel',
  227. }