m2m模型翻译
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

210 lines
6.2 KiB

6 months ago
  1. from __future__ import annotations
  2. import sys
  3. import types
  4. from typing import (
  5. Any,
  6. ClassVar,
  7. FrozenSet,
  8. Generator,
  9. Iterable,
  10. Iterator,
  11. List,
  12. NoReturn,
  13. Tuple,
  14. Type,
  15. TypeVar,
  16. TYPE_CHECKING,
  17. )
  18. import numpy as np
  19. __all__ = ["_GenericAlias", "NDArray"]
  20. _T = TypeVar("_T", bound="_GenericAlias")
  21. def _to_str(obj: object) -> str:
  22. """Helper function for `_GenericAlias.__repr__`."""
  23. if obj is Ellipsis:
  24. return '...'
  25. elif isinstance(obj, type) and not isinstance(obj, _GENERIC_ALIAS_TYPE):
  26. if obj.__module__ == 'builtins':
  27. return obj.__qualname__
  28. else:
  29. return f'{obj.__module__}.{obj.__qualname__}'
  30. else:
  31. return repr(obj)
  32. def _parse_parameters(args: Iterable[Any]) -> Generator[TypeVar, None, None]:
  33. """Search for all typevars and typevar-containing objects in `args`.
  34. Helper function for `_GenericAlias.__init__`.
  35. """
  36. for i in args:
  37. if hasattr(i, "__parameters__"):
  38. yield from i.__parameters__
  39. elif isinstance(i, TypeVar):
  40. yield i
  41. def _reconstruct_alias(alias: _T, parameters: Iterator[TypeVar]) -> _T:
  42. """Recursivelly replace all typevars with those from `parameters`.
  43. Helper function for `_GenericAlias.__getitem__`.
  44. """
  45. args = []
  46. for i in alias.__args__:
  47. if isinstance(i, TypeVar):
  48. value: Any = next(parameters)
  49. elif isinstance(i, _GenericAlias):
  50. value = _reconstruct_alias(i, parameters)
  51. elif hasattr(i, "__parameters__"):
  52. prm_tup = tuple(next(parameters) for _ in i.__parameters__)
  53. value = i[prm_tup]
  54. else:
  55. value = i
  56. args.append(value)
  57. cls = type(alias)
  58. return cls(alias.__origin__, tuple(args))
  59. class _GenericAlias:
  60. """A python-based backport of the `types.GenericAlias` class.
  61. E.g. for ``t = list[int]``, ``t.__origin__`` is ``list`` and
  62. ``t.__args__`` is ``(int,)``.
  63. See Also
  64. --------
  65. :pep:`585`
  66. The PEP responsible for introducing `types.GenericAlias`.
  67. """
  68. __slots__ = ("__weakref__", "_origin", "_args", "_parameters", "_hash")
  69. @property
  70. def __origin__(self) -> type:
  71. return super().__getattribute__("_origin")
  72. @property
  73. def __args__(self) -> Tuple[Any, ...]:
  74. return super().__getattribute__("_args")
  75. @property
  76. def __parameters__(self) -> Tuple[TypeVar, ...]:
  77. """Type variables in the ``GenericAlias``."""
  78. return super().__getattribute__("_parameters")
  79. def __init__(self, origin: type, args: Any) -> None:
  80. self._origin = origin
  81. self._args = args if isinstance(args, tuple) else (args,)
  82. self._parameters = tuple(_parse_parameters(args))
  83. @property
  84. def __call__(self) -> type:
  85. return self.__origin__
  86. def __reduce__(self: _T) -> Tuple[Type[_T], Tuple[type, Tuple[Any, ...]]]:
  87. cls = type(self)
  88. return cls, (self.__origin__, self.__args__)
  89. def __mro_entries__(self, bases: Iterable[object]) -> Tuple[type]:
  90. return (self.__origin__,)
  91. def __dir__(self) -> List[str]:
  92. """Implement ``dir(self)``."""
  93. cls = type(self)
  94. dir_origin = set(dir(self.__origin__))
  95. return sorted(cls._ATTR_EXCEPTIONS | dir_origin)
  96. def __hash__(self) -> int:
  97. """Return ``hash(self)``."""
  98. # Attempt to use the cached hash
  99. try:
  100. return super().__getattribute__("_hash")
  101. except AttributeError:
  102. self._hash: int = hash(self.__origin__) ^ hash(self.__args__)
  103. return super().__getattribute__("_hash")
  104. def __instancecheck__(self, obj: object) -> NoReturn:
  105. """Check if an `obj` is an instance."""
  106. raise TypeError("isinstance() argument 2 cannot be a "
  107. "parameterized generic")
  108. def __subclasscheck__(self, cls: type) -> NoReturn:
  109. """Check if a `cls` is a subclass."""
  110. raise TypeError("issubclass() argument 2 cannot be a "
  111. "parameterized generic")
  112. def __repr__(self) -> str:
  113. """Return ``repr(self)``."""
  114. args = ", ".join(_to_str(i) for i in self.__args__)
  115. origin = _to_str(self.__origin__)
  116. return f"{origin}[{args}]"
  117. def __getitem__(self: _T, key: Any) -> _T:
  118. """Return ``self[key]``."""
  119. key_tup = key if isinstance(key, tuple) else (key,)
  120. if len(self.__parameters__) == 0:
  121. raise TypeError(f"There are no type variables left in {self}")
  122. elif len(key_tup) > len(self.__parameters__):
  123. raise TypeError(f"Too many arguments for {self}")
  124. elif len(key_tup) < len(self.__parameters__):
  125. raise TypeError(f"Too few arguments for {self}")
  126. key_iter = iter(key_tup)
  127. return _reconstruct_alias(self, key_iter)
  128. def __eq__(self, value: object) -> bool:
  129. """Return ``self == value``."""
  130. if not isinstance(value, _GENERIC_ALIAS_TYPE):
  131. return NotImplemented
  132. return (
  133. self.__origin__ == value.__origin__ and
  134. self.__args__ == value.__args__
  135. )
  136. _ATTR_EXCEPTIONS: ClassVar[FrozenSet[str]] = frozenset({
  137. "__origin__",
  138. "__args__",
  139. "__parameters__",
  140. "__mro_entries__",
  141. "__reduce__",
  142. "__reduce_ex__",
  143. "__copy__",
  144. "__deepcopy__",
  145. })
  146. def __getattribute__(self, name: str) -> Any:
  147. """Return ``getattr(self, name)``."""
  148. # Pull the attribute from `__origin__` unless its
  149. # name is in `_ATTR_EXCEPTIONS`
  150. cls = type(self)
  151. if name in cls._ATTR_EXCEPTIONS:
  152. return super().__getattribute__(name)
  153. return getattr(self.__origin__, name)
  154. # See `_GenericAlias.__eq__`
  155. if sys.version_info >= (3, 9):
  156. _GENERIC_ALIAS_TYPE = (_GenericAlias, types.GenericAlias)
  157. else:
  158. _GENERIC_ALIAS_TYPE = (_GenericAlias,)
  159. ScalarType = TypeVar("ScalarType", bound=np.generic, covariant=True)
  160. if TYPE_CHECKING:
  161. NDArray = np.ndarray[Any, np.dtype[ScalarType]]
  162. elif sys.version_info >= (3, 9):
  163. _DType = types.GenericAlias(np.dtype, (ScalarType,))
  164. NDArray = types.GenericAlias(np.ndarray, (Any, _DType))
  165. else:
  166. _DType = _GenericAlias(np.dtype, (ScalarType,))
  167. NDArray = _GenericAlias(np.ndarray, (Any, _DType))