Skip to content

Commit

Permalink
Improve yield from inference for unions of generators (#16717)
Browse files Browse the repository at this point in the history
Fixes #15141, closes #15168
  • Loading branch information
hauntsaninja committed Mar 21, 2024
1 parent a505e5f commit 394d17b
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 18 deletions.
5 changes: 3 additions & 2 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,8 +989,9 @@ def get_generator_return_type(self, return_type: Type, is_coroutine: bool) -> Ty
# AwaitableGenerator, Generator: tr is args[2].
return return_type.args[2]
else:
# Supertype of Generator (Iterator, Iterable, object): tr is any.
return AnyType(TypeOfAny.special_form)
# We have a supertype of Generator (Iterator, Iterable, object)
# Treat `Iterator[X]` as a shorthand for `Generator[X, Any, None]`.
return NoneType()

def visit_func_def(self, defn: FuncDef) -> None:
if not self.recurse_into_functions:
Expand Down
12 changes: 1 addition & 11 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5963,17 +5963,7 @@ def visit_yield_from_expr(self, e: YieldFromExpr, allow_none_return: bool = Fals

# Determine the type of the entire yield from expression.
iter_type = get_proper_type(iter_type)
if isinstance(iter_type, Instance) and iter_type.type.fullname == "typing.Generator":
expr_type = self.chk.get_generator_return_type(iter_type, False)
else:
# Non-Generators don't return anything from `yield from` expressions.
# However special-case Any (which might be produced by an error).
actual_item_type = get_proper_type(actual_item_type)
if isinstance(actual_item_type, AnyType):
expr_type = AnyType(TypeOfAny.from_another_any, source_any=actual_item_type)
else:
# Treat `Iterator[X]` as a shorthand for `Generator[X, None, Any]`.
expr_type = NoneType()
expr_type = self.chk.get_generator_return_type(iter_type, is_coroutine=False)

if not allow_none_return and isinstance(get_proper_type(expr_type), NoneType):
self.chk.msg.does_not_return_value(None, e)
Expand Down
9 changes: 5 additions & 4 deletions mypyc/test-data/run-generators.test
Original file line number Diff line number Diff line change
Expand Up @@ -246,12 +246,12 @@ assert run_generator(another_triple()()) == ((1,), None)
assert run_generator(outer()) == ((0, 1, 2, 3, 4), None)

[case testYieldThrow]
from typing import Generator, Iterable, Any
from typing import Generator, Iterable, Any, Union
from traceback import print_tb
from contextlib import contextmanager
import wrapsys

def generator() -> Iterable[int]:
def generator() -> Generator[int, None, Union[int, None]]:
try:
yield 1
yield 2
Expand All @@ -264,6 +264,7 @@ def generator() -> Iterable[int]:
else:
print('caught exception without value')
return 0
return None

def no_except() -> Iterable[int]:
yield 1
Expand Down Expand Up @@ -355,11 +356,11 @@ with ctx_manager() as c:
raise Exception
File "native.py", line 10, in generator
yield 3
File "native.py", line 30, in wrapper
File "native.py", line 31, in wrapper
return (yield from x)
File "native.py", line 9, in generator
yield 2
File "native.py", line 30, in wrapper
File "native.py", line 31, in wrapper
return (yield from x)
caught exception without value
caught exception with value some string
Expand Down
47 changes: 46 additions & 1 deletion test-data/unit/check-statements.test
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def f() -> Generator[int, None, None]:
from typing import Iterator
def f() -> Iterator[int]:
yield 1
return "foo"
return "foo" # E: No return value expected
[out]


Expand Down Expand Up @@ -2231,6 +2231,51 @@ class B: pass
def foo(x: int) -> Union[Generator[A, None, None], Generator[B, None, None]]:
yield x # E: Incompatible types in "yield" (actual type "int", expected type "Union[A, B]")

[case testYieldFromUnionOfGenerators]
from typing import Generator, Union

class T: pass

def foo(arg: Union[Generator[int, None, T], Generator[str, None, T]]) -> Generator[Union[int, str], None, T]:
return (yield from arg)

[case testYieldFromInvalidUnionReturn]
from typing import Generator, Union

class A: pass
class B: pass

def foo(arg: Union[A, B]) -> Generator[Union[int, str], None, A]:
return (yield from arg) # E: "yield from" can't be applied to "Union[A, B]"

[case testYieldFromUnionOfGeneratorWithIterableStr]
from typing import Generator, Union, Iterable, Optional

def foo(arg: Union[Generator[int, None, bytes], Iterable[str]]) -> Generator[Union[int, str], None, Optional[bytes]]:
return (yield from arg)

def bar(arg: Generator[str, None, str]) -> Generator[str, None, str]:
return foo(arg) # E: Incompatible return value type (got "Generator[Union[int, str], None, Optional[bytes]]", expected "Generator[str, None, str]")

def launder(arg: Iterable[str]) -> Generator[Union[int, str], None, Optional[bytes]]:
return foo(arg)

def baz(arg: Generator[str, None, str]) -> Generator[Union[int, str], None, Optional[bytes]]:
# this is unsound, the Generator return type will actually be str
return launder(arg)
[builtins fixtures/tuple.pyi]

[case testYieldIteratorReturn]
from typing import Iterator

def get_strings(foo: bool) -> Iterator[str]:
if foo:
return ["foo1", "foo2"] # E: No return value expected
else:
yield "bar1"
yield "bar2"
[builtins fixtures/tuple.pyi]

[case testNoCrashOnStarRightHandSide]
x = *(1, 2, 3) # E: can't use starred expression here
[builtins fixtures/tuple.pyi]
Expand Down

0 comments on commit 394d17b

Please sign in to comment.