m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

131 lines
4.3 KiB

6 months ago
  1. """A module containing `numpy`-specific plugins for mypy."""
  2. from __future__ import annotations
  3. import typing as t
  4. import numpy as np
  5. try:
  6. import mypy.types
  7. from mypy.types import Type
  8. from mypy.plugin import Plugin, AnalyzeTypeContext
  9. from mypy.nodes import MypyFile, ImportFrom, Statement
  10. from mypy.build import PRI_MED
  11. _HookFunc = t.Callable[[AnalyzeTypeContext], Type]
  12. MYPY_EX: t.Optional[ModuleNotFoundError] = None
  13. except ModuleNotFoundError as ex:
  14. MYPY_EX = ex
  15. __all__: t.List[str] = []
  16. def _get_precision_dict() -> t.Dict[str, str]:
  17. names = [
  18. ("_NBitByte", np.byte),
  19. ("_NBitShort", np.short),
  20. ("_NBitIntC", np.intc),
  21. ("_NBitIntP", np.intp),
  22. ("_NBitInt", np.int_),
  23. ("_NBitLongLong", np.longlong),
  24. ("_NBitHalf", np.half),
  25. ("_NBitSingle", np.single),
  26. ("_NBitDouble", np.double),
  27. ("_NBitLongDouble", np.longdouble),
  28. ]
  29. ret = {}
  30. for name, typ in names:
  31. n: int = 8 * typ().dtype.itemsize
  32. ret[f'numpy.typing._nbit.{name}'] = f"numpy._{n}Bit"
  33. return ret
  34. def _get_extended_precision_list() -> t.List[str]:
  35. extended_types = [np.ulonglong, np.longlong, np.longdouble, np.clongdouble]
  36. extended_names = {
  37. "uint128",
  38. "uint256",
  39. "int128",
  40. "int256",
  41. "float80",
  42. "float96",
  43. "float128",
  44. "float256",
  45. "complex160",
  46. "complex192",
  47. "complex256",
  48. "complex512",
  49. }
  50. return [i.__name__ for i in extended_types if i.__name__ in extended_names]
  51. #: A dictionary mapping type-aliases in `numpy.typing._nbit` to
  52. #: concrete `numpy.typing.NBitBase` subclasses.
  53. _PRECISION_DICT: t.Final = _get_precision_dict()
  54. #: A list with the names of all extended precision `np.number` subclasses.
  55. _EXTENDED_PRECISION_LIST: t.Final = _get_extended_precision_list()
  56. def _hook(ctx: AnalyzeTypeContext) -> Type:
  57. """Replace a type-alias with a concrete ``NBitBase`` subclass."""
  58. typ, _, api = ctx
  59. name = typ.name.split(".")[-1]
  60. name_new = _PRECISION_DICT[f"numpy.typing._nbit.{name}"]
  61. return api.named_type(name_new)
  62. if t.TYPE_CHECKING or MYPY_EX is None:
  63. def _index(iterable: t.Iterable[Statement], id: str) -> int:
  64. """Identify the first ``ImportFrom`` instance the specified `id`."""
  65. for i, value in enumerate(iterable):
  66. if getattr(value, "id", None) == id:
  67. return i
  68. else:
  69. raise ValueError("Failed to identify a `ImportFrom` instance "
  70. f"with the following id: {id!r}")
  71. class _NumpyPlugin(Plugin):
  72. """A plugin for assigning platform-specific `numpy.number` precisions."""
  73. def get_type_analyze_hook(self, fullname: str) -> t.Optional[_HookFunc]:
  74. """Set the precision of platform-specific `numpy.number` subclasses.
  75. For example: `numpy.int_`, `numpy.longlong` and `numpy.longdouble`.
  76. """
  77. if fullname in _PRECISION_DICT:
  78. return _hook
  79. return None
  80. def get_additional_deps(self, file: MypyFile) -> t.List[t.Tuple[int, str, int]]:
  81. """Import platform-specific extended-precision `numpy.number` subclasses.
  82. For example: `numpy.float96`, `numpy.float128` and `numpy.complex256`.
  83. """
  84. ret = [(PRI_MED, file.fullname, -1)]
  85. if file.fullname == "numpy":
  86. # Import ONLY the extended precision types available to the
  87. # platform in question
  88. imports = ImportFrom(
  89. "numpy.typing._extended_precision", 0,
  90. names=[(v, v) for v in _EXTENDED_PRECISION_LIST],
  91. )
  92. imports.is_top_level = True
  93. # Replace the much broader extended-precision import
  94. # (defined in `numpy/__init__.pyi`) with a more specific one
  95. for lst in [file.defs, file.imports]: # type: t.List[Statement]
  96. i = _index(lst, "numpy.typing._extended_precision")
  97. lst[i] = imports
  98. return ret
  99. def plugin(version: str) -> t.Type[_NumpyPlugin]:
  100. """An entry-point for mypy."""
  101. return _NumpyPlugin
  102. else:
  103. def plugin(version: str) -> t.Type[_NumpyPlugin]:
  104. """An entry-point for mypy."""
  105. raise MYPY_EX