Skip to content

Commit

Permalink
Fix @patch when new is missing (#10459)
Browse files Browse the repository at this point in the history
  • Loading branch information
srittau authored Jul 14, 2023
1 parent 1d7f0d0 commit 7ea173c
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 11 deletions.
27 changes: 21 additions & 6 deletions stdlib/unittest/mock.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ class _patch(Generic[_T]):
def copy(self) -> _patch[_T]: ...
@overload
def __call__(self, func: _TT) -> _TT: ...
# If new==DEFAULT, this should add a MagicMock parameter to the function
# arguments. See the _patch_default_new class below for this functionality.
@overload
def __call__(self, func: Callable[_P, _R]) -> Callable[_P, _R]: ...
if sys.version_info >= (3, 8):
Expand All @@ -257,6 +259,22 @@ class _patch(Generic[_T]):
def start(self) -> _T: ...
def stop(self) -> None: ...

if sys.version_info >= (3, 8):
_Mock: TypeAlias = MagicMock | AsyncMock
else:
_Mock: TypeAlias = MagicMock

# This class does not exist at runtime, it's a hack to make this work:
# @patch("foo")
# def bar(..., mock: MagicMock) -> None: ...
class _patch_default_new(_patch[_Mock]):
@overload
def __call__(self, func: _TT) -> _TT: ...
# Can't use the following as ParamSpec is only allowed as last parameter:
# def __call__(self, func: Callable[_P, _R]) -> Callable[Concatenate[_P, MagicMock], _R]: ...
@overload
def __call__(self, func: Callable[..., _R]) -> Callable[..., _R]: ...

class _patch_dict:
in_dict: Any
values: Any
Expand All @@ -273,11 +291,8 @@ class _patch_dict:
start: Any
stop: Any

if sys.version_info >= (3, 8):
_Mock: TypeAlias = MagicMock | AsyncMock
else:
_Mock: TypeAlias = MagicMock

# This class does not exist at runtime, it's a hack to add methods to the
# patch() function.
class _patcher:
TEST_PREFIX: str
dict: type[_patch_dict]
Expand Down Expand Up @@ -307,7 +322,7 @@ class _patcher:
autospec: Any | None = ...,
new_callable: Any | None = ...,
**kwargs: Any,
) -> _patch[_Mock]: ...
) -> _patch_default_new: ...
@overload
@staticmethod
def object( # type: ignore[misc]
Expand Down
17 changes: 12 additions & 5 deletions test_cases/stdlib/check_unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from decimal import Decimal
from fractions import Fraction
from typing_extensions import assert_type
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch

case = unittest.TestCase()

Expand Down Expand Up @@ -94,13 +94,20 @@ def __gt__(self, other: Bacon) -> bool:
###


@patch("sys.exit", new=Mock())
def f(i: int) -> str:
@patch("sys.exit")
def f_default_new(i: int, mock: MagicMock) -> str:
return "asdf"


@patch("sys.exit", new=42)
def f_explicit_new(i: int) -> str:
return "asdf"


assert_type(f(1), str)
f("a") # type: ignore
assert_type(f_default_new(1), str)
f_default_new("a") # Not an error due to ParamSpec limitations
assert_type(f_explicit_new(1), str)
f_explicit_new("a") # type: ignore[arg-type]


@patch("sys.exit", new=Mock())
Expand Down

0 comments on commit 7ea173c

Please sign in to comment.