Skip to content

Commit

Permalink
Infer unions for ternary expressions (#17427)
Browse files Browse the repository at this point in the history
Ref #12056

cc @JukkaL 

Again, let's check the primer...
  • Loading branch information
ilevkivskyi committed Jul 2, 2024
1 parent f297917 commit 55a0812
Show file tree
Hide file tree
Showing 9 changed files with 99 additions and 65 deletions.
19 changes: 9 additions & 10 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5766,16 +5766,15 @@ def visit_conditional_expr(self, e: ConditionalExpr, allow_none_return: bool = F
context=if_type_fallback,
allow_none_return=allow_none_return,
)

# Only create a union type if the type context is a union, to be mostly
# compatible with older mypy versions where we always did a join.
#
# TODO: Always create a union or at least in more cases?
if isinstance(get_proper_type(self.type_context[-1]), UnionType):
res: Type = make_simplified_union([if_type, full_context_else_type])
else:
res = join.join_types(if_type, else_type)

res: Type = make_simplified_union([if_type, else_type])
if has_uninhabited_component(res) and not isinstance(
get_proper_type(self.type_context[-1]), UnionType
):
# In rare cases with empty collections join may give a better result.
alternative = join.join_types(if_type, else_type)
p_alt = get_proper_type(alternative)
if not isinstance(p_alt, Instance) or p_alt.type.fullname != "builtins.object":
res = alternative
return res

def analyze_cond_branch(
Expand Down
4 changes: 3 additions & 1 deletion mypyc/test-data/irbuild-any.test
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ def f4(a, n, b):
a :: object
n :: int
b :: bool
r0, r1, r2, r3 :: object
r0 :: union[object, int]
r1, r2 :: object
r3 :: union[int, object]
r4 :: int
L0:
if b goto L1 else goto L2 :: bool
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/check-errorcodes.test
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ a: D = {'x': ''} # E: Incompatible types (expression has type "str", TypedDict
b: D = {'y': ''} # E: Missing key "x" for TypedDict "D" [typeddict-item] \
# E: Extra key "y" for TypedDict "D" [typeddict-unknown-key]
c = D(x=0) if int() else E(x=0, y=0)
c = {} # E: Expected TypedDict key "x" but found no keys [typeddict-item]
c = {} # E: Missing key "x" for TypedDict "D" [typeddict-item]
d: D = {'x': '', 'y': 1} # E: Extra key "y" for TypedDict "D" [typeddict-unknown-key] \
# E: Incompatible types (expression has type "str", TypedDict item "x" has type "int") [typeddict-item]

Expand Down
17 changes: 8 additions & 9 deletions test-data/unit/check-expressions.test
Original file line number Diff line number Diff line change
Expand Up @@ -1470,10 +1470,9 @@ if int():

[case testConditionalExpressionUnion]
from typing import Union
reveal_type(1 if bool() else 2) # N: Revealed type is "builtins.int"
reveal_type(1 if bool() else '') # N: Revealed type is "builtins.object"
x: Union[int, str] = reveal_type(1 if bool() else '') \
# N: Revealed type is "Union[Literal[1]?, Literal['']?]"
reveal_type(1 if bool() else 2) # N: Revealed type is "Union[Literal[1]?, Literal[2]?]"
reveal_type(1 if bool() else '') # N: Revealed type is "Union[Literal[1]?, Literal['']?]"
x: Union[int, str] = reveal_type(1 if bool() else '') # N: Revealed type is "Union[Literal[1]?, Literal['']?]"
class A:
pass
class B(A):
Expand All @@ -1487,17 +1486,17 @@ b = B()
c = C()
d = D()
reveal_type(a if bool() else b) # N: Revealed type is "__main__.A"
reveal_type(b if bool() else c) # N: Revealed type is "builtins.object"
reveal_type(c if bool() else b) # N: Revealed type is "builtins.object"
reveal_type(c if bool() else a) # N: Revealed type is "builtins.object"
reveal_type(d if bool() else b) # N: Revealed type is "__main__.A"
reveal_type(b if bool() else c) # N: Revealed type is "Union[__main__.B, __main__.C]"
reveal_type(c if bool() else b) # N: Revealed type is "Union[__main__.C, __main__.B]"
reveal_type(c if bool() else a) # N: Revealed type is "Union[__main__.C, __main__.A]"
reveal_type(d if bool() else b) # N: Revealed type is "Union[__main__.D, __main__.B]"
[builtins fixtures/bool.pyi]

[case testConditionalExpressionUnionWithAny]
from typing import Union, Any
a: Any
x: Union[int, str] = reveal_type(a if int() else 1) # N: Revealed type is "Union[Any, Literal[1]?]"
reveal_type(a if int() else 1) # N: Revealed type is "Any"
reveal_type(a if int() else 1) # N: Revealed type is "Union[Any, Literal[1]?]"

[case testConditionalExpressionStatementNoReturn]
from typing import List, Union
Expand Down
17 changes: 15 additions & 2 deletions test-data/unit/check-functions.test
Original file line number Diff line number Diff line change
Expand Up @@ -2250,13 +2250,26 @@ def dec(f: Callable[[A, str], None]) -> Callable[[A, int], None]: pass
[out]

[case testUnknownFunctionNotCallable]
from typing import TypeVar

def f() -> None:
pass
def g(x: int) -> None:
pass
h = f if bool() else g
reveal_type(h) # N: Revealed type is "builtins.function"
h(7) # E: Cannot call function of unknown type
reveal_type(h) # N: Revealed type is "Union[def (), def (x: builtins.int)]"
h(7) # E: Too many arguments for "f"

T = TypeVar("T")
def join(x: T, y: T) -> T: ...

h2 = join(f, g)
reveal_type(h2) # N: Revealed type is "builtins.function"
h2(7) # E: Cannot call function of unknown type

h3 = join(g, f)
reveal_type(h3) # N: Revealed type is "builtins.function"
h3(7) # E: Cannot call function of unknown type
[builtins fixtures/bool.pyi]

[case testFunctionWithNameUnderscore]
Expand Down
6 changes: 3 additions & 3 deletions test-data/unit/check-inference-context.test
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,7 @@ class A: pass
class B(A): pass
class C(A): pass
def f(func: Callable[[T], S], *z: T, r: Optional[S] = None) -> S: pass
reveal_type(f(lambda x: 0 if isinstance(x, B) else 1)) # N: Revealed type is "builtins.int"
reveal_type(f(lambda x: 0 if isinstance(x, B) else 1)) # N: Revealed type is "Union[Literal[0]?, Literal[1]?]"
f(lambda x: 0 if isinstance(x, B) else 1, A())() # E: "int" not callable
f(lambda x: x if isinstance(x, B) else B(), A(), r=B())() # E: "B" not callable
f(
Expand Down Expand Up @@ -1391,15 +1391,15 @@ from typing import Union, List, Any

def f(x: Union[List[str], Any]) -> None:
a = x if x else []
reveal_type(a) # N: Revealed type is "Union[builtins.list[Union[builtins.str, Any]], builtins.list[builtins.str], Any]"
reveal_type(a) # N: Revealed type is "Union[builtins.list[builtins.str], Any, builtins.list[Union[builtins.str, Any]]]"
[builtins fixtures/list.pyi]

[case testConditionalExpressionWithEmptyIteableAndUnionWithAny]
from typing import Union, Iterable, Any

def f(x: Union[Iterable[str], Any]) -> None:
a = x if x else []
reveal_type(a) # N: Revealed type is "Union[builtins.list[Union[builtins.str, Any]], typing.Iterable[builtins.str], Any]"
reveal_type(a) # N: Revealed type is "Union[typing.Iterable[builtins.str], Any, builtins.list[Union[builtins.str, Any]]]"
[builtins fixtures/list.pyi]

[case testInferMultipleAnyUnionCovariant]
Expand Down
10 changes: 7 additions & 3 deletions test-data/unit/check-inference.test
Original file line number Diff line number Diff line change
Expand Up @@ -1438,18 +1438,22 @@ class Wrapper:

def f(cond: bool) -> Any:
f = Wrapper if cond else lambda x: x
reveal_type(f) # N: Revealed type is "def (x: Any) -> Any"
reveal_type(f) # N: Revealed type is "Union[def (x: Any) -> __main__.Wrapper, def (x: Any) -> Any]"
return f(3)

def g(cond: bool) -> Any:
f = lambda x: x if cond else Wrapper
reveal_type(f) # N: Revealed type is "def (x: Any) -> Any"
reveal_type(f) # N: Revealed type is "def (x: Any) -> Union[Any, def (x: Any) -> __main__.Wrapper]"
return f(3)

def h(cond: bool) -> Any:
f = (lambda x: x) if cond else Wrapper
reveal_type(f) # N: Revealed type is "Union[def (x: Any) -> Any, def (x: Any) -> __main__.Wrapper]"
return f(3)

-- Boolean operators
-- -----------------


[case testOrOperationWithGenericOperands]
from typing import List
a: List[A]
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/check-optional.test
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def lookup_field(name, obj):
attr = None

[case testTernaryWithNone]
reveal_type(None if bool() else 0) # N: Revealed type is "Union[Literal[0]?, None]"
reveal_type(None if bool() else 0) # N: Revealed type is "Union[None, Literal[0]?]"
[builtins fixtures/bool.pyi]

[case testListWithNone]
Expand Down
87 changes: 52 additions & 35 deletions test-data/unit/check-tuples.test
Original file line number Diff line number Diff line change
Expand Up @@ -1228,68 +1228,76 @@ x, y = g(z) # E: Argument 1 to "g" has incompatible type "int"; expected "Tuple[
[out]

[case testFixedTupleJoinVarTuple]
from typing import Tuple
from typing import Tuple, TypeVar

class A: pass
class B(A): pass

fixtup: Tuple[B, B]

T = TypeVar("T")
def join(x: T, y: T) -> T: ...

vartup_b: Tuple[B, ...]
reveal_type(fixtup if int() else vartup_b) # N: Revealed type is "builtins.tuple[__main__.B, ...]"
reveal_type(vartup_b if int() else fixtup) # N: Revealed type is "builtins.tuple[__main__.B, ...]"
reveal_type(join(fixtup, vartup_b)) # N: Revealed type is "builtins.tuple[__main__.B, ...]"
reveal_type(join(vartup_b, fixtup)) # N: Revealed type is "builtins.tuple[__main__.B, ...]"

vartup_a: Tuple[A, ...]
reveal_type(fixtup if int() else vartup_a) # N: Revealed type is "builtins.tuple[__main__.A, ...]"
reveal_type(vartup_a if int() else fixtup) # N: Revealed type is "builtins.tuple[__main__.A, ...]"

reveal_type(join(fixtup, vartup_a)) # N: Revealed type is "builtins.tuple[__main__.A, ...]"
reveal_type(join(vartup_a, fixtup)) # N: Revealed type is "builtins.tuple[__main__.A, ...]"

[builtins fixtures/tuple.pyi]
[out]

[case testFixedTupleJoinList]
from typing import Tuple, List
from typing import Tuple, List, TypeVar

class A: pass
class B(A): pass

fixtup: Tuple[B, B]

T = TypeVar("T")
def join(x: T, y: T) -> T: ...

lst_b: List[B]
reveal_type(fixtup if int() else lst_b) # N: Revealed type is "typing.Sequence[__main__.B]"
reveal_type(lst_b if int() else fixtup) # N: Revealed type is "typing.Sequence[__main__.B]"
reveal_type(join(fixtup, lst_b)) # N: Revealed type is "typing.Sequence[__main__.B]"
reveal_type(join(lst_b, fixtup)) # N: Revealed type is "typing.Sequence[__main__.B]"

lst_a: List[A]
reveal_type(fixtup if int() else lst_a) # N: Revealed type is "typing.Sequence[__main__.A]"
reveal_type(lst_a if int() else fixtup) # N: Revealed type is "typing.Sequence[__main__.A]"
reveal_type(join(fixtup, lst_a)) # N: Revealed type is "typing.Sequence[__main__.A]"
reveal_type(join(lst_a, fixtup)) # N: Revealed type is "typing.Sequence[__main__.A]"

[builtins fixtures/tuple.pyi]
[out]

[case testEmptyTupleJoin]
from typing import Tuple, List
from typing import Tuple, List, TypeVar

class A: pass

empty = ()

T = TypeVar("T")
def join(x: T, y: T) -> T: ...

fixtup: Tuple[A]
reveal_type(fixtup if int() else empty) # N: Revealed type is "builtins.tuple[__main__.A, ...]"
reveal_type(empty if int() else fixtup) # N: Revealed type is "builtins.tuple[__main__.A, ...]"
reveal_type(join(fixtup, empty)) # N: Revealed type is "builtins.tuple[__main__.A, ...]"
reveal_type(join(empty, fixtup)) # N: Revealed type is "builtins.tuple[__main__.A, ...]"

vartup: Tuple[A, ...]
reveal_type(empty if int() else vartup) # N: Revealed type is "builtins.tuple[__main__.A, ...]"
reveal_type(vartup if int() else empty) # N: Revealed type is "builtins.tuple[__main__.A, ...]"
reveal_type(join(vartup, empty)) # N: Revealed type is "builtins.tuple[__main__.A, ...]"
reveal_type(join(empty, vartup)) # N: Revealed type is "builtins.tuple[__main__.A, ...]"

lst: List[A]
reveal_type(empty if int() else lst) # N: Revealed type is "typing.Sequence[__main__.A]"
reveal_type(lst if int() else empty) # N: Revealed type is "typing.Sequence[__main__.A]"
reveal_type(join(empty, lst)) # N: Revealed type is "typing.Sequence[__main__.A]"
reveal_type(join(lst, empty)) # N: Revealed type is "typing.Sequence[__main__.A]"

[builtins fixtures/tuple.pyi]
[out]

[case testTupleSubclassJoin]
from typing import Tuple, NamedTuple
from typing import Tuple, NamedTuple, TypeVar

class NTup(NamedTuple):
a: bool
Expand All @@ -1302,32 +1310,38 @@ ntup: NTup
subtup: SubTuple
vartup: SubVarTuple

reveal_type(ntup if int() else vartup) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
reveal_type(subtup if int() else vartup) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
T = TypeVar("T")
def join(x: T, y: T) -> T: ...

reveal_type(join(ntup, vartup)) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
reveal_type(join(subtup, vartup)) # N: Revealed type is "builtins.tuple[builtins.int, ...]"

[builtins fixtures/tuple.pyi]
[out]

[case testTupleJoinIrregular]
from typing import Tuple
from typing import Tuple, TypeVar

tup1: Tuple[bool, int]
tup2: Tuple[bool]

reveal_type(tup1 if int() else tup2) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
reveal_type(tup2 if int() else tup1) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
T = TypeVar("T")
def join(x: T, y: T) -> T: ...

reveal_type(join(tup1, tup2)) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
reveal_type(join(tup2, tup1)) # N: Revealed type is "builtins.tuple[builtins.int, ...]"

reveal_type(tup1 if int() else ()) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
reveal_type(() if int() else tup1) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
reveal_type(join(tup1, ())) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
reveal_type(join((), tup1)) # N: Revealed type is "builtins.tuple[builtins.int, ...]"

reveal_type(tup2 if int() else ()) # N: Revealed type is "builtins.tuple[builtins.bool, ...]"
reveal_type(() if int() else tup2) # N: Revealed type is "builtins.tuple[builtins.bool, ...]"
reveal_type(join(tup2, ())) # N: Revealed type is "builtins.tuple[builtins.bool, ...]"
reveal_type(join((), tup2)) # N: Revealed type is "builtins.tuple[builtins.bool, ...]"

[builtins fixtures/tuple.pyi]
[out]

[case testTupleSubclassJoinIrregular]
from typing import Tuple, NamedTuple
from typing import Tuple, NamedTuple, TypeVar

class NTup1(NamedTuple):
a: bool
Expand All @@ -1342,14 +1356,17 @@ tup1: NTup1
tup2: NTup2
subtup: SubTuple

reveal_type(tup1 if int() else tup2) # N: Revealed type is "builtins.tuple[builtins.bool, ...]"
reveal_type(tup2 if int() else tup1) # N: Revealed type is "builtins.tuple[builtins.bool, ...]"
T = TypeVar("T")
def join(x: T, y: T) -> T: ...

reveal_type(join(tup1, tup2)) # N: Revealed type is "builtins.tuple[builtins.bool, ...]"
reveal_type(join(tup2, tup1)) # N: Revealed type is "builtins.tuple[builtins.bool, ...]"

reveal_type(tup1 if int() else subtup) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
reveal_type(subtup if int() else tup1) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
reveal_type(join(tup1, subtup)) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
reveal_type(join(subtup, tup1)) # N: Revealed type is "builtins.tuple[builtins.int, ...]"

reveal_type(tup2 if int() else subtup) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
reveal_type(subtup if int() else tup2) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
reveal_type(join(tup2, subtup)) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
reveal_type(join(subtup, tup2)) # N: Revealed type is "builtins.tuple[builtins.int, ...]"

[builtins fixtures/tuple.pyi]
[out]
Expand Down

0 comments on commit 55a0812

Please sign in to comment.