diff --git a/mypy/typeshed/stdlib/@tests/test_cases/asyncio/check_coroutines.py b/mypy/typeshed/stdlib/@tests/test_cases/asyncio/check_coroutines.py new file mode 100644 index 000000000000..160bd896469e --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/asyncio/check_coroutines.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from asyncio import iscoroutinefunction +from collections.abc import Awaitable, Callable, Coroutine +from typing import Any +from typing_extensions import assert_type + + +def test_iscoroutinefunction( + x: Callable[[str, int], Coroutine[str, int, bytes]], + y: Callable[[str, int], Awaitable[bytes]], + z: Callable[[str, int], str | Awaitable[bytes]], + xx: object, +) -> None: + if iscoroutinefunction(x): + assert_type(x, Callable[[str, int], Coroutine[str, int, bytes]]) + + if iscoroutinefunction(y): + assert_type(y, Callable[[str, int], Coroutine[Any, Any, bytes]]) + + if iscoroutinefunction(z): + assert_type(z, Callable[[str, int], Coroutine[Any, Any, Any]]) + + if iscoroutinefunction(xx): + assert_type(xx, Callable[..., Coroutine[Any, Any, Any]]) diff --git a/mypy/typeshed/stdlib/@tests/test_cases/asyncio/check_gather.py b/mypy/typeshed/stdlib/@tests/test_cases/asyncio/check_gather.py new file mode 100644 index 000000000000..02a01e39731a --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/asyncio/check_gather.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import asyncio +from typing import Awaitable, List, Tuple, Union +from typing_extensions import assert_type + + +async def coro1() -> int: + return 42 + + +async def coro2() -> str: + return "spam" + + +async def test_gather(awaitable1: Awaitable[int], awaitable2: Awaitable[str]) -> None: + a = await asyncio.gather(awaitable1) + assert_type(a, Tuple[int]) + + b = await asyncio.gather(awaitable1, awaitable2, return_exceptions=True) + assert_type(b, Tuple[Union[int, BaseException], Union[str, BaseException]]) + + c = await asyncio.gather(awaitable1, awaitable2, awaitable1, awaitable1, awaitable1, awaitable1) + assert_type(c, Tuple[int, str, int, int, int, int]) + + d = await asyncio.gather(awaitable1, awaitable1, awaitable1, awaitable1, awaitable1, awaitable1, awaitable1) + assert_type(d, List[int]) + + awaitables_list: list[Awaitable[int]] = [awaitable1] + e = await asyncio.gather(*awaitables_list) + assert_type(e, List[int]) + + # this case isn't reliable between typecheckers, no one would ever call it with no args anyway + # f = await asyncio.gather() + # assert_type(f, list[Any]) + + +asyncio.run(test_gather(coro1(), coro2())) diff --git a/mypy/typeshed/stdlib/@tests/test_cases/asyncio/check_task.py b/mypy/typeshed/stdlib/@tests/test_cases/asyncio/check_task.py new file mode 100644 index 000000000000..69bcf8f782aa --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/asyncio/check_task.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +import asyncio + + +class Waiter: + def __init__(self) -> None: + self.tasks: list[asyncio.Task[object]] = [] + + def add(self, t: asyncio.Task[object]) -> None: + self.tasks.append(t) + + async def join(self) -> None: + await asyncio.wait(self.tasks) + + +async def foo() -> int: + return 42 + + +async def main() -> None: + # asyncio.Task is covariant in its type argument, which is unusual since its parent class + # asyncio.Future is invariant in its type argument. This is only sound because asyncio.Task + # is not actually Liskov substitutable for asyncio.Future: it does not implement set_result. + w = Waiter() + t: asyncio.Task[int] = asyncio.create_task(foo()) + w.add(t) + await w.join() diff --git a/mypy/typeshed/stdlib/@tests/test_cases/builtins/check_dict-py39.py b/mypy/typeshed/stdlib/@tests/test_cases/builtins/check_dict-py39.py new file mode 100644 index 000000000000..d707cfed222e --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/builtins/check_dict-py39.py @@ -0,0 +1,67 @@ +""" +Tests for `dict.__(r)or__`. + +`dict.__or__` and `dict.__ror__` were only added in py39, +hence why these are in a separate file to the other test cases for `dict`. +""" + +from __future__ import annotations + +import os +import sys +from typing import Mapping, TypeVar, Union +from typing_extensions import Self, assert_type + +_KT = TypeVar("_KT") +_VT = TypeVar("_VT") + +if sys.version_info >= (3, 9): + + class CustomDictSubclass(dict[_KT, _VT]): + pass + + class CustomMappingWithDunderOr(Mapping[_KT, _VT]): + def __or__(self, other: Mapping[_KT, _VT]) -> dict[_KT, _VT]: + return {} + + def __ror__(self, other: Mapping[_KT, _VT]) -> dict[_KT, _VT]: + return {} + + def __ior__(self, other: Mapping[_KT, _VT]) -> Self: + return self + + def test_dict_dot_or( + a: dict[int, int], + b: CustomDictSubclass[int, int], + c: dict[str, str], + d: Mapping[int, int], + e: CustomMappingWithDunderOr[str, str], + ) -> None: + # dict.__(r)or__ always returns a dict, even if called on a subclass of dict: + assert_type(a | b, dict[int, int]) + assert_type(b | a, dict[int, int]) + + assert_type(a | c, dict[Union[int, str], Union[int, str]]) + + # arbitrary mappings are not accepted by `dict.__or__`; + # it has to be a subclass of `dict` + a | d # type: ignore + + # but Mappings such as `os._Environ` or `CustomMappingWithDunderOr`, + # which define `__ror__` methods that accept `dict`, are fine: + assert_type(a | os.environ, dict[Union[str, int], Union[str, int]]) + assert_type(os.environ | a, dict[Union[str, int], Union[str, int]]) + + assert_type(c | os.environ, dict[str, str]) + assert_type(c | e, dict[str, str]) + + assert_type(os.environ | c, dict[str, str]) + assert_type(e | c, dict[str, str]) + + e |= c + e |= a # type: ignore + + # TODO: this test passes mypy, but fails pyright for some reason: + # c |= e + + c |= a # type: ignore diff --git a/mypy/typeshed/stdlib/@tests/test_cases/builtins/check_dict.py b/mypy/typeshed/stdlib/@tests/test_cases/builtins/check_dict.py new file mode 100644 index 000000000000..aa920d045cbc --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/builtins/check_dict.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from typing import Dict, Generic, Iterable, TypeVar +from typing_extensions import assert_type + +# These do follow `__init__` overloads order: +# mypy and pyright have different opinions about this one: +# mypy raises: 'Need type annotation for "bad"' +# pyright is fine with it. +# bad = dict() +good: dict[str, str] = dict() +assert_type(good, Dict[str, str]) + +assert_type(dict(arg=1), Dict[str, int]) + +_KT = TypeVar("_KT") +_VT = TypeVar("_VT") + + +class KeysAndGetItem(Generic[_KT, _VT]): + data: dict[_KT, _VT] + + def __init__(self, data: dict[_KT, _VT]) -> None: + self.data = data + + def keys(self) -> Iterable[_KT]: + return self.data.keys() + + def __getitem__(self, __k: _KT) -> _VT: + return self.data[__k] + + +kt1: KeysAndGetItem[int, str] = KeysAndGetItem({0: ""}) +assert_type(dict(kt1), Dict[int, str]) +dict(kt1, arg="a") # type: ignore + +kt2: KeysAndGetItem[str, int] = KeysAndGetItem({"": 0}) +assert_type(dict(kt2, arg=1), Dict[str, int]) + + +def test_iterable_tuple_overload(x: Iterable[tuple[int, str]]) -> dict[int, str]: + return dict(x) + + +i1: Iterable[tuple[int, str]] = [(1, "a"), (2, "b")] +test_iterable_tuple_overload(i1) +dict(i1, arg="a") # type: ignore + +i2: Iterable[tuple[str, int]] = [("a", 1), ("b", 2)] +assert_type(dict(i2, arg=1), Dict[str, int]) + +i3: Iterable[str] = ["a.b"] +i4: Iterable[bytes] = [b"a.b"] +assert_type(dict(string.split(".") for string in i3), Dict[str, str]) +assert_type(dict(string.split(b".") for string in i4), Dict[bytes, bytes]) + +dict(["foo", "bar", "baz"]) # type: ignore +dict([b"foo", b"bar", b"baz"]) # type: ignore diff --git a/mypy/typeshed/stdlib/@tests/test_cases/builtins/check_exception_group-py311.py b/mypy/typeshed/stdlib/@tests/test_cases/builtins/check_exception_group-py311.py new file mode 100644 index 000000000000..e53cd12288a4 --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/builtins/check_exception_group-py311.py @@ -0,0 +1,323 @@ +from __future__ import annotations + +import sys +from typing import TypeVar +from typing_extensions import assert_type + +if sys.version_info >= (3, 11): + # This can be removed later, but right now Flake8 does not know + # about these two classes: + from builtins import BaseExceptionGroup, ExceptionGroup + + # BaseExceptionGroup + # ================== + # `BaseExceptionGroup` can work with `BaseException`: + beg = BaseExceptionGroup("x", [SystemExit(), SystemExit()]) + assert_type(beg, BaseExceptionGroup[SystemExit]) + assert_type(beg.exceptions, tuple[SystemExit | BaseExceptionGroup[SystemExit], ...]) + + # Covariance works: + _beg1: BaseExceptionGroup[BaseException] = beg + + # `BaseExceptionGroup` can work with `Exception`: + beg2 = BaseExceptionGroup("x", [ValueError()]) + # FIXME: this is not right, runtime returns `ExceptionGroup` instance instead, + # but I am unable to represent this with types right now. + assert_type(beg2, BaseExceptionGroup[ValueError]) + + # .subgroup() + # ----------- + + assert_type(beg.subgroup(KeyboardInterrupt), BaseExceptionGroup[KeyboardInterrupt] | None) + assert_type(beg.subgroup((KeyboardInterrupt,)), BaseExceptionGroup[KeyboardInterrupt] | None) + + def is_base_exc(exc: BaseException) -> bool: + return isinstance(exc, BaseException) + + def is_specific(exc: SystemExit | BaseExceptionGroup[SystemExit]) -> bool: + return isinstance(exc, SystemExit) + + # This one does not have `BaseExceptionGroup` part, + # this is why we treat as an error. + def is_system_exit(exc: SystemExit) -> bool: + return isinstance(exc, SystemExit) + + def unrelated_subgroup(exc: KeyboardInterrupt) -> bool: + return False + + assert_type(beg.subgroup(is_base_exc), BaseExceptionGroup[SystemExit] | None) + assert_type(beg.subgroup(is_specific), BaseExceptionGroup[SystemExit] | None) + beg.subgroup(is_system_exit) # type: ignore + beg.subgroup(unrelated_subgroup) # type: ignore + + # `Exception`` subgroup returns `ExceptionGroup`: + assert_type(beg.subgroup(ValueError), ExceptionGroup[ValueError] | None) + assert_type(beg.subgroup((ValueError,)), ExceptionGroup[ValueError] | None) + + # Callable are harder, we don't support cast to `ExceptionGroup` here. + # Because callables might return `True` the first time. And `BaseExceptionGroup` + # will stick, no matter what arguments are. + + def is_exception(exc: Exception) -> bool: + return isinstance(exc, Exception) + + def is_exception_or_beg(exc: Exception | BaseExceptionGroup[SystemExit]) -> bool: + return isinstance(exc, Exception) + + # This is an error because of the `Exception` argument type, + # while `SystemExit` is needed instead. + beg.subgroup(is_exception_or_beg) # type: ignore + + # This is an error, because `BaseExceptionGroup` is not an `Exception` + # subclass. It is required. + beg.subgroup(is_exception) # type: ignore + + # .split() + # -------- + + assert_type( + beg.split(KeyboardInterrupt), tuple[BaseExceptionGroup[KeyboardInterrupt] | None, BaseExceptionGroup[SystemExit] | None] + ) + assert_type( + beg.split((KeyboardInterrupt,)), + tuple[BaseExceptionGroup[KeyboardInterrupt] | None, BaseExceptionGroup[SystemExit] | None], + ) + assert_type( + beg.split(ValueError), # there are no `ValueError` items in there, but anyway + tuple[ExceptionGroup[ValueError] | None, BaseExceptionGroup[SystemExit] | None], + ) + + excs_to_split: list[ValueError | KeyError | SystemExit] = [ValueError(), KeyError(), SystemExit()] + to_split = BaseExceptionGroup("x", excs_to_split) + assert_type(to_split, BaseExceptionGroup[ValueError | KeyError | SystemExit]) + + # Ideally the first part should be `ExceptionGroup[ValueError]` (done) + # and the second part should be `BaseExceptionGroup[KeyError | SystemExit]`, + # but we cannot subtract type from a union. + # We also cannot change `BaseExceptionGroup` to `ExceptionGroup` even if needed + # in the second part here because of that. + assert_type( + to_split.split(ValueError), + tuple[ExceptionGroup[ValueError] | None, BaseExceptionGroup[ValueError | KeyError | SystemExit] | None], + ) + + def split_callable1(exc: ValueError | KeyError | SystemExit | BaseExceptionGroup[ValueError | KeyError | SystemExit]) -> bool: + return True + + assert_type( + to_split.split(split_callable1), # Concrete type is ok + tuple[ + BaseExceptionGroup[ValueError | KeyError | SystemExit] | None, + BaseExceptionGroup[ValueError | KeyError | SystemExit] | None, + ], + ) + assert_type( + to_split.split(is_base_exc), # Base class is ok + tuple[ + BaseExceptionGroup[ValueError | KeyError | SystemExit] | None, + BaseExceptionGroup[ValueError | KeyError | SystemExit] | None, + ], + ) + # `Exception` cannot be used: `BaseExceptionGroup` is not a subtype of it. + to_split.split(is_exception) # type: ignore + + # .derive() + # --------- + + assert_type(beg.derive([ValueError()]), ExceptionGroup[ValueError]) + assert_type(beg.derive([KeyboardInterrupt()]), BaseExceptionGroup[KeyboardInterrupt]) + + # ExceptionGroup + # ============== + + # `ExceptionGroup` can work with `Exception`: + excs: list[ValueError | KeyError] = [ValueError(), KeyError()] + eg = ExceptionGroup("x", excs) + assert_type(eg, ExceptionGroup[ValueError | KeyError]) + assert_type(eg.exceptions, tuple[ValueError | KeyError | ExceptionGroup[ValueError | KeyError], ...]) + + # Covariance works: + _eg1: ExceptionGroup[Exception] = eg + + # `ExceptionGroup` cannot work with `BaseException`: + ExceptionGroup("x", [SystemExit()]) # type: ignore + + # .subgroup() + # ----------- + + # Our decision is to ban cases like:: + # + # >>> eg = ExceptionGroup('x', [ValueError()]) + # >>> eg.subgroup(BaseException) + # ExceptionGroup('e', [ValueError()]) + # + # are possible in runtime. + # We do it because, it does not make sense for all other base exception types. + # Supporting just `BaseException` looks like an overkill. + eg.subgroup(BaseException) # type: ignore + eg.subgroup((KeyboardInterrupt, SystemExit)) # type: ignore + + assert_type(eg.subgroup(Exception), ExceptionGroup[Exception] | None) + assert_type(eg.subgroup(ValueError), ExceptionGroup[ValueError] | None) + assert_type(eg.subgroup((ValueError,)), ExceptionGroup[ValueError] | None) + + def subgroup_eg1(exc: ValueError | KeyError | ExceptionGroup[ValueError | KeyError]) -> bool: + return True + + def subgroup_eg2(exc: ValueError | KeyError) -> bool: + return True + + assert_type(eg.subgroup(subgroup_eg1), ExceptionGroup[ValueError | KeyError] | None) + assert_type(eg.subgroup(is_exception), ExceptionGroup[ValueError | KeyError] | None) + assert_type(eg.subgroup(is_base_exc), ExceptionGroup[ValueError | KeyError] | None) + assert_type(eg.subgroup(is_base_exc), ExceptionGroup[ValueError | KeyError] | None) + + # Does not have `ExceptionGroup` part: + eg.subgroup(subgroup_eg2) # type: ignore + + # .split() + # -------- + + assert_type(eg.split(TypeError), tuple[ExceptionGroup[TypeError] | None, ExceptionGroup[ValueError | KeyError] | None]) + assert_type(eg.split((TypeError,)), tuple[ExceptionGroup[TypeError] | None, ExceptionGroup[ValueError | KeyError] | None]) + assert_type( + eg.split(is_exception), tuple[ExceptionGroup[ValueError | KeyError] | None, ExceptionGroup[ValueError | KeyError] | None] + ) + assert_type( + eg.split(is_base_exc), + # is not converted, because `ExceptionGroup` cannot have + # direct `BaseException` subclasses inside. + tuple[ExceptionGroup[ValueError | KeyError] | None, ExceptionGroup[ValueError | KeyError] | None], + ) + + # It does not include `ExceptionGroup` itself, so it will fail: + def value_or_key_error(exc: ValueError | KeyError) -> bool: + return isinstance(exc, (ValueError, KeyError)) + + eg.split(value_or_key_error) # type: ignore + + # `ExceptionGroup` cannot have direct `BaseException` subclasses inside. + eg.split(BaseException) # type: ignore + eg.split((SystemExit, GeneratorExit)) # type: ignore + + # .derive() + # --------- + + assert_type(eg.derive([ValueError()]), ExceptionGroup[ValueError]) + assert_type(eg.derive([KeyboardInterrupt()]), BaseExceptionGroup[KeyboardInterrupt]) + + # BaseExceptionGroup Custom Subclass + # ================================== + # In some cases `Self` type can be preserved in runtime, + # but it is impossible to express. That's why we always fallback to + # `BaseExceptionGroup` and `ExceptionGroup`. + + _BE = TypeVar("_BE", bound=BaseException) + + class CustomBaseGroup(BaseExceptionGroup[_BE]): ... + + cb1 = CustomBaseGroup("x", [SystemExit()]) + assert_type(cb1, CustomBaseGroup[SystemExit]) + cb2 = CustomBaseGroup("x", [ValueError()]) + assert_type(cb2, CustomBaseGroup[ValueError]) + + # .subgroup() + # ----------- + + assert_type(cb1.subgroup(KeyboardInterrupt), BaseExceptionGroup[KeyboardInterrupt] | None) + assert_type(cb2.subgroup((KeyboardInterrupt,)), BaseExceptionGroup[KeyboardInterrupt] | None) + + assert_type(cb1.subgroup(ValueError), ExceptionGroup[ValueError] | None) + assert_type(cb2.subgroup((KeyError,)), ExceptionGroup[KeyError] | None) + + def cb_subgroup1(exc: SystemExit | CustomBaseGroup[SystemExit]) -> bool: + return True + + def cb_subgroup2(exc: ValueError | CustomBaseGroup[ValueError]) -> bool: + return True + + assert_type(cb1.subgroup(cb_subgroup1), BaseExceptionGroup[SystemExit] | None) + assert_type(cb2.subgroup(cb_subgroup2), BaseExceptionGroup[ValueError] | None) + cb1.subgroup(cb_subgroup2) # type: ignore + cb2.subgroup(cb_subgroup1) # type: ignore + + # .split() + # -------- + + assert_type( + cb1.split(KeyboardInterrupt), tuple[BaseExceptionGroup[KeyboardInterrupt] | None, BaseExceptionGroup[SystemExit] | None] + ) + assert_type(cb1.split(TypeError), tuple[ExceptionGroup[TypeError] | None, BaseExceptionGroup[SystemExit] | None]) + assert_type(cb2.split((TypeError,)), tuple[ExceptionGroup[TypeError] | None, BaseExceptionGroup[ValueError] | None]) + + def cb_split1(exc: SystemExit | CustomBaseGroup[SystemExit]) -> bool: + return True + + def cb_split2(exc: ValueError | CustomBaseGroup[ValueError]) -> bool: + return True + + assert_type(cb1.split(cb_split1), tuple[BaseExceptionGroup[SystemExit] | None, BaseExceptionGroup[SystemExit] | None]) + assert_type(cb2.split(cb_split2), tuple[BaseExceptionGroup[ValueError] | None, BaseExceptionGroup[ValueError] | None]) + cb1.split(cb_split2) # type: ignore + cb2.split(cb_split1) # type: ignore + + # .derive() + # --------- + + # Note, that `Self` type is not preserved in runtime. + assert_type(cb1.derive([ValueError()]), ExceptionGroup[ValueError]) + assert_type(cb1.derive([KeyboardInterrupt()]), BaseExceptionGroup[KeyboardInterrupt]) + assert_type(cb2.derive([ValueError()]), ExceptionGroup[ValueError]) + assert_type(cb2.derive([KeyboardInterrupt()]), BaseExceptionGroup[KeyboardInterrupt]) + + # ExceptionGroup Custom Subclass + # ============================== + + _E = TypeVar("_E", bound=Exception) + + class CustomGroup(ExceptionGroup[_E]): ... + + CustomGroup("x", [SystemExit()]) # type: ignore + cg1 = CustomGroup("x", [ValueError()]) + assert_type(cg1, CustomGroup[ValueError]) + + # .subgroup() + # ----------- + + cg1.subgroup(BaseException) # type: ignore + cg1.subgroup((KeyboardInterrupt, SystemExit)) # type: ignore + + assert_type(cg1.subgroup(ValueError), ExceptionGroup[ValueError] | None) + assert_type(cg1.subgroup((KeyError,)), ExceptionGroup[KeyError] | None) + + def cg_subgroup1(exc: ValueError | CustomGroup[ValueError]) -> bool: + return True + + def cg_subgroup2(exc: ValueError) -> bool: + return True + + assert_type(cg1.subgroup(cg_subgroup1), ExceptionGroup[ValueError] | None) + cg1.subgroup(cb_subgroup2) # type: ignore + + # .split() + # -------- + + assert_type(cg1.split(TypeError), tuple[ExceptionGroup[TypeError] | None, ExceptionGroup[ValueError] | None]) + assert_type(cg1.split((TypeError,)), tuple[ExceptionGroup[TypeError] | None, ExceptionGroup[ValueError] | None]) + cg1.split(BaseException) # type: ignore + + def cg_split1(exc: ValueError | CustomGroup[ValueError]) -> bool: + return True + + def cg_split2(exc: ValueError) -> bool: + return True + + assert_type(cg1.split(cg_split1), tuple[ExceptionGroup[ValueError] | None, ExceptionGroup[ValueError] | None]) + cg1.split(cg_split2) # type: ignore + + # .derive() + # --------- + + # Note, that `Self` type is not preserved in runtime. + assert_type(cg1.derive([ValueError()]), ExceptionGroup[ValueError]) + assert_type(cg1.derive([KeyboardInterrupt()]), BaseExceptionGroup[KeyboardInterrupt]) diff --git a/mypy/typeshed/stdlib/@tests/test_cases/builtins/check_iteration.py b/mypy/typeshed/stdlib/@tests/test_cases/builtins/check_iteration.py new file mode 100644 index 000000000000..3d609635377e --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/builtins/check_iteration.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from typing import Iterator +from typing_extensions import assert_type + + +class OldStyleIter: + def __getitem__(self, index: int) -> str: + return str(index) + + +for x in iter(OldStyleIter()): + assert_type(x, str) + +assert_type(iter(OldStyleIter()), Iterator[str]) +assert_type(next(iter(OldStyleIter())), str) diff --git a/mypy/typeshed/stdlib/@tests/test_cases/builtins/check_list.py b/mypy/typeshed/stdlib/@tests/test_cases/builtins/check_list.py new file mode 100644 index 000000000000..4113f5c66182 --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/builtins/check_list.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from typing import List, Union +from typing_extensions import assert_type + + +# list.__add__ example from #8292 +class Foo: + def asd(self) -> int: + return 1 + + +class Bar: + def asd(self) -> int: + return 2 + + +combined = [Foo()] + [Bar()] +assert_type(combined, List[Union[Foo, Bar]]) +for item in combined: + assert_type(item.asd(), int) diff --git a/mypy/typeshed/stdlib/@tests/test_cases/builtins/check_object.py b/mypy/typeshed/stdlib/@tests/test_cases/builtins/check_object.py new file mode 100644 index 000000000000..60df1143f727 --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/builtins/check_object.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from typing import Any + + +# The following should pass without error (see #6661): +class Diagnostic: + def __reduce__(self) -> str | tuple[Any, ...]: + res = super().__reduce__() + if isinstance(res, tuple) and len(res) >= 3: + res[2]["_info"] = 42 + + return res diff --git a/mypy/typeshed/stdlib/@tests/test_cases/builtins/check_pow.py b/mypy/typeshed/stdlib/@tests/test_cases/builtins/check_pow.py new file mode 100644 index 000000000000..1f38710d6bea --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/builtins/check_pow.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +from decimal import Decimal +from fractions import Fraction +from typing import Any, Literal +from typing_extensions import assert_type + +# See #7163 +assert_type(pow(1, 0), Literal[1]) +assert_type(1**0, Literal[1]) +assert_type(pow(1, 0, None), Literal[1]) + +# TODO: We don't have a good way of expressing the fact +# that passing 0 for the third argument will lead to an exception being raised +# (see discussion in #8566) +# +# assert_type(pow(2, 4, 0), NoReturn) + +assert_type(pow(2, 4), int) +assert_type(2**4, int) +assert_type(pow(4, 6, None), int) + +assert_type(pow(5, -7), float) +assert_type(5**-7, float) + +assert_type(pow(2, 4, 5), int) # pow(, , ) +assert_type(pow(2, 35, 3), int) # pow(, , ) + +assert_type(pow(2, 8.5), float) +assert_type(2**8.6, float) +assert_type(pow(2, 8.6, None), float) + +# TODO: Why does this pass pyright but not mypy?? +# assert_type((-2) ** 0.5, complex) + +assert_type(pow((-5), 8.42, None), complex) + +assert_type(pow(4.6, 8), float) +assert_type(4.6**8, float) +assert_type(pow(5.1, 4, None), float) + +assert_type(pow(complex(6), 6.2), complex) +assert_type(complex(6) ** 6.2, complex) +assert_type(pow(complex(9), 7.3, None), complex) + +assert_type(pow(Fraction(), 4, None), Fraction) +assert_type(Fraction() ** 4, Fraction) + +assert_type(pow(Fraction(3, 7), complex(1, 8)), complex) +assert_type(Fraction(3, 7) ** complex(1, 8), complex) + +assert_type(pow(complex(4, -8), Fraction(2, 3)), complex) +assert_type(complex(4, -8) ** Fraction(2, 3), complex) + +assert_type(pow(Decimal("1.0"), Decimal("1.6")), Decimal) +assert_type(Decimal("1.0") ** Decimal("1.6"), Decimal) + +assert_type(pow(Decimal("1.0"), Decimal("1.0"), Decimal("1.0")), Decimal) +assert_type(pow(Decimal("4.6"), 7, None), Decimal) +assert_type(Decimal("4.6") ** 7, Decimal) + +# These would ideally be more precise, but `Any` is acceptable +# They have to be `Any` due to the fact that type-checkers can't distinguish +# between positive and negative numbers for the second argument to `pow()` +# +# int for positive 2nd-arg, float otherwise +assert_type(pow(4, 65), Any) +assert_type(pow(2, -45), Any) +assert_type(pow(3, 57, None), Any) +assert_type(pow(67, 0.98, None), Any) +assert_type(87**7.32, Any) +# pow(, ) -> float +# pow(, ) -> complex +assert_type(pow(4.7, 7.4), Any) +assert_type(pow(-9.8, 8.3), Any) +assert_type(pow(-9.3, -88.2), Any) +assert_type(pow(8.2, -9.8), Any) +assert_type(pow(4.7, 9.2, None), Any) +# See #7046 -- float for a positive 1st arg, complex otherwise +assert_type((-95) ** 8.42, Any) + +# All of the following cases should fail a type-checker. +pow(1.9, 4, 6) # type: ignore +pow(4, 7, 4.32) # type: ignore +pow(6.2, 5.9, 73) # type: ignore +pow(complex(6), 6.2, 7) # type: ignore +pow(Fraction(), 5, 8) # type: ignore +Decimal("8.7") ** 3.14 # type: ignore + +# TODO: This fails at runtime, but currently passes mypy and pyright: +pow(Decimal("8.5"), 3.21) diff --git a/mypy/typeshed/stdlib/@tests/test_cases/builtins/check_reversed.py b/mypy/typeshed/stdlib/@tests/test_cases/builtins/check_reversed.py new file mode 100644 index 000000000000..2a43a57deb4e --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/builtins/check_reversed.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from collections.abc import Iterator +from typing import Generic, TypeVar +from typing_extensions import assert_type + +x: list[int] = [] +assert_type(list(reversed(x)), "list[int]") + + +class MyReversible: + def __iter__(self) -> Iterator[str]: + yield "blah" + + def __reversed__(self) -> Iterator[str]: + yield "blah" + + +assert_type(list(reversed(MyReversible())), "list[str]") + + +_T = TypeVar("_T") + + +class MyLenAndGetItem(Generic[_T]): + def __len__(self) -> int: + return 0 + + def __getitem__(self, item: int) -> _T: + raise KeyError + + +len_and_get_item: MyLenAndGetItem[int] = MyLenAndGetItem() +assert_type(list(reversed(len_and_get_item)), "list[int]") diff --git a/mypy/typeshed/stdlib/@tests/test_cases/builtins/check_round.py b/mypy/typeshed/stdlib/@tests/test_cases/builtins/check_round.py new file mode 100644 index 000000000000..84081f3665b9 --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/builtins/check_round.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from typing import overload +from typing_extensions import assert_type + + +class CustomIndex: + def __index__(self) -> int: + return 1 + + +# float: + +assert_type(round(5.5), int) +assert_type(round(5.5, None), int) +assert_type(round(5.5, 0), float) +assert_type(round(5.5, 1), float) +assert_type(round(5.5, 5), float) +assert_type(round(5.5, CustomIndex()), float) + +# int: + +assert_type(round(1), int) +assert_type(round(1, 1), int) +assert_type(round(1, None), int) +assert_type(round(1, CustomIndex()), int) + +# Protocols: + + +class WithCustomRound1: + def __round__(self) -> str: + return "a" + + +assert_type(round(WithCustomRound1()), str) +assert_type(round(WithCustomRound1(), None), str) +# Errors: +round(WithCustomRound1(), 1) # type: ignore +round(WithCustomRound1(), CustomIndex()) # type: ignore + + +class WithCustomRound2: + def __round__(self, digits: int) -> str: + return "a" + + +assert_type(round(WithCustomRound2(), 1), str) +assert_type(round(WithCustomRound2(), CustomIndex()), str) +# Errors: +round(WithCustomRound2(), None) # type: ignore +round(WithCustomRound2()) # type: ignore + + +class WithOverloadedRound: + @overload + def __round__(self, ndigits: None = ...) -> str: ... + + @overload + def __round__(self, ndigits: int) -> bytes: ... + + def __round__(self, ndigits: int | None = None) -> str | bytes: + return b"" if ndigits is None else "" + + +assert_type(round(WithOverloadedRound()), str) +assert_type(round(WithOverloadedRound(), None), str) +assert_type(round(WithOverloadedRound(), 1), bytes) diff --git a/mypy/typeshed/stdlib/@tests/test_cases/builtins/check_sum.py b/mypy/typeshed/stdlib/@tests/test_cases/builtins/check_sum.py new file mode 100644 index 000000000000..cda7eadbbe41 --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/builtins/check_sum.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from typing import Any, List, Literal, Union +from typing_extensions import assert_type + + +class Foo: + def __add__(self, other: Any) -> Foo: + return Foo() + + +class Bar: + def __radd__(self, other: Any) -> Bar: + return Bar() + + +class Baz: + def __add__(self, other: Any) -> Baz: + return Baz() + + def __radd__(self, other: Any) -> Baz: + return Baz() + + +literal_list: list[Literal[0, 1]] = [0, 1, 1] + +assert_type(sum([2, 4]), int) +assert_type(sum([3, 5], 4), int) + +assert_type(sum([True, False]), int) +assert_type(sum([True, False], True), int) +assert_type(sum(literal_list), int) + +assert_type(sum([["foo"], ["bar"]], ["baz"]), List[str]) + +assert_type(sum([Foo(), Foo()], Foo()), Foo) +assert_type(sum([Baz(), Baz()]), Union[Baz, Literal[0]]) + +# mypy and pyright infer the types differently for these, so we can't use assert_type +# Just test that no error is emitted for any of these +sum([("foo",), ("bar", "baz")], ()) # mypy: `tuple[str, ...]`; pyright: `tuple[()] | tuple[str] | tuple[str, str]` +sum([5.6, 3.2]) # mypy: `float`; pyright: `float | Literal[0]` +sum([2.5, 5.8], 5) # mypy: `float`; pyright: `float | int` + +# These all fail at runtime +sum("abcde") # type: ignore +sum([["foo"], ["bar"]]) # type: ignore +sum([("foo",), ("bar", "baz")]) # type: ignore +sum([Foo(), Foo()]) # type: ignore +sum([Bar(), Bar()], Bar()) # type: ignore +sum([Bar(), Bar()]) # type: ignore + +# TODO: these pass pyright with the current stubs, but mypy erroneously emits an error: +# sum([3, Fraction(7, 22), complex(8, 0), 9.83]) +# sum([3, Decimal('0.98')]) diff --git a/mypy/typeshed/stdlib/@tests/test_cases/builtins/check_tuple.py b/mypy/typeshed/stdlib/@tests/test_cases/builtins/check_tuple.py new file mode 100644 index 000000000000..bc0d8db28389 --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/builtins/check_tuple.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from typing import Tuple +from typing_extensions import assert_type + + +# Empty tuples, see #8275 +class TupleSub(Tuple[int, ...]): + pass + + +assert_type(TupleSub(), TupleSub) +assert_type(TupleSub([1, 2, 3]), TupleSub) diff --git a/mypy/typeshed/stdlib/@tests/test_cases/check_codecs.py b/mypy/typeshed/stdlib/@tests/test_cases/check_codecs.py new file mode 100644 index 000000000000..19e663ceeaaf --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/check_codecs.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +import codecs +from typing_extensions import assert_type + +assert_type(codecs.decode("x", "unicode-escape"), str) +assert_type(codecs.decode(b"x", "unicode-escape"), str) + +assert_type(codecs.decode(b"x", "utf-8"), str) +codecs.decode("x", "utf-8") # type: ignore + +assert_type(codecs.decode("ab", "hex"), bytes) +assert_type(codecs.decode(b"ab", "hex"), bytes) diff --git a/mypy/typeshed/stdlib/@tests/test_cases/check_concurrent_futures.py b/mypy/typeshed/stdlib/@tests/test_cases/check_concurrent_futures.py new file mode 100644 index 000000000000..962ec23c6b48 --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/check_concurrent_futures.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from collections.abc import Callable, Iterator +from concurrent.futures import Future, ThreadPoolExecutor, as_completed +from typing_extensions import assert_type + + +class Parent: ... + + +class Child(Parent): ... + + +def check_as_completed_covariance() -> None: + with ThreadPoolExecutor() as executor: + f1 = executor.submit(lambda: Parent()) + f2 = executor.submit(lambda: Child()) + fs: list[Future[Parent] | Future[Child]] = [f1, f2] + assert_type(as_completed(fs), Iterator[Future[Parent]]) + for future in as_completed(fs): + assert_type(future.result(), Parent) + + +def check_future_invariance() -> None: + def execute_callback(callback: Callable[[], Parent], future: Future[Parent]) -> None: + future.set_result(callback()) + + fut: Future[Child] = Future() + execute_callback(lambda: Parent(), fut) # type: ignore + assert isinstance(fut.result(), Child) diff --git a/mypy/typeshed/stdlib/@tests/test_cases/check_contextlib.py b/mypy/typeshed/stdlib/@tests/test_cases/check_contextlib.py new file mode 100644 index 000000000000..648661bca856 --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/check_contextlib.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from contextlib import ExitStack +from typing_extensions import assert_type + + +# See issue #7961 +class Thing(ExitStack): + pass + + +stack = ExitStack() +thing = Thing() +assert_type(stack.enter_context(Thing()), Thing) +assert_type(thing.enter_context(ExitStack()), ExitStack) + +with stack as cm: + assert_type(cm, ExitStack) +with thing as cm2: + assert_type(cm2, Thing) diff --git a/mypy/typeshed/stdlib/@tests/test_cases/check_dataclasses.py b/mypy/typeshed/stdlib/@tests/test_cases/check_dataclasses.py new file mode 100644 index 000000000000..76ce8e1bd260 --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/check_dataclasses.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +import dataclasses as dc +from typing import TYPE_CHECKING, Any, Dict, FrozenSet, Tuple, Type, Union +from typing_extensions import Annotated, assert_type + +if TYPE_CHECKING: + from _typeshed import DataclassInstance + + +@dc.dataclass +class Foo: + attr: str + + +assert_type(dc.fields(Foo), Tuple[dc.Field[Any], ...]) + +# Mypy correctly emits errors on these +# due to the fact it's a dataclass class, not an instance. +# Pyright, however, handles ClassVar members in protocols differently. +# See https://github.com/microsoft/pyright/issues/4339 +# +# dc.asdict(Foo) +# dc.astuple(Foo) +# dc.replace(Foo) + +# See #9723 for why we can't make this assertion +# if dc.is_dataclass(Foo): +# assert_type(Foo, Type[Foo]) + +f = Foo(attr="attr") + +assert_type(dc.fields(f), Tuple[dc.Field[Any], ...]) +assert_type(dc.asdict(f), Dict[str, Any]) +assert_type(dc.astuple(f), Tuple[Any, ...]) +assert_type(dc.replace(f, attr="new"), Foo) + +if dc.is_dataclass(f): + # The inferred type doesn't change + # if it's already known to be a subtype of _DataclassInstance + assert_type(f, Foo) + + +def check_other_isdataclass_overloads(x: type, y: object) -> None: + # TODO: pyright correctly emits an error on this, but mypy does not -- why? + # dc.fields(x) + + dc.fields(y) # type: ignore + + dc.asdict(x) # type: ignore + dc.asdict(y) # type: ignore + + dc.astuple(x) # type: ignore + dc.astuple(y) # type: ignore + + dc.replace(x) # type: ignore + dc.replace(y) # type: ignore + + if dc.is_dataclass(x): + assert_type(x, Type["DataclassInstance"]) + assert_type(dc.fields(x), Tuple[dc.Field[Any], ...]) + + # Mypy correctly emits an error on these due to the fact + # that it's a dataclass class, not a dataclass instance. + # Pyright, however, handles ClassVar members in protocols differently. + # See https://github.com/microsoft/pyright/issues/4339 + # + # dc.asdict(x) + # dc.astuple(x) + # dc.replace(x) + + if dc.is_dataclass(y): + assert_type(y, Union["DataclassInstance", Type["DataclassInstance"]]) + assert_type(dc.fields(y), Tuple[dc.Field[Any], ...]) + + # Mypy correctly emits an error on these due to the fact we don't know + # whether it's a dataclass class or a dataclass instance. + # Pyright, however, handles ClassVar members in protocols differently. + # See https://github.com/microsoft/pyright/issues/4339 + # + # dc.asdict(y) + # dc.astuple(y) + # dc.replace(y) + + if dc.is_dataclass(y) and not isinstance(y, type): + assert_type(y, "DataclassInstance") + assert_type(dc.fields(y), Tuple[dc.Field[Any], ...]) + assert_type(dc.asdict(y), Dict[str, Any]) + assert_type(dc.astuple(y), Tuple[Any, ...]) + dc.replace(y) + + +# Regression test for #11653 +D = dc.make_dataclass( + "D", [("a", Union[int, None]), "y", ("z", Annotated[FrozenSet[bytes], "metadata"], dc.field(default=frozenset({b"foo"})))] +) +# Check that it's inferred by the type checker as a class object of some kind +# (but don't assert the exact type that `D` is inferred as, +# in case a type checker decides to add some special-casing for +# `make_dataclass` in the future) +assert_type(D.__mro__, Tuple[type, ...]) diff --git a/mypy/typeshed/stdlib/@tests/test_cases/check_enum.py b/mypy/typeshed/stdlib/@tests/test_cases/check_enum.py new file mode 100644 index 000000000000..4ea4947c811d --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/check_enum.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import enum +import sys +from typing import Literal, Type +from typing_extensions import assert_type + +A = enum.Enum("A", "spam eggs bacon") +B = enum.Enum("B", ["spam", "eggs", "bacon"]) +C = enum.Enum("Bar", [("spam", 1), ("eggs", 2), ("bacon", 3)]) +D = enum.Enum("Bar", {"spam": 1, "eggs": 2}) + +assert_type(A, Type[A]) +assert_type(B, Type[B]) +assert_type(C, Type[C]) +assert_type(D, Type[D]) + + +class EnumOfTuples(enum.Enum): + X = 1, 2, 3 + Y = 4, 5, 6 + + +assert_type(EnumOfTuples((1, 2, 3)), EnumOfTuples) + +# TODO: ideally this test would pass: +# +# if sys.version_info >= (3, 12): +# assert_type(EnumOfTuples(1, 2, 3), EnumOfTuples) + + +if sys.version_info >= (3, 11): + + class Foo(enum.StrEnum): + X = enum.auto() + + assert_type(Foo.X, Literal[Foo.X]) + assert_type(Foo.X.value, str) diff --git a/mypy/typeshed/stdlib/@tests/test_cases/check_functools.py b/mypy/typeshed/stdlib/@tests/test_cases/check_functools.py new file mode 100644 index 000000000000..dca572683f8d --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/check_functools.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from functools import cached_property, wraps +from typing import Callable, TypeVar +from typing_extensions import ParamSpec, assert_type + +P = ParamSpec("P") +T_co = TypeVar("T_co", covariant=True) + + +def my_decorator(func: Callable[P, T_co]) -> Callable[P, T_co]: + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> T_co: + print(args) + return func(*args, **kwargs) + + # verify that the wrapped function has all these attributes + wrapper.__annotations__ = func.__annotations__ + wrapper.__doc__ = func.__doc__ + wrapper.__module__ = func.__module__ + wrapper.__name__ = func.__name__ + wrapper.__qualname__ = func.__qualname__ + return wrapper + + +class A: + def __init__(self, x: int): + self.x = x + + @cached_property + def x(self) -> int: + return 0 + + +assert_type(A(x=1).x, int) + + +class B: + @cached_property + def x(self) -> int: + return 0 + + +def check_cached_property_settable(x: int) -> None: + b = B() + assert_type(b.x, int) + b.x = x + assert_type(b.x, int) + + +# https://github.com/python/typeshed/issues/10048 +class Parent: ... + + +class Child(Parent): ... + + +class X: + @cached_property + def some(self) -> Parent: + return Parent() + + +class Y(X): + @cached_property + def some(self) -> Child: # safe override + return Child() diff --git a/mypy/typeshed/stdlib/@tests/test_cases/check_importlib.py b/mypy/typeshed/stdlib/@tests/test_cases/check_importlib.py new file mode 100644 index 000000000000..17eefdafc971 --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/check_importlib.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import importlib.abc +import importlib.util +import pathlib +import sys +import zipfile +from collections.abc import Sequence +from importlib.machinery import ModuleSpec +from types import ModuleType +from typing_extensions import Self + +# Assert that some Path classes are Traversable. +if sys.version_info >= (3, 9): + + def traverse(t: importlib.abc.Traversable) -> None: + pass + + traverse(pathlib.Path()) + traverse(zipfile.Path("")) + + +class MetaFinder: + @classmethod + def find_spec(cls, fullname: str, path: Sequence[str] | None, target: ModuleType | None = None) -> ModuleSpec | None: + return None # simplified mock for demonstration purposes only + + +class PathFinder: + @classmethod + def path_hook(cls, path_entry: str) -> type[Self]: + return cls # simplified mock for demonstration purposes only + + @classmethod + def find_spec(cls, fullname: str, target: ModuleType | None = None) -> ModuleSpec | None: + return None # simplified mock for demonstration purposes only + + +class Loader: + @classmethod + def load_module(cls, fullname: str) -> ModuleType: + return ModuleType(fullname) + + +sys.meta_path.append(MetaFinder) +sys.path_hooks.append(PathFinder.path_hook) +importlib.util.spec_from_loader("xxxx42xxxx", Loader) diff --git a/mypy/typeshed/stdlib/@tests/test_cases/check_importlib_metadata.py b/mypy/typeshed/stdlib/@tests/test_cases/check_importlib_metadata.py new file mode 100644 index 000000000000..f1322e16c54f --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/check_importlib_metadata.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import sys +from _typeshed import StrPath +from os import PathLike +from pathlib import Path +from typing import Any +from zipfile import Path as ZipPath + +if sys.version_info >= (3, 10): + from importlib.metadata._meta import SimplePath + + # Simplified version of zipfile.Path + class MyPath: + @property + def parent(self) -> PathLike[str]: ... # undocumented + + def read_text(self, encoding: str | None = ..., errors: str | None = ...) -> str: ... + def joinpath(self, *other: StrPath) -> MyPath: ... + def __truediv__(self, add: StrPath) -> MyPath: ... + + if sys.version_info >= (3, 12): + + def takes_simple_path(p: SimplePath[Any]) -> None: ... + + else: + + def takes_simple_path(p: SimplePath) -> None: ... + + takes_simple_path(Path()) + takes_simple_path(ZipPath("")) + takes_simple_path(MyPath()) + takes_simple_path("some string") # type: ignore diff --git a/mypy/typeshed/stdlib/@tests/test_cases/check_io.py b/mypy/typeshed/stdlib/@tests/test_cases/check_io.py new file mode 100644 index 000000000000..abf84dd5a103 --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/check_io.py @@ -0,0 +1,6 @@ +from gzip import GzipFile +from io import FileIO, TextIOWrapper + +TextIOWrapper(FileIO("")) +TextIOWrapper(FileIO(13)) +TextIOWrapper(GzipFile("")) diff --git a/mypy/typeshed/stdlib/@tests/test_cases/check_logging.py b/mypy/typeshed/stdlib/@tests/test_cases/check_logging.py new file mode 100644 index 000000000000..fe3d8eb16fd0 --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/check_logging.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +import logging +import logging.handlers +import multiprocessing +import queue +from typing import Any + +# This pattern comes from the logging docs, and should therefore pass a type checker +# See https://docs.python.org/3/library/logging.html#logrecord-objects + +old_factory = logging.getLogRecordFactory() + + +def record_factory(*args: Any, **kwargs: Any) -> logging.LogRecord: + record = old_factory(*args, **kwargs) + record.custom_attribute = 0xDECAFBAD + return record + + +logging.setLogRecordFactory(record_factory) + +# The logging docs say that QueueHandler and QueueListener can take "any queue-like object" +# We test that here (regression test for #10168) +logging.handlers.QueueHandler(queue.Queue()) +logging.handlers.QueueHandler(queue.SimpleQueue()) +logging.handlers.QueueHandler(multiprocessing.Queue()) +logging.handlers.QueueListener(queue.Queue()) +logging.handlers.QueueListener(queue.SimpleQueue()) +logging.handlers.QueueListener(multiprocessing.Queue()) diff --git a/mypy/typeshed/stdlib/@tests/test_cases/check_multiprocessing.py b/mypy/typeshed/stdlib/@tests/test_cases/check_multiprocessing.py new file mode 100644 index 000000000000..201f96c0c4c8 --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/check_multiprocessing.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +from ctypes import c_char, c_float +from multiprocessing import Array, Value +from multiprocessing.sharedctypes import Synchronized, SynchronizedString +from typing_extensions import assert_type + +string = Array(c_char, 12) +assert_type(string, SynchronizedString) +assert_type(string.value, bytes) + +field = Value(c_float, 0.0) +assert_type(field, Synchronized[float]) +field.value = 1.2 diff --git a/mypy/typeshed/stdlib/@tests/test_cases/check_pathlib.py b/mypy/typeshed/stdlib/@tests/test_cases/check_pathlib.py new file mode 100644 index 000000000000..0b52c3669d07 --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/check_pathlib.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from pathlib import Path, PureWindowsPath + +if Path("asdf") == Path("asdf"): + ... + +# https://github.com/python/typeshed/issues/10661 +# Provide a true positive error when comparing Path to str +# mypy should report a comparison-overlap error with --strict-equality, +# and pyright should report a reportUnnecessaryComparison error +if Path("asdf") == "asdf": # type: ignore + ... + +# Errors on comparison here are technically false positives. However, this comparison is a little +# interesting: it can never hold true on Posix, but could hold true on Windows. We should experiment +# with more accurate __new__, such that we only get an error for such comparisons on platforms +# where they can never hold true. +if PureWindowsPath("asdf") == Path("asdf"): # type: ignore + ... diff --git a/mypy/typeshed/stdlib/@tests/test_cases/check_re.py b/mypy/typeshed/stdlib/@tests/test_cases/check_re.py new file mode 100644 index 000000000000..b6ab2b0d59d2 --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/check_re.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import mmap +import re +import typing as t +from typing_extensions import assert_type + + +def check_search(str_pat: re.Pattern[str], bytes_pat: re.Pattern[bytes]) -> None: + assert_type(str_pat.search("x"), t.Optional[t.Match[str]]) + assert_type(bytes_pat.search(b"x"), t.Optional[t.Match[bytes]]) + assert_type(bytes_pat.search(bytearray(b"x")), t.Optional[t.Match[bytes]]) + assert_type(bytes_pat.search(mmap.mmap(0, 10)), t.Optional[t.Match[bytes]]) + + +def check_search_with_AnyStr(pattern: re.Pattern[t.AnyStr], string: t.AnyStr) -> re.Match[t.AnyStr]: + """See issue #9591""" + match = pattern.search(string) + if match is None: + raise ValueError(f"'{string!r}' does not match {pattern!r}") + return match + + +def check_no_ReadableBuffer_false_negatives() -> None: + re.compile("foo").search(bytearray(b"foo")) # type: ignore + re.compile("foo").search(mmap.mmap(0, 10)) # type: ignore diff --git a/mypy/typeshed/stdlib/@tests/test_cases/check_sqlite3.py b/mypy/typeshed/stdlib/@tests/test_cases/check_sqlite3.py new file mode 100644 index 000000000000..3ec47ceccb90 --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/check_sqlite3.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import sqlite3 +from typing_extensions import assert_type + + +class MyConnection(sqlite3.Connection): + pass + + +# Default return-type is Connection. +assert_type(sqlite3.connect(":memory:"), sqlite3.Connection) + +# Providing an alternate factory changes the return-type. +assert_type(sqlite3.connect(":memory:", factory=MyConnection), MyConnection) + +# Provides a true positive error. When checking the connect() function, +# mypy should report an arg-type error for the factory argument. +with sqlite3.connect(":memory:", factory=None) as con: # type: ignore + pass + +# The Connection class also accepts a `factory` arg but it does not affect +# the return-type. This use case is not idiomatic--connections should be +# established using the `connect()` function, not directly (as shown here). +assert_type(sqlite3.Connection(":memory:", factory=None), sqlite3.Connection) +assert_type(sqlite3.Connection(":memory:", factory=MyConnection), sqlite3.Connection) diff --git a/mypy/typeshed/stdlib/@tests/test_cases/check_tarfile.py b/mypy/typeshed/stdlib/@tests/test_cases/check_tarfile.py new file mode 100644 index 000000000000..54510a3d7626 --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/check_tarfile.py @@ -0,0 +1,13 @@ +import tarfile + +with tarfile.open("test.tar.xz", "w:xz") as tar: + pass + +# Test with valid preset values +tarfile.open("test.tar.xz", "w:xz", preset=0) +tarfile.open("test.tar.xz", "w:xz", preset=5) +tarfile.open("test.tar.xz", "w:xz", preset=9) + +# Test with invalid preset values +tarfile.open("test.tar.xz", "w:xz", preset=-1) # type: ignore +tarfile.open("test.tar.xz", "w:xz", preset=10) # type: ignore diff --git a/mypy/typeshed/stdlib/@tests/test_cases/check_tempfile.py b/mypy/typeshed/stdlib/@tests/test_cases/check_tempfile.py new file mode 100644 index 000000000000..c259c192a140 --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/check_tempfile.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import io +import sys +from tempfile import TemporaryFile, _TemporaryFileWrapper +from typing_extensions import assert_type + +if sys.platform == "win32": + assert_type(TemporaryFile(), _TemporaryFileWrapper[bytes]) + assert_type(TemporaryFile("w+"), _TemporaryFileWrapper[str]) + assert_type(TemporaryFile("w+b"), _TemporaryFileWrapper[bytes]) + assert_type(TemporaryFile("wb"), _TemporaryFileWrapper[bytes]) + assert_type(TemporaryFile("rb"), _TemporaryFileWrapper[bytes]) + assert_type(TemporaryFile("wb", 0), _TemporaryFileWrapper[bytes]) + assert_type(TemporaryFile(mode="w+"), _TemporaryFileWrapper[str]) + assert_type(TemporaryFile(mode="w+b"), _TemporaryFileWrapper[bytes]) + assert_type(TemporaryFile(mode="wb"), _TemporaryFileWrapper[bytes]) + assert_type(TemporaryFile(mode="rb"), _TemporaryFileWrapper[bytes]) + assert_type(TemporaryFile(buffering=0), _TemporaryFileWrapper[bytes]) +else: + assert_type(TemporaryFile(), io.BufferedRandom) + assert_type(TemporaryFile("w+"), io.TextIOWrapper) + assert_type(TemporaryFile("w+b"), io.BufferedRandom) + assert_type(TemporaryFile("wb"), io.BufferedWriter) + assert_type(TemporaryFile("rb"), io.BufferedReader) + assert_type(TemporaryFile("wb", 0), io.FileIO) + assert_type(TemporaryFile(mode="w+"), io.TextIOWrapper) + assert_type(TemporaryFile(mode="w+b"), io.BufferedRandom) + assert_type(TemporaryFile(mode="wb"), io.BufferedWriter) + assert_type(TemporaryFile(mode="rb"), io.BufferedReader) + assert_type(TemporaryFile(buffering=0), io.FileIO) diff --git a/mypy/typeshed/stdlib/@tests/test_cases/check_threading.py b/mypy/typeshed/stdlib/@tests/test_cases/check_threading.py new file mode 100644 index 000000000000..eddfc2549a64 --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/check_threading.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +import _threading_local +import threading + +loc = threading.local() +loc.foo = 42 +del loc.foo +loc.baz = ["spam", "eggs"] +del loc.baz + +l2 = _threading_local.local() +l2.asdfasdf = 56 +del l2.asdfasdf diff --git a/mypy/typeshed/stdlib/@tests/test_cases/check_tkinter.py b/mypy/typeshed/stdlib/@tests/test_cases/check_tkinter.py new file mode 100644 index 000000000000..befac6697519 --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/check_tkinter.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +import tkinter +import traceback +import types + + +def custom_handler(exc: type[BaseException], val: BaseException, tb: types.TracebackType | None) -> None: + print("oh no") + + +root = tkinter.Tk() +root.report_callback_exception = traceback.print_exception +root.report_callback_exception = custom_handler + + +def foo(x: int, y: str) -> None: + pass + + +root.after(1000, foo, 10, "lol") +root.after(1000, foo, 10, 10) # type: ignore + + +# Font size must be integer +label = tkinter.Label() +label.config(font=("", 12)) +label.config(font=("", 12.34)) # type: ignore +label.config(font=("", 12, "bold")) +label.config(font=("", 12.34, "bold")) # type: ignore diff --git a/mypy/typeshed/stdlib/@tests/test_cases/check_unittest.py b/mypy/typeshed/stdlib/@tests/test_cases/check_unittest.py new file mode 100644 index 000000000000..40c6efaa8ca0 --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/check_unittest.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +import unittest +from collections.abc import Iterator, Mapping +from datetime import datetime, timedelta +from decimal import Decimal +from fractions import Fraction +from typing import TypedDict +from typing_extensions import assert_type +from unittest.mock import MagicMock, Mock, patch + +case = unittest.TestCase() + +### +# Tests for assertAlmostEqual +### + +case.assertAlmostEqual(1, 2.4) +case.assertAlmostEqual(2.4, 2.41) +case.assertAlmostEqual(Fraction(49, 50), Fraction(48, 50)) +case.assertAlmostEqual(3.14, complex(5, 6)) +case.assertAlmostEqual(datetime(1999, 1, 2), datetime(1999, 1, 2, microsecond=1), delta=timedelta(hours=1)) +case.assertAlmostEqual(datetime(1999, 1, 2), datetime(1999, 1, 2, microsecond=1), None, "foo", timedelta(hours=1)) +case.assertAlmostEqual(Decimal("1.1"), Decimal("1.11")) +case.assertAlmostEqual(2.4, 2.41, places=8) +case.assertAlmostEqual(2.4, 2.41, delta=0.02) +case.assertAlmostEqual(2.4, 2.41, None, "foo", 0.02) + +case.assertAlmostEqual(2.4, 2.41, places=9, delta=0.02) # type: ignore +case.assertAlmostEqual("foo", "bar") # type: ignore +case.assertAlmostEqual(datetime(1999, 1, 2), datetime(1999, 1, 2, microsecond=1)) # type: ignore +case.assertAlmostEqual(Decimal("0.4"), Fraction(1, 2)) # type: ignore +case.assertAlmostEqual(complex(2, 3), Decimal("0.9")) # type: ignore + +### +# Tests for assertNotAlmostEqual +### + +case.assertAlmostEqual(1, 2.4) +case.assertNotAlmostEqual(Fraction(49, 50), Fraction(48, 50)) +case.assertAlmostEqual(3.14, complex(5, 6)) +case.assertNotAlmostEqual(datetime(1999, 1, 2), datetime(1999, 1, 2, microsecond=1), delta=timedelta(hours=1)) +case.assertNotAlmostEqual(datetime(1999, 1, 2), datetime(1999, 1, 2, microsecond=1), None, "foo", timedelta(hours=1)) + +case.assertNotAlmostEqual(2.4, 2.41, places=9, delta=0.02) # type: ignore +case.assertNotAlmostEqual("foo", "bar") # type: ignore +case.assertNotAlmostEqual(datetime(1999, 1, 2), datetime(1999, 1, 2, microsecond=1)) # type: ignore +case.assertNotAlmostEqual(Decimal("0.4"), Fraction(1, 2)) # type: ignore +case.assertNotAlmostEqual(complex(2, 3), Decimal("0.9")) # type: ignore + +### +# Tests for assertGreater +### + + +class Spam: + def __lt__(self, other: object) -> bool: + return True + + +class Eggs: + def __gt__(self, other: object) -> bool: + return True + + +class Ham: + def __lt__(self, other: Ham) -> bool: + if not isinstance(other, Ham): + return NotImplemented + return True + + +class Bacon: + def __gt__(self, other: Bacon) -> bool: + if not isinstance(other, Bacon): + return NotImplemented + return True + + +case.assertGreater(5.8, 3) +case.assertGreater(Decimal("4.5"), Fraction(3, 2)) +case.assertGreater(Fraction(3, 2), 0.9) +case.assertGreater(Eggs(), object()) +case.assertGreater(object(), Spam()) +case.assertGreater(Ham(), Ham()) +case.assertGreater(Bacon(), Bacon()) + +case.assertGreater(object(), object()) # type: ignore +case.assertGreater(datetime(1999, 1, 2), 1) # type: ignore +case.assertGreater(Spam(), Eggs()) # type: ignore +case.assertGreater(Ham(), Bacon()) # type: ignore +case.assertGreater(Bacon(), Ham()) # type: ignore + + +### +# Tests for assertDictEqual +### + + +class TD1(TypedDict): + x: int + y: str + + +class TD2(TypedDict): + a: bool + b: bool + + +class MyMapping(Mapping[str, int]): + def __getitem__(self, __key: str) -> int: + return 42 + + def __iter__(self) -> Iterator[str]: + return iter([]) + + def __len__(self) -> int: + return 0 + + +td1: TD1 = {"x": 1, "y": "foo"} +td2: TD2 = {"a": True, "b": False} +m = MyMapping() + +case.assertDictEqual({}, {}) +case.assertDictEqual({"x": 1, "y": 2}, {"x": 1, "y": 2}) +case.assertDictEqual({"x": 1, "y": "foo"}, {"y": "foo", "x": 1}) +case.assertDictEqual({"x": 1}, {}) +case.assertDictEqual({}, {"x": 1}) +case.assertDictEqual({1: "x"}, {"y": 222}) +case.assertDictEqual({1: "x"}, td1) +case.assertDictEqual(td1, {1: "x"}) +case.assertDictEqual(td1, td2) + +case.assertDictEqual(1, {}) # type: ignore +case.assertDictEqual({}, 1) # type: ignore + +# These should fail, but don't due to TypedDict limitations: +# case.assertDictEqual(m, {"": 0}) # xtype: ignore +# case.assertDictEqual({"": 0}, m) # xtype: ignore + +### +# Tests for mock.patch +### + + +@patch("sys.exit") +def f_default_new(i: int, mock: MagicMock) -> str: + return "asdf" + + +@patch("sys.exit", new=42) +def f_explicit_new(i: int) -> str: + return "asdf" + + +assert_type(f_default_new(1), str) +f_default_new("a") # Not an error due to ParamSpec limitations +assert_type(f_explicit_new(1), str) +f_explicit_new("a") # type: ignore[arg-type] + + +@patch("sys.exit", new=Mock()) +class TestXYZ(unittest.TestCase): + attr: int = 5 + + @staticmethod + def method() -> int: + return 123 + + +assert_type(TestXYZ.attr, int) +assert_type(TestXYZ.method(), int) diff --git a/mypy/typeshed/stdlib/@tests/test_cases/check_xml.py b/mypy/typeshed/stdlib/@tests/test_cases/check_xml.py new file mode 100644 index 000000000000..b485dac8dc29 --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/check_xml.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import sys +from typing_extensions import assert_type +from xml.dom.minidom import Document + +document = Document() + +assert_type(document.toxml(), str) +assert_type(document.toxml(encoding=None), str) +assert_type(document.toxml(encoding="UTF8"), bytes) +assert_type(document.toxml("UTF8"), bytes) +if sys.version_info >= (3, 9): + assert_type(document.toxml(standalone=True), str) + assert_type(document.toxml("UTF8", True), bytes) + assert_type(document.toxml(encoding="UTF8", standalone=True), bytes) + + +# Because toprettyxml can mix positional and keyword variants of the "encoding" argument, which +# determines the return type, the proper stub typing isn't immediately obvious. This is a basic +# brute-force sanity check. +# Test cases like toxml +assert_type(document.toprettyxml(), str) +assert_type(document.toprettyxml(encoding=None), str) +assert_type(document.toprettyxml(encoding="UTF8"), bytes) +if sys.version_info >= (3, 9): + assert_type(document.toprettyxml(standalone=True), str) + assert_type(document.toprettyxml(encoding="UTF8", standalone=True), bytes) +# Test cases unique to toprettyxml +assert_type(document.toprettyxml(" "), str) +assert_type(document.toprettyxml(" ", "\r\n"), str) +assert_type(document.toprettyxml(" ", "\r\n", "UTF8"), bytes) +if sys.version_info >= (3, 9): + assert_type(document.toprettyxml(" ", "\r\n", "UTF8", True), bytes) + assert_type(document.toprettyxml(" ", "\r\n", standalone=True), str) diff --git a/mypy/typeshed/stdlib/@tests/test_cases/collections/check_defaultdict-py39.py b/mypy/typeshed/stdlib/@tests/test_cases/collections/check_defaultdict-py39.py new file mode 100644 index 000000000000..9fe5ec8076ce --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/collections/check_defaultdict-py39.py @@ -0,0 +1,69 @@ +""" +Tests for `defaultdict.__or__` and `defaultdict.__ror__`. +These methods were only added in py39. +""" + +from __future__ import annotations + +import os +import sys +from collections import defaultdict +from typing import Mapping, TypeVar, Union +from typing_extensions import Self, assert_type + +_KT = TypeVar("_KT") +_VT = TypeVar("_VT") + + +if sys.version_info >= (3, 9): + + class CustomDefaultDictSubclass(defaultdict[_KT, _VT]): + pass + + class CustomMappingWithDunderOr(Mapping[_KT, _VT]): + def __or__(self, other: Mapping[_KT, _VT]) -> dict[_KT, _VT]: + return {} + + def __ror__(self, other: Mapping[_KT, _VT]) -> dict[_KT, _VT]: + return {} + + def __ior__(self, other: Mapping[_KT, _VT]) -> Self: + return self + + def test_defaultdict_dot_or( + a: defaultdict[int, int], + b: CustomDefaultDictSubclass[int, int], + c: defaultdict[str, str], + d: Mapping[int, int], + e: CustomMappingWithDunderOr[str, str], + ) -> None: + assert_type(a | b, defaultdict[int, int]) + + # In contrast to `dict.__or__`, `defaultdict.__or__` returns `Self` if called on a subclass of `defaultdict`: + assert_type(b | a, CustomDefaultDictSubclass[int, int]) + + assert_type(a | c, defaultdict[Union[int, str], Union[int, str]]) + + # arbitrary mappings are not accepted by `defaultdict.__or__`; + # it has to be a subclass of `dict` + a | d # type: ignore + + # but Mappings such as `os._Environ` or `CustomMappingWithDunderOr`, + # which define `__ror__` methods that accept `dict`, are fine + # (`os._Environ.__(r)or__` always returns `dict`, even if a `defaultdict` is passed): + assert_type(a | os.environ, dict[Union[str, int], Union[str, int]]) + assert_type(os.environ | a, dict[Union[str, int], Union[str, int]]) + + assert_type(c | os.environ, dict[str, str]) + assert_type(c | e, dict[str, str]) + + assert_type(os.environ | c, dict[str, str]) + assert_type(e | c, dict[str, str]) + + e |= c + e |= a # type: ignore + + # TODO: this test passes mypy, but fails pyright for some reason: + # c |= e + + c |= a # type: ignore diff --git a/mypy/typeshed/stdlib/@tests/test_cases/email/check_message.py b/mypy/typeshed/stdlib/@tests/test_cases/email/check_message.py new file mode 100644 index 000000000000..a9b43e23fb27 --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/email/check_message.py @@ -0,0 +1,6 @@ +from email.headerregistry import Address +from email.message import EmailMessage + +msg = EmailMessage() +msg["To"] = "receiver@example.com" +msg["From"] = Address("Sender Name", "sender", "example.com") diff --git a/mypy/typeshed/stdlib/@tests/test_cases/itertools/check_itertools_recipes.py b/mypy/typeshed/stdlib/@tests/test_cases/itertools/check_itertools_recipes.py new file mode 100644 index 000000000000..c45ffee28cee --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/itertools/check_itertools_recipes.py @@ -0,0 +1,410 @@ +"""Type-annotated versions of the recipes from the itertools docs. + +These are all meant to be examples of idiomatic itertools usage, +so they should all type-check without error. +""" + +from __future__ import annotations + +import collections +import math +import operator +import sys +from itertools import chain, combinations, count, cycle, filterfalse, groupby, islice, product, repeat, starmap, tee, zip_longest +from typing import ( + Any, + Callable, + Collection, + Hashable, + Iterable, + Iterator, + Literal, + Sequence, + Tuple, + Type, + TypeVar, + Union, + overload, +) +from typing_extensions import TypeAlias, TypeVarTuple, Unpack + +_T = TypeVar("_T") +_T1 = TypeVar("_T1") +_T2 = TypeVar("_T2") +_HashableT = TypeVar("_HashableT", bound=Hashable) +_Ts = TypeVarTuple("_Ts") + + +def take(n: int, iterable: Iterable[_T]) -> list[_T]: + "Return first n items of the iterable as a list" + return list(islice(iterable, n)) + + +# Note: the itertools docs uses the parameter name "iterator", +# but the function actually accepts any iterable +# as its second argument +def prepend(value: _T1, iterator: Iterable[_T2]) -> Iterator[_T1 | _T2]: + "Prepend a single value in front of an iterator" + # prepend(1, [2, 3, 4]) --> 1 2 3 4 + return chain([value], iterator) + + +def tabulate(function: Callable[[int], _T], start: int = 0) -> Iterator[_T]: + "Return function(0), function(1), ..." + return map(function, count(start)) + + +def repeatfunc(func: Callable[[Unpack[_Ts]], _T], times: int | None = None, *args: Unpack[_Ts]) -> Iterator[_T]: + """Repeat calls to func with specified arguments. + + Example: repeatfunc(random.random) + """ + if times is None: + return starmap(func, repeat(args)) + return starmap(func, repeat(args, times)) + + +def flatten(list_of_lists: Iterable[Iterable[_T]]) -> Iterator[_T]: + "Flatten one level of nesting" + return chain.from_iterable(list_of_lists) + + +def ncycles(iterable: Iterable[_T], n: int) -> Iterator[_T]: + "Returns the sequence elements n times" + return chain.from_iterable(repeat(tuple(iterable), n)) + + +def tail(n: int, iterable: Iterable[_T]) -> Iterator[_T]: + "Return an iterator over the last n items" + # tail(3, 'ABCDEFG') --> E F G + return iter(collections.deque(iterable, maxlen=n)) + + +# This function *accepts* any iterable, +# but it only *makes sense* to use it with an iterator +def consume(iterator: Iterator[object], n: int | None = None) -> None: + "Advance the iterator n-steps ahead. If n is None, consume entirely." + # Use functions that consume iterators at C speed. + if n is None: + # feed the entire iterator into a zero-length deque + collections.deque(iterator, maxlen=0) + else: + # advance to the empty slice starting at position n + next(islice(iterator, n, n), None) + + +@overload +def nth(iterable: Iterable[_T], n: int, default: None = None) -> _T | None: ... + + +@overload +def nth(iterable: Iterable[_T], n: int, default: _T1) -> _T | _T1: ... + + +def nth(iterable: Iterable[object], n: int, default: object = None) -> object: + "Returns the nth item or a default value" + return next(islice(iterable, n, None), default) + + +@overload +def quantify(iterable: Iterable[object]) -> int: ... + + +@overload +def quantify(iterable: Iterable[_T], pred: Callable[[_T], bool]) -> int: ... + + +def quantify(iterable: Iterable[object], pred: Callable[[Any], bool] = bool) -> int: + "Given a predicate that returns True or False, count the True results." + return sum(map(pred, iterable)) + + +@overload +def first_true( + iterable: Iterable[_T], default: Literal[False] = False, pred: Callable[[_T], bool] | None = None +) -> _T | Literal[False]: ... + + +@overload +def first_true(iterable: Iterable[_T], default: _T1, pred: Callable[[_T], bool] | None = None) -> _T | _T1: ... + + +def first_true(iterable: Iterable[object], default: object = False, pred: Callable[[Any], bool] | None = None) -> object: + """Returns the first true value in the iterable. + If no true value is found, returns *default* + If *pred* is not None, returns the first item + for which pred(item) is true. + """ + # first_true([a,b,c], x) --> a or b or c or x + # first_true([a,b], x, f) --> a if f(a) else b if f(b) else x + return next(filter(pred, iterable), default) + + +_ExceptionOrExceptionTuple: TypeAlias = Union[Type[BaseException], Tuple[Type[BaseException], ...]] + + +@overload +def iter_except(func: Callable[[], _T], exception: _ExceptionOrExceptionTuple, first: None = None) -> Iterator[_T]: ... + + +@overload +def iter_except( + func: Callable[[], _T], exception: _ExceptionOrExceptionTuple, first: Callable[[], _T1] +) -> Iterator[_T | _T1]: ... + + +def iter_except( + func: Callable[[], object], exception: _ExceptionOrExceptionTuple, first: Callable[[], object] | None = None +) -> Iterator[object]: + """Call a function repeatedly until an exception is raised. + Converts a call-until-exception interface to an iterator interface. + Like builtins.iter(func, sentinel) but uses an exception instead + of a sentinel to end the loop. + Examples: + iter_except(functools.partial(heappop, h), IndexError) # priority queue iterator + iter_except(d.popitem, KeyError) # non-blocking dict iterator + iter_except(d.popleft, IndexError) # non-blocking deque iterator + iter_except(q.get_nowait, Queue.Empty) # loop over a producer Queue + iter_except(s.pop, KeyError) # non-blocking set iterator + """ + try: + if first is not None: + yield first() # For database APIs needing an initial cast to db.first() + while True: + yield func() + except exception: + pass + + +def sliding_window(iterable: Iterable[_T], n: int) -> Iterator[tuple[_T, ...]]: + # sliding_window('ABCDEFG', 4) --> ABCD BCDE CDEF DEFG + it = iter(iterable) + window = collections.deque(islice(it, n - 1), maxlen=n) + for x in it: + window.append(x) + yield tuple(window) + + +def roundrobin(*iterables: Iterable[_T]) -> Iterator[_T]: + "roundrobin('ABC', 'D', 'EF') --> A D E B F C" + # Recipe credited to George Sakkis + num_active = len(iterables) + nexts: Iterator[Callable[[], _T]] = cycle(iter(it).__next__ for it in iterables) + while num_active: + try: + for next in nexts: + yield next() + except StopIteration: + # Remove the iterator we just exhausted from the cycle. + num_active -= 1 + nexts = cycle(islice(nexts, num_active)) + + +def partition(pred: Callable[[_T], bool], iterable: Iterable[_T]) -> tuple[Iterator[_T], Iterator[_T]]: + """Partition entries into false entries and true entries. + If *pred* is slow, consider wrapping it with functools.lru_cache(). + """ + # partition(is_odd, range(10)) --> 0 2 4 6 8 and 1 3 5 7 9 + t1, t2 = tee(iterable) + return filterfalse(pred, t1), filter(pred, t2) + + +def subslices(seq: Sequence[_T]) -> Iterator[Sequence[_T]]: + "Return all contiguous non-empty subslices of a sequence" + # subslices('ABCD') --> A AB ABC ABCD B BC BCD C CD D + slices = starmap(slice, combinations(range(len(seq) + 1), 2)) + return map(operator.getitem, repeat(seq), slices) + + +def before_and_after(predicate: Callable[[_T], bool], it: Iterable[_T]) -> tuple[Iterator[_T], Iterator[_T]]: + """Variant of takewhile() that allows complete + access to the remainder of the iterator. + >>> it = iter('ABCdEfGhI') + >>> all_upper, remainder = before_and_after(str.isupper, it) + >>> ''.join(all_upper) + 'ABC' + >>> ''.join(remainder) # takewhile() would lose the 'd' + 'dEfGhI' + Note that the first iterator must be fully + consumed before the second iterator can + generate valid results. + """ + it = iter(it) + transition: list[_T] = [] + + def true_iterator() -> Iterator[_T]: + for elem in it: + if predicate(elem): + yield elem + else: + transition.append(elem) + return + + def remainder_iterator() -> Iterator[_T]: + yield from transition + yield from it + + return true_iterator(), remainder_iterator() + + +@overload +def unique_everseen(iterable: Iterable[_HashableT], key: None = None) -> Iterator[_HashableT]: ... + + +@overload +def unique_everseen(iterable: Iterable[_T], key: Callable[[_T], Hashable]) -> Iterator[_T]: ... + + +def unique_everseen(iterable: Iterable[_T], key: Callable[[_T], Hashable] | None = None) -> Iterator[_T]: + "List unique elements, preserving order. Remember all elements ever seen." + # unique_everseen('AAAABBBCCDAABBB') --> A B C D + # unique_everseen('ABBcCAD', str.lower) --> A B c D + seen: set[Hashable] = set() + if key is None: + for element in filterfalse(seen.__contains__, iterable): + seen.add(element) + yield element + # For order preserving deduplication, + # a faster but non-lazy solution is: + # yield from dict.fromkeys(iterable) + else: + for element in iterable: + k = key(element) + if k not in seen: + seen.add(k) + yield element + # For use cases that allow the last matching element to be returned, + # a faster but non-lazy solution is: + # t1, t2 = tee(iterable) + # yield from dict(zip(map(key, t1), t2)).values() + + +# Slightly adapted from the docs recipe; a one-liner was a bit much for pyright +def unique_justseen(iterable: Iterable[_T], key: Callable[[_T], bool] | None = None) -> Iterator[_T]: + "List unique elements, preserving order. Remember only the element just seen." + # unique_justseen('AAAABBBCCDAABBB') --> A B C D A B + # unique_justseen('ABBcCAD', str.lower) --> A B c A D + g: groupby[_T | bool, _T] = groupby(iterable, key) + return map(next, map(operator.itemgetter(1), g)) + + +def powerset(iterable: Iterable[_T]) -> Iterator[tuple[_T, ...]]: + "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)" + s = list(iterable) + return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)) + + +def polynomial_derivative(coefficients: Sequence[float]) -> list[float]: + """Compute the first derivative of a polynomial. + f(x) = x³ -4x² -17x + 60 + f'(x) = 3x² -8x -17 + """ + # polynomial_derivative([1, -4, -17, 60]) -> [3, -8, -17] + n = len(coefficients) + powers = reversed(range(1, n)) + return list(map(operator.mul, coefficients, powers)) + + +def nth_combination(iterable: Iterable[_T], r: int, index: int) -> tuple[_T, ...]: + "Equivalent to list(combinations(iterable, r))[index]" + pool = tuple(iterable) + n = len(pool) + c = math.comb(n, r) + if index < 0: + index += c + if index < 0 or index >= c: + raise IndexError + result: list[_T] = [] + while r: + c, n, r = c * r // n, n - 1, r - 1 + while index >= c: + index -= c + c, n = c * (n - r) // n, n - 1 + result.append(pool[-1 - n]) + return tuple(result) + + +if sys.version_info >= (3, 10): + + @overload + def grouper( + iterable: Iterable[_T], n: int, *, incomplete: Literal["fill"] = "fill", fillvalue: None = None + ) -> Iterator[tuple[_T | None, ...]]: ... + + @overload + def grouper( + iterable: Iterable[_T], n: int, *, incomplete: Literal["fill"] = "fill", fillvalue: _T1 + ) -> Iterator[tuple[_T | _T1, ...]]: ... + + @overload + def grouper( + iterable: Iterable[_T], n: int, *, incomplete: Literal["strict", "ignore"], fillvalue: None = None + ) -> Iterator[tuple[_T, ...]]: ... + + def grouper( + iterable: Iterable[object], n: int, *, incomplete: Literal["fill", "strict", "ignore"] = "fill", fillvalue: object = None + ) -> Iterator[tuple[object, ...]]: + "Collect data into non-overlapping fixed-length chunks or blocks" + # grouper('ABCDEFG', 3, fillvalue='x') --> ABC DEF Gxx + # grouper('ABCDEFG', 3, incomplete='strict') --> ABC DEF ValueError + # grouper('ABCDEFG', 3, incomplete='ignore') --> ABC DEF + args = [iter(iterable)] * n + if incomplete == "fill": + return zip_longest(*args, fillvalue=fillvalue) + if incomplete == "strict": + return zip(*args, strict=True) + if incomplete == "ignore": + return zip(*args) + else: + raise ValueError("Expected fill, strict, or ignore") + + def transpose(it: Iterable[Iterable[_T]]) -> Iterator[tuple[_T, ...]]: + "Swap the rows and columns of the input." + # transpose([(1, 2, 3), (11, 22, 33)]) --> (1, 11) (2, 22) (3, 33) + return zip(*it, strict=True) + + +if sys.version_info >= (3, 12): + from itertools import batched + + def sum_of_squares(it: Iterable[float]) -> float: + "Add up the squares of the input values." + # sum_of_squares([10, 20, 30]) -> 1400 + return math.sumprod(*tee(it)) + + def convolve(signal: Iterable[float], kernel: Iterable[float]) -> Iterator[float]: + """Discrete linear convolution of two iterables. + The kernel is fully consumed before the calculations begin. + The signal is consumed lazily and can be infinite. + Convolutions are mathematically commutative. + If the signal and kernel are swapped, + the output will be the same. + Article: https://betterexplained.com/articles/intuitive-convolution/ + Video: https://www.youtube.com/watch?v=KuXjwB4LzSA + """ + # convolve(data, [0.25, 0.25, 0.25, 0.25]) --> Moving average (blur) + # convolve(data, [1/2, 0, -1/2]) --> 1st derivative estimate + # convolve(data, [1, -2, 1]) --> 2nd derivative estimate + kernel = tuple(kernel)[::-1] + n = len(kernel) + padded_signal = chain(repeat(0, n - 1), signal, repeat(0, n - 1)) + windowed_signal = sliding_window(padded_signal, n) + return map(math.sumprod, repeat(kernel), windowed_signal) + + def polynomial_eval(coefficients: Sequence[float], x: float) -> float: + """Evaluate a polynomial at a specific value. + Computes with better numeric stability than Horner's method. + """ + # Evaluate x³ -4x² -17x + 60 at x = 2.5 + # polynomial_eval([1, -4, -17, 60], x=2.5) --> 8.125 + n = len(coefficients) + if not n: + return type(x)(0) + powers = map(pow, repeat(x), reversed(range(n))) + return math.sumprod(coefficients, powers) + + def matmul(m1: Sequence[Collection[float]], m2: Sequence[Collection[float]]) -> Iterator[tuple[float, ...]]: + "Multiply two matrices." + # matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]) --> (49, 80), (41, 60) + n = len(m2[0]) + return batched(starmap(math.sumprod, product(m1, transpose(m2))), n) diff --git a/mypy/typeshed/stdlib/@tests/test_cases/typing/check_MutableMapping.py b/mypy/typeshed/stdlib/@tests/test_cases/typing/check_MutableMapping.py new file mode 100644 index 000000000000..10a33ffb83d5 --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/typing/check_MutableMapping.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from typing import Any, Union +from typing_extensions import assert_type + + +def check_setdefault_method() -> None: + d: dict[int, str] = {} + d2: dict[int, str | None] = {} + d3: dict[int, Any] = {} + + d.setdefault(1) # type: ignore + assert_type(d.setdefault(1, "x"), str) + assert_type(d2.setdefault(1), Union[str, None]) + assert_type(d2.setdefault(1, None), Union[str, None]) + assert_type(d2.setdefault(1, "x"), Union[str, None]) + assert_type(d3.setdefault(1), Union[Any, None]) + assert_type(d3.setdefault(1, "x"), Any) diff --git a/mypy/typeshed/stdlib/@tests/test_cases/typing/check_all.py b/mypy/typeshed/stdlib/@tests/test_cases/typing/check_all.py new file mode 100644 index 000000000000..44eb548e04a9 --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/typing/check_all.py @@ -0,0 +1,14 @@ +# pyright: reportWildcardImportFromLibrary=false +""" +This tests that star imports work when using "all += " syntax. +""" +from __future__ import annotations + +import sys +from typing import * +from zipfile import * + +if sys.version_info >= (3, 9): + x: Annotated[int, 42] + +p: Path diff --git a/mypy/typeshed/stdlib/@tests/test_cases/typing/check_regression_issue_9296.py b/mypy/typeshed/stdlib/@tests/test_cases/typing/check_regression_issue_9296.py new file mode 100644 index 000000000000..34c5631aeb1a --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/typing/check_regression_issue_9296.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +import typing as t + +KT = t.TypeVar("KT") + + +class MyKeysView(t.KeysView[KT]): + pass + + +d: dict[t.Any, t.Any] = {} +dict_keys = type(d.keys()) + +# This should not cause an error like `Member "register" is unknown`: +MyKeysView.register(dict_keys) diff --git a/mypy/typeshed/stdlib/@tests/test_cases/typing/check_typing_io.py b/mypy/typeshed/stdlib/@tests/test_cases/typing/check_typing_io.py new file mode 100644 index 000000000000..67f16dc91765 --- /dev/null +++ b/mypy/typeshed/stdlib/@tests/test_cases/typing/check_typing_io.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +import mmap +from typing import IO, AnyStr + + +def check_write(io_bytes: IO[bytes], io_str: IO[str], io_anystr: IO[AnyStr], any_str: AnyStr, buf: mmap.mmap) -> None: + io_bytes.write(b"") + io_bytes.write(buf) + io_bytes.write("") # type: ignore + io_bytes.write(any_str) # type: ignore + + io_str.write(b"") # type: ignore + io_str.write(buf) # type: ignore + io_str.write("") + io_str.write(any_str) # type: ignore + + io_anystr.write(b"") # type: ignore + io_anystr.write(buf) # type: ignore + io_anystr.write("") # type: ignore + io_anystr.write(any_str) diff --git a/mypy/typeshed/stdlib/_typeshed/importlib.pyi b/mypy/typeshed/stdlib/_typeshed/importlib.pyi new file mode 100644 index 000000000000..a4e56cdaff62 --- /dev/null +++ b/mypy/typeshed/stdlib/_typeshed/importlib.pyi @@ -0,0 +1,18 @@ +# Implicit protocols used in importlib. +# We intentionally omit deprecated and optional methods. + +from collections.abc import Sequence +from importlib.machinery import ModuleSpec +from types import ModuleType +from typing import Protocol + +__all__ = ["LoaderProtocol", "MetaPathFinderProtocol", "PathEntryFinderProtocol"] + +class LoaderProtocol(Protocol): + def load_module(self, fullname: str, /) -> ModuleType: ... + +class MetaPathFinderProtocol(Protocol): + def find_spec(self, fullname: str, path: Sequence[str] | None, target: ModuleType | None = ..., /) -> ModuleSpec | None: ... + +class PathEntryFinderProtocol(Protocol): + def find_spec(self, fullname: str, target: ModuleType | None = ..., /) -> ModuleSpec | None: ... diff --git a/mypy/typeshed/stdlib/ast.pyi b/mypy/typeshed/stdlib/ast.pyi index 4ab325b5baa7..2525c3642a6f 100644 --- a/mypy/typeshed/stdlib/ast.pyi +++ b/mypy/typeshed/stdlib/ast.pyi @@ -181,82 +181,172 @@ class NodeTransformer(NodeVisitor): _T = _TypeVar("_T", bound=AST) -@overload -def parse( - source: str | ReadableBuffer, - filename: str | ReadableBuffer | os.PathLike[Any] = "", - mode: Literal["exec"] = "exec", - *, - type_comments: bool = False, - feature_version: None | int | tuple[int, int] = None, -) -> Module: ... -@overload -def parse( - source: str | ReadableBuffer, - filename: str | ReadableBuffer | os.PathLike[Any], - mode: Literal["eval"], - *, - type_comments: bool = False, - feature_version: None | int | tuple[int, int] = None, -) -> Expression: ... -@overload -def parse( - source: str | ReadableBuffer, - filename: str | ReadableBuffer | os.PathLike[Any], - mode: Literal["func_type"], - *, - type_comments: bool = False, - feature_version: None | int | tuple[int, int] = None, -) -> FunctionType: ... -@overload -def parse( - source: str | ReadableBuffer, - filename: str | ReadableBuffer | os.PathLike[Any], - mode: Literal["single"], - *, - type_comments: bool = False, - feature_version: None | int | tuple[int, int] = None, -) -> Interactive: ... -@overload -def parse( - source: str | ReadableBuffer, - *, - mode: Literal["eval"], - type_comments: bool = False, - feature_version: None | int | tuple[int, int] = None, -) -> Expression: ... -@overload -def parse( - source: str | ReadableBuffer, - *, - mode: Literal["func_type"], - type_comments: bool = False, - feature_version: None | int | tuple[int, int] = None, -) -> FunctionType: ... -@overload -def parse( - source: str | ReadableBuffer, - *, - mode: Literal["single"], - type_comments: bool = False, - feature_version: None | int | tuple[int, int] = None, -) -> Interactive: ... -@overload -def parse( - source: str | ReadableBuffer, - filename: str | ReadableBuffer | os.PathLike[Any] = "", - mode: str = "exec", - *, - type_comments: bool = False, - feature_version: None | int | tuple[int, int] = None, -) -> AST: ... +if sys.version_info >= (3, 13): + @overload + def parse( + source: str | ReadableBuffer, + filename: str | ReadableBuffer | os.PathLike[Any] = "", + mode: Literal["exec"] = "exec", + *, + type_comments: bool = False, + feature_version: None | int | tuple[int, int] = None, + optimize: Literal[-1, 0, 1, 2] = -1, + ) -> Module: ... + @overload + def parse( + source: str | ReadableBuffer, + filename: str | ReadableBuffer | os.PathLike[Any], + mode: Literal["eval"], + *, + type_comments: bool = False, + feature_version: None | int | tuple[int, int] = None, + optimize: Literal[-1, 0, 1, 2] = -1, + ) -> Expression: ... + @overload + def parse( + source: str | ReadableBuffer, + filename: str | ReadableBuffer | os.PathLike[Any], + mode: Literal["func_type"], + *, + type_comments: bool = False, + feature_version: None | int | tuple[int, int] = None, + optimize: Literal[-1, 0, 1, 2] = -1, + ) -> FunctionType: ... + @overload + def parse( + source: str | ReadableBuffer, + filename: str | ReadableBuffer | os.PathLike[Any], + mode: Literal["single"], + *, + type_comments: bool = False, + feature_version: None | int | tuple[int, int] = None, + optimize: Literal[-1, 0, 1, 2] = -1, + ) -> Interactive: ... + @overload + def parse( + source: str | ReadableBuffer, + *, + mode: Literal["eval"], + type_comments: bool = False, + feature_version: None | int | tuple[int, int] = None, + optimize: Literal[-1, 0, 1, 2] = -1, + ) -> Expression: ... + @overload + def parse( + source: str | ReadableBuffer, + *, + mode: Literal["func_type"], + type_comments: bool = False, + feature_version: None | int | tuple[int, int] = None, + optimize: Literal[-1, 0, 1, 2] = -1, + ) -> FunctionType: ... + @overload + def parse( + source: str | ReadableBuffer, + *, + mode: Literal["single"], + type_comments: bool = False, + feature_version: None | int | tuple[int, int] = None, + optimize: Literal[-1, 0, 1, 2] = -1, + ) -> Interactive: ... + @overload + def parse( + source: str | ReadableBuffer, + filename: str | ReadableBuffer | os.PathLike[Any] = "", + mode: str = "exec", + *, + type_comments: bool = False, + feature_version: None | int | tuple[int, int] = None, + optimize: Literal[-1, 0, 1, 2] = -1, + ) -> AST: ... + +else: + @overload + def parse( + source: str | ReadableBuffer, + filename: str | ReadableBuffer | os.PathLike[Any] = "", + mode: Literal["exec"] = "exec", + *, + type_comments: bool = False, + feature_version: None | int | tuple[int, int] = None, + ) -> Module: ... + @overload + def parse( + source: str | ReadableBuffer, + filename: str | ReadableBuffer | os.PathLike[Any], + mode: Literal["eval"], + *, + type_comments: bool = False, + feature_version: None | int | tuple[int, int] = None, + ) -> Expression: ... + @overload + def parse( + source: str | ReadableBuffer, + filename: str | ReadableBuffer | os.PathLike[Any], + mode: Literal["func_type"], + *, + type_comments: bool = False, + feature_version: None | int | tuple[int, int] = None, + ) -> FunctionType: ... + @overload + def parse( + source: str | ReadableBuffer, + filename: str | ReadableBuffer | os.PathLike[Any], + mode: Literal["single"], + *, + type_comments: bool = False, + feature_version: None | int | tuple[int, int] = None, + ) -> Interactive: ... + @overload + def parse( + source: str | ReadableBuffer, + *, + mode: Literal["eval"], + type_comments: bool = False, + feature_version: None | int | tuple[int, int] = None, + ) -> Expression: ... + @overload + def parse( + source: str | ReadableBuffer, + *, + mode: Literal["func_type"], + type_comments: bool = False, + feature_version: None | int | tuple[int, int] = None, + ) -> FunctionType: ... + @overload + def parse( + source: str | ReadableBuffer, + *, + mode: Literal["single"], + type_comments: bool = False, + feature_version: None | int | tuple[int, int] = None, + ) -> Interactive: ... + @overload + def parse( + source: str | ReadableBuffer, + filename: str | ReadableBuffer | os.PathLike[Any] = "", + mode: str = "exec", + *, + type_comments: bool = False, + feature_version: None | int | tuple[int, int] = None, + ) -> AST: ... if sys.version_info >= (3, 9): def unparse(ast_obj: AST) -> str: ... def copy_location(new_node: _T, old_node: AST) -> _T: ... -if sys.version_info >= (3, 9): +if sys.version_info >= (3, 13): + def dump( + node: AST, + annotate_fields: bool = True, + include_attributes: bool = False, + *, + indent: int | str | None = None, + show_empty: bool = False, + ) -> str: ... + +elif sys.version_info >= (3, 9): def dump( node: AST, annotate_fields: bool = True, include_attributes: bool = False, *, indent: int | str | None = None ) -> str: ... diff --git a/mypy/typeshed/stdlib/builtins.pyi b/mypy/typeshed/stdlib/builtins.pyi index 9e56c5430c52..4c47a0736e2e 100644 --- a/mypy/typeshed/stdlib/builtins.pyi +++ b/mypy/typeshed/stdlib/builtins.pyi @@ -31,7 +31,7 @@ from _typeshed import ( ) from collections.abc import Awaitable, Callable, Iterable, Iterator, MutableSet, Reversible, Set as AbstractSet, Sized from io import BufferedRandom, BufferedReader, BufferedWriter, FileIO, TextIOWrapper -from types import CodeType, TracebackType, _Cell +from types import CellType, CodeType, TracebackType # mypy crashes if any of {ByteString, Sequence, MutableSequence, Mapping, MutableMapping} are imported from collections.abc in builtins.pyi from typing import ( # noqa: Y022 @@ -863,7 +863,7 @@ class tuple(Sequence[_T_co]): class function: # Make sure this class definition stays roughly in line with `types.FunctionType` @property - def __closure__(self) -> tuple[_Cell, ...] | None: ... + def __closure__(self) -> tuple[CellType, ...] | None: ... __code__: CodeType __defaults__: tuple[Any, ...] | None __dict__: dict[str, Any] @@ -1245,7 +1245,7 @@ if sys.version_info >= (3, 11): locals: Mapping[str, object] | None = None, /, *, - closure: tuple[_Cell, ...] | None = None, + closure: tuple[CellType, ...] | None = None, ) -> None: ... else: @@ -1706,7 +1706,7 @@ def __import__( fromlist: Sequence[str] = (), level: int = 0, ) -> types.ModuleType: ... -def __build_class__(func: Callable[[], _Cell | Any], name: str, /, *bases: Any, metaclass: Any = ..., **kwds: Any) -> Any: ... +def __build_class__(func: Callable[[], CellType | Any], name: str, /, *bases: Any, metaclass: Any = ..., **kwds: Any) -> Any: ... if sys.version_info >= (3, 10): from types import EllipsisType diff --git a/mypy/typeshed/stdlib/dbm/__init__.pyi b/mypy/typeshed/stdlib/dbm/__init__.pyi index 9c6989a1c151..f414763e02a6 100644 --- a/mypy/typeshed/stdlib/dbm/__init__.pyi +++ b/mypy/typeshed/stdlib/dbm/__init__.pyi @@ -1,3 +1,5 @@ +import sys +from _typeshed import StrOrBytesPath from collections.abc import Iterator, MutableMapping from types import TracebackType from typing import Literal @@ -91,5 +93,10 @@ class _error(Exception): ... error: tuple[type[_error], type[OSError]] -def whichdb(filename: str) -> str | None: ... -def open(file: str, flag: _TFlags = "r", mode: int = 0o666) -> _Database: ... +if sys.version_info >= (3, 11): + def whichdb(filename: StrOrBytesPath) -> str | None: ... + def open(file: StrOrBytesPath, flag: _TFlags = "r", mode: int = 0o666) -> _Database: ... + +else: + def whichdb(filename: str) -> str | None: ... + def open(file: str, flag: _TFlags = "r", mode: int = 0o666) -> _Database: ... diff --git a/mypy/typeshed/stdlib/dbm/dumb.pyi b/mypy/typeshed/stdlib/dbm/dumb.pyi index 1fc68cf71f9d..1c0b7756f292 100644 --- a/mypy/typeshed/stdlib/dbm/dumb.pyi +++ b/mypy/typeshed/stdlib/dbm/dumb.pyi @@ -1,3 +1,5 @@ +import sys +from _typeshed import StrOrBytesPath from collections.abc import Iterator, MutableMapping from types import TracebackType from typing_extensions import Self, TypeAlias @@ -28,4 +30,8 @@ class _Database(MutableMapping[_KeyType, bytes]): self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None ) -> None: ... -def open(file: str, flag: str = "c", mode: int = 0o666) -> _Database: ... +if sys.version_info >= (3, 11): + def open(file: StrOrBytesPath, flag: str = "c", mode: int = 0o666) -> _Database: ... + +else: + def open(file: str, flag: str = "c", mode: int = 0o666) -> _Database: ... diff --git a/mypy/typeshed/stdlib/dbm/gnu.pyi b/mypy/typeshed/stdlib/dbm/gnu.pyi index 8b562019fcfb..e80441cbb25b 100644 --- a/mypy/typeshed/stdlib/dbm/gnu.pyi +++ b/mypy/typeshed/stdlib/dbm/gnu.pyi @@ -1,5 +1,5 @@ import sys -from _typeshed import ReadOnlyBuffer +from _typeshed import ReadOnlyBuffer, StrOrBytesPath from types import TracebackType from typing import TypeVar, overload from typing_extensions import Self, TypeAlias @@ -38,4 +38,7 @@ if sys.platform != "win32": __new__: None # type: ignore[assignment] __init__: None # type: ignore[assignment] - def open(filename: str, flags: str = "r", mode: int = 0o666, /) -> _gdbm: ... + if sys.version_info >= (3, 11): + def open(filename: StrOrBytesPath, flags: str = "r", mode: int = 0o666, /) -> _gdbm: ... + else: + def open(filename: str, flags: str = "r", mode: int = 0o666, /) -> _gdbm: ... diff --git a/mypy/typeshed/stdlib/dbm/ndbm.pyi b/mypy/typeshed/stdlib/dbm/ndbm.pyi index 5eb84e6949fc..02bf23ec181c 100644 --- a/mypy/typeshed/stdlib/dbm/ndbm.pyi +++ b/mypy/typeshed/stdlib/dbm/ndbm.pyi @@ -1,5 +1,5 @@ import sys -from _typeshed import ReadOnlyBuffer +from _typeshed import ReadOnlyBuffer, StrOrBytesPath from types import TracebackType from typing import TypeVar, overload from typing_extensions import Self, TypeAlias @@ -34,4 +34,7 @@ if sys.platform != "win32": __new__: None # type: ignore[assignment] __init__: None # type: ignore[assignment] - def open(filename: str, flags: str = "r", mode: int = 0o666, /) -> _dbm: ... + if sys.version_info >= (3, 11): + def open(filename: StrOrBytesPath, flags: str = "r", mode: int = 0o666, /) -> _dbm: ... + else: + def open(filename: str, flags: str = "r", mode: int = 0o666, /) -> _dbm: ... diff --git a/mypy/typeshed/stdlib/importlib/abc.pyi b/mypy/typeshed/stdlib/importlib/abc.pyi index 75e78ed59172..3937481159dc 100644 --- a/mypy/typeshed/stdlib/importlib/abc.pyi +++ b/mypy/typeshed/stdlib/importlib/abc.pyi @@ -64,7 +64,7 @@ class SourceLoader(ResourceLoader, ExecutionLoader, metaclass=ABCMeta): # The base classes differ starting in 3.10: if sys.version_info >= (3, 10): - # Please keep in sync with sys._MetaPathFinder + # Please keep in sync with _typeshed.importlib.MetaPathFinderProtocol class MetaPathFinder(metaclass=ABCMeta): if sys.version_info < (3, 12): def find_module(self, fullname: str, path: Sequence[str] | None) -> Loader | None: ... @@ -85,7 +85,7 @@ if sys.version_info >= (3, 10): def find_spec(self, fullname: str, target: types.ModuleType | None = ...) -> ModuleSpec | None: ... else: - # Please keep in sync with sys._MetaPathFinder + # Please keep in sync with _typeshed.importlib.MetaPathFinderProtocol class MetaPathFinder(Finder): def find_module(self, fullname: str, path: Sequence[str] | None) -> Loader | None: ... def invalidate_caches(self) -> None: ... diff --git a/mypy/typeshed/stdlib/importlib/util.pyi b/mypy/typeshed/stdlib/importlib/util.pyi index 6608f70d4469..2492c76d5c6c 100644 --- a/mypy/typeshed/stdlib/importlib/util.pyi +++ b/mypy/typeshed/stdlib/importlib/util.pyi @@ -3,6 +3,7 @@ import importlib.machinery import sys import types from _typeshed import ReadableBuffer, StrOrBytesPath +from _typeshed.importlib import LoaderProtocol from collections.abc import Callable from typing import Any from typing_extensions import ParamSpec @@ -23,13 +24,13 @@ def source_from_cache(path: str) -> str: ... def decode_source(source_bytes: ReadableBuffer) -> str: ... def find_spec(name: str, package: str | None = None) -> importlib.machinery.ModuleSpec | None: ... def spec_from_loader( - name: str, loader: importlib.abc.Loader | None, *, origin: str | None = None, is_package: bool | None = None + name: str, loader: LoaderProtocol | None, *, origin: str | None = None, is_package: bool | None = None ) -> importlib.machinery.ModuleSpec | None: ... def spec_from_file_location( name: str, location: StrOrBytesPath | None = None, *, - loader: importlib.abc.Loader | None = None, + loader: LoaderProtocol | None = None, submodule_search_locations: list[str] | None = ..., ) -> importlib.machinery.ModuleSpec | None: ... def module_from_spec(spec: importlib.machinery.ModuleSpec) -> types.ModuleType: ... diff --git a/mypy/typeshed/stdlib/logging/__init__.pyi b/mypy/typeshed/stdlib/logging/__init__.pyi index f5f7f91ece61..7ceddfa7ff28 100644 --- a/mypy/typeshed/stdlib/logging/__init__.pyi +++ b/mypy/typeshed/stdlib/logging/__init__.pyi @@ -587,7 +587,7 @@ def setLoggerClass(klass: type[Logger]) -> None: ... def captureWarnings(capture: bool) -> None: ... def setLogRecordFactory(factory: Callable[..., LogRecord]) -> None: ... -lastResort: StreamHandler[Any] | None +lastResort: Handler | None _StreamT = TypeVar("_StreamT", bound=SupportsWrite[str]) diff --git a/mypy/typeshed/stdlib/pathlib.pyi b/mypy/typeshed/stdlib/pathlib.pyi index 5ea025095f68..0013e221f2e1 100644 --- a/mypy/typeshed/stdlib/pathlib.pyi +++ b/mypy/typeshed/stdlib/pathlib.pyi @@ -15,7 +15,7 @@ from io import BufferedRandom, BufferedReader, BufferedWriter, FileIO, TextIOWra from os import PathLike, stat_result from types import TracebackType from typing import IO, Any, BinaryIO, Literal, overload -from typing_extensions import Self +from typing_extensions import Self, deprecated if sys.version_info >= (3, 9): from types import GenericAlias @@ -222,7 +222,11 @@ class Path(PurePath): else: def write_text(self, data: str, encoding: str | None = None, errors: str | None = None) -> int: ... if sys.version_info < (3, 12): - def link_to(self, target: StrOrBytesPath) -> None: ... + if sys.version_info >= (3, 10): + @deprecated("Deprecated as of Python 3.10 and removed in Python 3.12. Use hardlink_to() instead.") + def link_to(self, target: StrOrBytesPath) -> None: ... + else: + def link_to(self, target: StrOrBytesPath) -> None: ... if sys.version_info >= (3, 12): def walk( self, top_down: bool = ..., on_error: Callable[[OSError], object] | None = ..., follow_symlinks: bool = ... diff --git a/mypy/typeshed/stdlib/pkgutil.pyi b/mypy/typeshed/stdlib/pkgutil.pyi index 4a0c8d101b7a..7e7fa4fda9a1 100644 --- a/mypy/typeshed/stdlib/pkgutil.pyi +++ b/mypy/typeshed/stdlib/pkgutil.pyi @@ -1,7 +1,7 @@ import sys from _typeshed import SupportsRead +from _typeshed.importlib import LoaderProtocol, MetaPathFinderProtocol, PathEntryFinderProtocol from collections.abc import Callable, Iterable, Iterator -from importlib.abc import Loader, MetaPathFinder, PathEntryFinder from typing import IO, Any, NamedTuple, TypeVar from typing_extensions import deprecated @@ -23,7 +23,7 @@ if sys.version_info < (3, 12): _PathT = TypeVar("_PathT", bound=Iterable[str]) class ModuleInfo(NamedTuple): - module_finder: MetaPathFinder | PathEntryFinder + module_finder: MetaPathFinderProtocol | PathEntryFinderProtocol name: str ispkg: bool @@ -37,11 +37,11 @@ if sys.version_info < (3, 12): def __init__(self, fullname: str, file: IO[str], filename: str, etc: tuple[str, str, int]) -> None: ... @deprecated("Use importlib.util.find_spec() instead. Will be removed in Python 3.14.") -def find_loader(fullname: str) -> Loader | None: ... -def get_importer(path_item: str) -> PathEntryFinder | None: ... +def find_loader(fullname: str) -> LoaderProtocol | None: ... +def get_importer(path_item: str) -> PathEntryFinderProtocol | None: ... @deprecated("Use importlib.util.find_spec() instead. Will be removed in Python 3.14.") -def get_loader(module_or_name: str) -> Loader | None: ... -def iter_importers(fullname: str = "") -> Iterator[MetaPathFinder | PathEntryFinder]: ... +def get_loader(module_or_name: str) -> LoaderProtocol | None: ... +def iter_importers(fullname: str = "") -> Iterator[MetaPathFinderProtocol | PathEntryFinderProtocol]: ... def iter_modules(path: Iterable[str] | None = None, prefix: str = "") -> Iterator[ModuleInfo]: ... def read_code(stream: SupportsRead[bytes]) -> Any: ... # undocumented def walk_packages( diff --git a/mypy/typeshed/stdlib/shelve.pyi b/mypy/typeshed/stdlib/shelve.pyi index 59abeafe6fca..654c2ea097f7 100644 --- a/mypy/typeshed/stdlib/shelve.pyi +++ b/mypy/typeshed/stdlib/shelve.pyi @@ -1,3 +1,5 @@ +import sys +from _typeshed import StrOrBytesPath from collections.abc import Iterator, MutableMapping from dbm import _TFlags from types import TracebackType @@ -41,6 +43,17 @@ class BsdDbShelf(Shelf[_VT]): def last(self) -> tuple[str, _VT]: ... class DbfilenameShelf(Shelf[_VT]): - def __init__(self, filename: str, flag: _TFlags = "c", protocol: int | None = None, writeback: bool = False) -> None: ... + if sys.version_info >= (3, 11): + def __init__( + self, filename: StrOrBytesPath, flag: _TFlags = "c", protocol: int | None = None, writeback: bool = False + ) -> None: ... + else: + def __init__(self, filename: str, flag: _TFlags = "c", protocol: int | None = None, writeback: bool = False) -> None: ... -def open(filename: str, flag: _TFlags = "c", protocol: int | None = None, writeback: bool = False) -> Shelf[Any]: ... +if sys.version_info >= (3, 11): + def open( + filename: StrOrBytesPath, flag: _TFlags = "c", protocol: int | None = None, writeback: bool = False + ) -> Shelf[Any]: ... + +else: + def open(filename: str, flag: _TFlags = "c", protocol: int | None = None, writeback: bool = False) -> Shelf[Any]: ... diff --git a/mypy/typeshed/stdlib/socket.pyi b/mypy/typeshed/stdlib/socket.pyi index a309bac9370a..b626409d2dde 100644 --- a/mypy/typeshed/stdlib/socket.pyi +++ b/mypy/typeshed/stdlib/socket.pyi @@ -474,6 +474,13 @@ if sys.version_info >= (3, 12): ETHERTYPE_VLAN as ETHERTYPE_VLAN, ) + if sys.platform == "linux": + from _socket import ETH_P_ALL as ETH_P_ALL + + if sys.platform != "linux" and sys.platform != "win32" and sys.platform != "darwin": + # FreeBSD >= 14.0 + from _socket import PF_DIVERT as PF_DIVERT + # Re-exported from errno EBADF: int EAGAIN: int @@ -525,6 +532,9 @@ class AddressFamily(IntEnum): AF_BLUETOOTH = 32 if sys.platform == "win32" and sys.version_info >= (3, 12): AF_HYPERV = 34 + if sys.platform != "linux" and sys.platform != "win32" and sys.platform != "darwin" and sys.version_info >= (3, 12): + # FreeBSD >= 14.0 + AF_DIVERT = 44 AF_INET = AddressFamily.AF_INET AF_INET6 = AddressFamily.AF_INET6 @@ -577,6 +587,9 @@ if sys.platform != "win32" or sys.version_info >= (3, 9): if sys.platform == "win32" and sys.version_info >= (3, 12): AF_HYPERV = AddressFamily.AF_HYPERV +if sys.platform != "linux" and sys.platform != "win32" and sys.platform != "darwin" and sys.version_info >= (3, 12): + # FreeBSD >= 14.0 + AF_DIVERT = AddressFamily.AF_DIVERT class SocketKind(IntEnum): SOCK_STREAM = 1 diff --git a/mypy/typeshed/stdlib/sys/__init__.pyi b/mypy/typeshed/stdlib/sys/__init__.pyi index 353e20c4b2e1..5867c9a9d510 100644 --- a/mypy/typeshed/stdlib/sys/__init__.pyi +++ b/mypy/typeshed/stdlib/sys/__init__.pyi @@ -1,9 +1,8 @@ import sys from _typeshed import OptExcInfo, ProfileFunction, TraceFunction, structseq +from _typeshed.importlib import MetaPathFinderProtocol, PathEntryFinderProtocol from builtins import object as _object from collections.abc import AsyncGenerator, Callable, Sequence -from importlib.abc import PathEntryFinder -from importlib.machinery import ModuleSpec from io import TextIOWrapper from types import FrameType, ModuleType, TracebackType from typing import Any, Final, Literal, NoReturn, Protocol, TextIO, TypeVar, final @@ -15,10 +14,6 @@ _T = TypeVar("_T") _ExitCode: TypeAlias = str | int | None _OptExcInfo: TypeAlias = OptExcInfo # noqa: Y047 # TODO: obsolete, remove fall 2022 or later -# Intentionally omits one deprecated and one optional method of `importlib.abc.MetaPathFinder` -class _MetaPathFinder(Protocol): - def find_spec(self, fullname: str, path: Sequence[str] | None, target: ModuleType | None = ..., /) -> ModuleSpec | None: ... - # ----- sys variables ----- if sys.platform != "win32": abiflags: str @@ -44,13 +39,13 @@ if sys.version_info >= (3, 12): last_exc: BaseException # or undefined. maxsize: int maxunicode: int -meta_path: list[_MetaPathFinder] +meta_path: list[MetaPathFinderProtocol] modules: dict[str, ModuleType] if sys.version_info >= (3, 10): orig_argv: list[str] path: list[str] -path_hooks: list[Callable[[str], PathEntryFinder]] -path_importer_cache: dict[str, PathEntryFinder | None] +path_hooks: list[Callable[[str], PathEntryFinderProtocol]] +path_importer_cache: dict[str, PathEntryFinderProtocol | None] platform: str if sys.version_info >= (3, 9): platlibdir: str diff --git a/mypy/typeshed/stdlib/tempfile.pyi b/mypy/typeshed/stdlib/tempfile.pyi index ce8f2f1f5929..b66369926404 100644 --- a/mypy/typeshed/stdlib/tempfile.pyi +++ b/mypy/typeshed/stdlib/tempfile.pyi @@ -374,7 +374,11 @@ class SpooledTemporaryFile(IO[AnyStr], _SpooledTemporaryFileBase): def readlines(self, hint: int = ..., /) -> list[AnyStr]: ... # type: ignore[override] def seek(self, offset: int, whence: int = ...) -> int: ... def tell(self) -> int: ... - def truncate(self, size: int | None = None) -> None: ... # type: ignore[override] + if sys.version_info >= (3, 11): + def truncate(self, size: int | None = None) -> int: ... + else: + def truncate(self, size: int | None = None) -> None: ... # type: ignore[override] + @overload def write(self: SpooledTemporaryFile[str], s: str) -> int: ... @overload diff --git a/mypy/typeshed/stdlib/types.pyi b/mypy/typeshed/stdlib/types.pyi index f2d79b7f3ade..38940b4345c8 100644 --- a/mypy/typeshed/stdlib/types.pyi +++ b/mypy/typeshed/stdlib/types.pyi @@ -1,5 +1,6 @@ import sys from _typeshed import SupportsKeysAndGetItem +from _typeshed.importlib import LoaderProtocol from collections.abc import ( AsyncGenerator, Awaitable, @@ -16,7 +17,7 @@ from collections.abc import ( from importlib.machinery import ModuleSpec # pytype crashes if types.MappingProxyType inherits from collections.abc.Mapping instead of typing.Mapping -from typing import Any, ClassVar, Literal, Mapping, Protocol, TypeVar, final, overload # noqa: Y022 +from typing import Any, ClassVar, Literal, Mapping, TypeVar, final, overload # noqa: Y022 from typing_extensions import ParamSpec, Self, TypeVarTuple, deprecated __all__ = [ @@ -64,18 +65,11 @@ _T2 = TypeVar("_T2") _KT = TypeVar("_KT") _VT_co = TypeVar("_VT_co", covariant=True) -@final -class _Cell: - def __new__(cls, contents: object = ..., /) -> Self: ... - def __eq__(self, value: object, /) -> bool: ... - __hash__: ClassVar[None] # type: ignore[assignment] - cell_contents: Any - # Make sure this class definition stays roughly in line with `builtins.function` @final class FunctionType: @property - def __closure__(self) -> tuple[_Cell, ...] | None: ... + def __closure__(self) -> tuple[CellType, ...] | None: ... __code__: CodeType __defaults__: tuple[Any, ...] | None __dict__: dict[str, Any] @@ -98,7 +92,7 @@ class FunctionType: globals: dict[str, Any], name: str | None = ..., argdefs: tuple[object, ...] | None = ..., - closure: tuple[_Cell, ...] | None = ..., + closure: tuple[CellType, ...] | None = ..., ) -> Self: ... def __call__(self, *args: Any, **kwargs: Any) -> Any: ... @overload @@ -318,15 +312,12 @@ class SimpleNamespace: def __setattr__(self, name: str, value: Any, /) -> None: ... def __delattr__(self, name: str, /) -> None: ... -class _LoaderProtocol(Protocol): - def load_module(self, fullname: str, /) -> ModuleType: ... - class ModuleType: __name__: str __file__: str | None @property def __dict__(self) -> dict[str, Any]: ... # type: ignore[override] - __loader__: _LoaderProtocol | None + __loader__: LoaderProtocol | None __package__: str | None __path__: MutableSequence[str] __spec__: ModuleSpec | None @@ -336,6 +327,12 @@ class ModuleType: # using `builtins.__import__` or `importlib.import_module` less painful def __getattr__(self, name: str) -> Any: ... +@final +class CellType: + def __new__(cls, contents: object = ..., /) -> Self: ... + __hash__: ClassVar[None] # type: ignore[assignment] + cell_contents: Any + _YieldT_co = TypeVar("_YieldT_co", covariant=True) _SendT_contra = TypeVar("_SendT_contra", contravariant=True) _ReturnT_co = TypeVar("_ReturnT_co", covariant=True) @@ -405,7 +402,7 @@ class CoroutineType(Coroutine[_YieldT_co, _SendT_contra, _ReturnT_co]): @final class MethodType: @property - def __closure__(self) -> tuple[_Cell, ...] | None: ... # inherited from the added function + def __closure__(self) -> tuple[CellType, ...] | None: ... # inherited from the added function @property def __defaults__(self) -> tuple[Any, ...] | None: ... # inherited from the added function @property @@ -570,8 +567,6 @@ def coroutine(func: Callable[_P, Generator[Any, Any, _R]]) -> Callable[_P, Await @overload def coroutine(func: _Fn) -> _Fn: ... -CellType = _Cell - if sys.version_info >= (3, 9): class GenericAlias: @property diff --git a/mypy/typeshed/stdlib/typing.pyi b/mypy/typeshed/stdlib/typing.pyi index 4b80397bdd7a..d047f1c87621 100644 --- a/mypy/typeshed/stdlib/typing.pyi +++ b/mypy/typeshed/stdlib/typing.pyi @@ -8,7 +8,6 @@ import typing_extensions from _collections_abc import dict_items, dict_keys, dict_values from _typeshed import IdentityFunction, ReadableBuffer, SupportsKeysAndGetItem from abc import ABCMeta, abstractmethod -from contextlib import AbstractAsyncContextManager, AbstractContextManager from re import Match as Match, Pattern as Pattern from types import ( BuiltinFunctionType, @@ -24,10 +23,10 @@ from types import ( ) from typing_extensions import Never as _Never, ParamSpec as _ParamSpec -if sys.version_info >= (3, 10): - from types import UnionType if sys.version_info >= (3, 9): from types import GenericAlias +if sys.version_info >= (3, 10): + from types import UnionType __all__ = [ "AbstractSet", @@ -402,8 +401,8 @@ class Reversible(Iterable[_T_co], Protocol[_T_co]): def __reversed__(self) -> Iterator[_T_co]: ... _YieldT_co = TypeVar("_YieldT_co", covariant=True) -_SendT_contra = TypeVar("_SendT_contra", contravariant=True) -_ReturnT_co = TypeVar("_ReturnT_co", covariant=True) +_SendT_contra = TypeVar("_SendT_contra", contravariant=True, default=None) +_ReturnT_co = TypeVar("_ReturnT_co", covariant=True, default=None) class Generator(Iterator[_YieldT_co], Generic[_YieldT_co, _SendT_contra, _ReturnT_co]): def __next__(self) -> _YieldT_co: ... @@ -428,24 +427,28 @@ class Generator(Iterator[_YieldT_co], Generic[_YieldT_co, _SendT_contra, _Return @property def gi_yieldfrom(self) -> Generator[Any, Any, Any] | None: ... -# NOTE: Technically we would like this to be able to accept a second parameter as well, just -# like it's counterpart in contextlib, however `typing._SpecialGenericAlias` enforces the -# correct number of arguments at runtime, so we would be hiding runtime errors. -@runtime_checkable -class ContextManager(AbstractContextManager[_T_co, bool | None], Protocol[_T_co]): ... +# NOTE: Prior to Python 3.13 these aliases are lacking the second _ExitT_co parameter +if sys.version_info >= (3, 13): + from contextlib import AbstractAsyncContextManager as AsyncContextManager, AbstractContextManager as ContextManager +else: + from contextlib import AbstractAsyncContextManager, AbstractContextManager -# NOTE: Technically we would like this to be able to accept a second parameter as well, just -# like it's counterpart in contextlib, however `typing._SpecialGenericAlias` enforces the -# correct number of arguments at runtime, so we would be hiding runtime errors. -@runtime_checkable -class AsyncContextManager(AbstractAsyncContextManager[_T_co, bool | None], Protocol[_T_co]): ... + @runtime_checkable + class ContextManager(AbstractContextManager[_T_co, bool | None], Protocol[_T_co]): ... + + @runtime_checkable + class AsyncContextManager(AbstractAsyncContextManager[_T_co, bool | None], Protocol[_T_co]): ... @runtime_checkable class Awaitable(Protocol[_T_co]): @abstractmethod def __await__(self) -> Generator[Any, Any, _T_co]: ... -class Coroutine(Awaitable[_ReturnT_co], Generic[_YieldT_co, _SendT_contra, _ReturnT_co]): +# Non-default variations to accommodate couroutines, and `AwaitableGenerator` having a 4th type parameter. +_SendT_contra_nd = TypeVar("_SendT_contra_nd", contravariant=True) +_ReturnT_co_nd = TypeVar("_ReturnT_co_nd", covariant=True) + +class Coroutine(Awaitable[_ReturnT_co_nd], Generic[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd]): __name__: str __qualname__: str @property @@ -457,7 +460,7 @@ class Coroutine(Awaitable[_ReturnT_co], Generic[_YieldT_co, _SendT_contra, _Retu @property def cr_running(self) -> bool: ... @abstractmethod - def send(self, value: _SendT_contra, /) -> _YieldT_co: ... + def send(self, value: _SendT_contra_nd, /) -> _YieldT_co: ... @overload @abstractmethod def throw( @@ -473,9 +476,9 @@ class Coroutine(Awaitable[_ReturnT_co], Generic[_YieldT_co, _SendT_contra, _Retu # The parameters correspond to Generator, but the 4th is the original type. @type_check_only class AwaitableGenerator( - Awaitable[_ReturnT_co], - Generator[_YieldT_co, _SendT_contra, _ReturnT_co], - Generic[_YieldT_co, _SendT_contra, _ReturnT_co, _S], + Awaitable[_ReturnT_co_nd], + Generator[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd], + Generic[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd, _S], metaclass=ABCMeta, ): ...