Skip to content

Commit

Permalink
Narrow based on collection containment (#17344)
Browse files Browse the repository at this point in the history
Enables the narrowing of variable types when checking a variable is "in"
a collection, and the collection type is a subtype of the variable type.

Fixes #3229 

This PR updates the type narrowing for the "in" operator and allows it
to narrow the type of a variable to the type of the collection's items -
if the collection item type is a subtype of the variable (as defined by
is_subtype).

Examples
```python
def foobar(foo: Union[str, float]):
    if foo in ['a', 'b']:
        reveal_type(foo)  # N: Revealed type is "builtins.str"
    else:
        reveal_type(foo)  # N: Revealed type is "Union[builtins.str, builtins.float]"
```
```python
typ: List[Literal['a', 'b']] = ['a', 'b']
x: str = "hi!"
if x in typ:
    reveal_type(x)  # N: Revealed type is "Union[Literal['a'], Literal['b']]"
else:
    reveal_type(x)  # N: Revealed type is "builtins.str"
```

One existing test was updated, which compared `Optional[A]` with "in" to
`(None,)`. Piror to this change that resulted in `Union[__main__.A,
None]`, which now narrows to `None`. Test cases have been added for
"in", "not in", Sets, Lists, and Tuples.

I did add to the existing narrowing.pyi fixture for the test cases. A
search of the *.test files shows it was only used in the narrowing
tests, so there shouldn't be any speed impact in other areas.

---------

Co-authored-by: Jordandev678 <[email protected]>
  • Loading branch information
Jordandev678 and Jordandev678 committed Jul 24, 2024
1 parent 18965d6 commit ed0cd4a
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 8 deletions.
15 changes: 10 additions & 5 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6011,11 +6011,16 @@ def has_no_custom_eq_checks(t: Type) -> bool:
if_map, else_map = {}, {}

if left_index in narrowable_operand_index_to_hash:
# We only try and narrow away 'None' for now
if is_overlapping_none(item_type):
collection_item_type = get_proper_type(
builtin_item_type(iterable_type)
)
collection_item_type = get_proper_type(builtin_item_type(iterable_type))
# Narrow if the collection is a subtype
if (
collection_item_type is not None
and collection_item_type != item_type
and is_subtype(collection_item_type, item_type)
):
if_map[operands[left_index]] = collection_item_type
# Try and narrow away 'None'
elif is_overlapping_none(item_type):
if (
collection_item_type is not None
and not is_overlapping_none(collection_item_type)
Expand Down
112 changes: 110 additions & 2 deletions test-data/unit/check-narrowing.test
Original file line number Diff line number Diff line change
Expand Up @@ -1376,13 +1376,13 @@ else:
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"

if val in (None,):
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
reveal_type(val) # N: Revealed type is "None"
else:
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
if val not in (None,):
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
else:
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
reveal_type(val) # N: Revealed type is "None"
[builtins fixtures/primitives.pyi]

[case testNarrowingWithTupleOfTypes]
Expand Down Expand Up @@ -2114,3 +2114,111 @@ else:

[typing fixtures/typing-medium.pyi]
[builtins fixtures/ops.pyi]


[case testTypeNarrowingStringInLiteralUnion]
from typing import Literal, Tuple
typ: Tuple[Literal['a', 'b'], ...] = ('a', 'b')
x: str = "hi!"
if x in typ:
reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]"
else:
reveal_type(x) # N: Revealed type is "builtins.str"
[builtins fixtures/tuple.pyi]
[typing fixtures/typing-medium.pyi]

[case testTypeNarrowingStringInLiteralUnionSubset]
from typing import Literal, Tuple
typeAlpha: Tuple[Literal['a', 'b', 'c'], ...] = ('a', 'b')
strIn: str = "b"
strOut: str = "c"
if strIn in typeAlpha:
reveal_type(strIn) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]"
else:
reveal_type(strIn) # N: Revealed type is "builtins.str"
if strOut in typeAlpha:
reveal_type(strOut) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]"
else:
reveal_type(strOut) # N: Revealed type is "builtins.str"
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-medium.pyi]

