m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

90 lines
2.6 KiB

6 months ago
  1. """Test the runtime usage of `numpy.typing`."""
  2. from __future__ import annotations
  3. import sys
  4. from typing import get_type_hints, Union, Tuple, NamedTuple
  5. import pytest
  6. import numpy as np
  7. import numpy.typing as npt
  8. try:
  9. from typing_extensions import get_args, get_origin
  10. SKIP = False
  11. except ImportError:
  12. SKIP = True
  13. class TypeTup(NamedTuple):
  14. typ: type
  15. args: Tuple[type, ...]
  16. origin: None | type
  17. if sys.version_info >= (3, 9):
  18. NDArrayTup = TypeTup(npt.NDArray, npt.NDArray.__args__, np.ndarray)
  19. else:
  20. NDArrayTup = TypeTup(npt.NDArray, (), None)
  21. TYPES = {
  22. "ArrayLike": TypeTup(npt.ArrayLike, npt.ArrayLike.__args__, Union),
  23. "DTypeLike": TypeTup(npt.DTypeLike, npt.DTypeLike.__args__, Union),
  24. "NBitBase": TypeTup(npt.NBitBase, (), None),
  25. "NDArray": NDArrayTup,
  26. }
  27. @pytest.mark.parametrize("name,tup", TYPES.items(), ids=TYPES.keys())
  28. @pytest.mark.skipif(SKIP, reason="requires typing-extensions")
  29. def test_get_args(name: type, tup: TypeTup) -> None:
  30. """Test `typing.get_args`."""
  31. typ, ref = tup.typ, tup.args
  32. out = get_args(typ)
  33. assert out == ref
  34. @pytest.mark.parametrize("name,tup", TYPES.items(), ids=TYPES.keys())
  35. @pytest.mark.skipif(SKIP, reason="requires typing-extensions")
  36. def test_get_origin(name: type, tup: TypeTup) -> None:
  37. """Test `typing.get_origin`."""
  38. typ, ref = tup.typ, tup.origin
  39. out = get_origin(typ)
  40. assert out == ref
  41. @pytest.mark.parametrize("name,tup", TYPES.items(), ids=TYPES.keys())
  42. def test_get_type_hints(name: type, tup: TypeTup) -> None:
  43. """Test `typing.get_type_hints`."""
  44. typ = tup.typ
  45. # Explicitly set `__annotations__` in order to circumvent the
  46. # stringification performed by `from __future__ import annotations`
  47. def func(a): pass
  48. func.__annotations__ = {"a": typ, "return": None}
  49. out = get_type_hints(func)
  50. ref = {"a": typ, "return": type(None)}
  51. assert out == ref
  52. @pytest.mark.parametrize("name,tup", TYPES.items(), ids=TYPES.keys())
  53. def test_get_type_hints_str(name: type, tup: TypeTup) -> None:
  54. """Test `typing.get_type_hints` with string-representation of types."""
  55. typ_str, typ = f"npt.{name}", tup.typ
  56. # Explicitly set `__annotations__` in order to circumvent the
  57. # stringification performed by `from __future__ import annotations`
  58. def func(a): pass
  59. func.__annotations__ = {"a": typ_str, "return": None}
  60. out = get_type_hints(func)
  61. ref = {"a": typ, "return": type(None)}
  62. assert out == ref
  63. def test_keys() -> None:
  64. """Test that ``TYPES.keys()`` and ``numpy.typing.__all__`` are synced."""
  65. keys = TYPES.keys()
  66. ref = set(npt.__all__)
  67. assert keys == ref