|
|
"""A module containing `numpy`-specific plugins for mypy."""
from __future__ import annotations
import typing as t
import numpy as np
try: import mypy.types from mypy.types import Type from mypy.plugin import Plugin, AnalyzeTypeContext from mypy.nodes import MypyFile, ImportFrom, Statement from mypy.build import PRI_MED
_HookFunc = t.Callable[[AnalyzeTypeContext], Type] MYPY_EX: t.Optional[ModuleNotFoundError] = None except ModuleNotFoundError as ex: MYPY_EX = ex
__all__: t.List[str] = []
def _get_precision_dict() -> t.Dict[str, str]: names = [ ("_NBitByte", np.byte), ("_NBitShort", np.short), ("_NBitIntC", np.intc), ("_NBitIntP", np.intp), ("_NBitInt", np.int_), ("_NBitLongLong", np.longlong),
("_NBitHalf", np.half), ("_NBitSingle", np.single), ("_NBitDouble", np.double), ("_NBitLongDouble", np.longdouble), ] ret = {} for name, typ in names: n: int = 8 * typ().dtype.itemsize ret[f'numpy.typing._nbit.{name}'] = f"numpy._{n}Bit" return ret
def _get_extended_precision_list() -> t.List[str]: extended_types = [np.ulonglong, np.longlong, np.longdouble, np.clongdouble] extended_names = { "uint128", "uint256", "int128", "int256", "float80", "float96", "float128", "float256", "complex160", "complex192", "complex256", "complex512", } return [i.__name__ for i in extended_types if i.__name__ in extended_names]
#: A dictionary mapping type-aliases in `numpy.typing._nbit` to #: concrete `numpy.typing.NBitBase` subclasses. _PRECISION_DICT: t.Final = _get_precision_dict()
#: A list with the names of all extended precision `np.number` subclasses. _EXTENDED_PRECISION_LIST: t.Final = _get_extended_precision_list()
def _hook(ctx: AnalyzeTypeContext) -> Type: """Replace a type-alias with a concrete ``NBitBase`` subclass.""" typ, _, api = ctx name = typ.name.split(".")[-1] name_new = _PRECISION_DICT[f"numpy.typing._nbit.{name}"] return api.named_type(name_new)
if t.TYPE_CHECKING or MYPY_EX is None: def _index(iterable: t.Iterable[Statement], id: str) -> int: """Identify the first ``ImportFrom`` instance the specified `id`.""" for i, value in enumerate(iterable): if getattr(value, "id", None) == id: return i else: raise ValueError("Failed to identify a `ImportFrom` instance " f"with the following id: {id!r}")
class _NumpyPlugin(Plugin): """A plugin for assigning platform-specific `numpy.number` precisions."""
def get_type_analyze_hook(self, fullname: str) -> t.Optional[_HookFunc]: """Set the precision of platform-specific `numpy.number` subclasses.
For example: `numpy.int_`, `numpy.longlong` and `numpy.longdouble`. """
if fullname in _PRECISION_DICT: return _hook return None
def get_additional_deps(self, file: MypyFile) -> t.List[t.Tuple[int, str, int]]: """Import platform-specific extended-precision `numpy.number` subclasses.
For example: `numpy.float96`, `numpy.float128` and `numpy.complex256`. """
ret = [(PRI_MED, file.fullname, -1)] if file.fullname == "numpy": # Import ONLY the extended precision types available to the # platform in question imports = ImportFrom( "numpy.typing._extended_precision", 0, names=[(v, v) for v in _EXTENDED_PRECISION_LIST], ) imports.is_top_level = True
# Replace the much broader extended-precision import # (defined in `numpy/__init__.pyi`) with a more specific one for lst in [file.defs, file.imports]: # type: t.List[Statement] i = _index(lst, "numpy.typing._extended_precision") lst[i] = imports return ret
def plugin(version: str) -> t.Type[_NumpyPlugin]: """An entry-point for mypy.""" return _NumpyPlugin
else: def plugin(version: str) -> t.Type[_NumpyPlugin]: """An entry-point for mypy.""" raise MYPY_EX
|