[case testNarrowingStringNotInLiteralUnion]
from typing import Literal, Tuple
typeAlpha: Tuple[Literal['a', 'b', 'c'],...] = ('a', 'b', 'c')
strIn: str = "c"
strOut: str = "d"
if strIn not in typeAlpha:
reveal_type(strIn) # N: Revealed type is "builtins.str"
else:
reveal_type(strIn) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]"
if strOut in typeAlpha:
reveal_type(strOut) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]"
else:
reveal_type(strOut) # N: Revealed type is "builtins.str"
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-medium.pyi]

[case testNarrowingStringInLiteralUnionDontExpand]
from typing import Literal, Tuple
typeAlpha: Tuple[Literal['a', 'b', 'c'], ...] = ('a', 'b', 'c')
strIn: Literal['c'] = "c"
reveal_type(strIn) # N: Revealed type is "Literal['c']"
#Check we don't expand a Literal into the Union type
if strIn not in typeAlpha:
reveal_type(strIn) # N: Revealed type is "Literal['c']"
else:
reveal_type(strIn) # N: Revealed type is "Literal['c']"
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-medium.pyi]

[case testTypeNarrowingStringInMixedUnion]
from typing import Literal, Tuple
typ: Tuple[Literal['a', 'b'], ...] = ('a', 'b')
x: str = "hi!"
if x in typ:
reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]"
else:
reveal_type(x) # N: Revealed type is "builtins.str"
[builtins fixtures/tuple.pyi]
[typing fixtures/typing-medium.pyi]

[case testTypeNarrowingStringInSet]
from typing import Literal, Set
typ: Set[Literal['a', 'b']] = {'a', 'b'}
x: str = "hi!"
if x in typ:
reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]"
else:
reveal_type(x) # N: Revealed type is "builtins.str"
if x not in typ:
reveal_type(x) # N: Revealed type is "builtins.str"
else:
reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]"
[builtins fixtures/narrowing.pyi]
[typing fixtures/typing-medium.pyi]

[case testTypeNarrowingStringInList]
from typing import Literal, List
typ: List[Literal['a', 'b']] = ['a', 'b']
x: str = "hi!"
if x in typ:
reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]"
else:
reveal_type(x) # N: Revealed type is "builtins.str"
if x not in typ:
reveal_type(x) # N: Revealed type is "builtins.str"
else:
reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]"
[builtins fixtures/narrowing.pyi]
[typing fixtures/typing-medium.pyi]

[case testTypeNarrowingUnionStringFloat]
from typing import Union
def foobar(foo: Union[str, float]):
if foo in ['a', 'b']:
reveal_type(foo) # N: Revealed type is "builtins.str"
else:
reveal_type(foo) # N: Revealed type is "Union[builtins.str, builtins.float]"
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-medium.pyi]
9 changes: 8 additions & 1 deletion test-data/unit/fixtures/narrowing.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Builtins stub used in check-narrowing test cases.
from typing import Generic, Sequence, Tuple, Type, TypeVar, Union
from typing import Generic, Sequence, Tuple, Type, TypeVar, Union, Iterable


Tco = TypeVar('Tco', covariant=True)
Expand All @@ -15,6 +15,13 @@ class function: pass
class ellipsis: pass
class int: pass
class str: pass
class float: pass
class dict(Generic[KT, VT]): pass

def isinstance(x: object, t: Union[Type[object], Tuple[Type[object], ...]]) -> bool: pass

class list(Sequence[Tco]):
def __contains__(self, other: object) -> bool: pass
class set(Iterable[Tco], Generic[Tco]):
def __init__(self, iterable: Iterable[Tco] = ...) -> None: ...
def __contains__(self, item: object) -> bool: pass

0 comments on commit ed0cd4a

Please sign in to comment.