diff --git a/packages/pyright-internal/src/analyzer/typeGuards.ts b/packages/pyright-internal/src/analyzer/typeGuards.ts index b3e1ace9d2dc..de9d4b6b35fd 100644 --- a/packages/pyright-internal/src/analyzer/typeGuards.ts +++ b/packages/pyright-internal/src/analyzer/typeGuards.ts @@ -2060,7 +2060,7 @@ export function getElementTypeForContainerNarrowing(containerType: Type) { export function narrowTypeForContainerElementType(evaluator: TypeEvaluator, referenceType: Type, elementType: Type) { return evaluator.mapSubtypesExpandTypeVars(referenceType, /* options */ undefined, (referenceSubtype) => { return mapSubtypes(elementType, (elementSubtype) => { - if (isAnyOrUnknown(referenceSubtype)) { + if (isAnyOrUnknown(elementSubtype)) { return referenceSubtype; } diff --git a/packages/pyright-internal/src/tests/samples/typeNarrowingIn1.py b/packages/pyright-internal/src/tests/samples/typeNarrowingIn1.py index a7a670a086ec..51f4f5363597 100644 --- a/packages/pyright-internal/src/tests/samples/typeNarrowingIn1.py +++ b/packages/pyright-internal/src/tests/samples/typeNarrowingIn1.py @@ -1,40 +1,22 @@ # This sample tests type narrowing for the "in" operator. -from typing import Callable, Generic, Literal, ParamSpec, TypeVar, TypedDict +from typing import Any, Callable, Generic, Literal, ParamSpec, TypeVar, TypedDict import random -def verify_str(p: str) -> None: ... - - -def verify_int(p: int) -> None: ... - - -def verify_none(p: None) -> None: ... - - -x: str | None -y: int | str -if random.random() < 0.5: - x = None - y = 1 -else: - x = "2" - y = "2" - -if x in ["2"]: - verify_str(x) - - # This should generate an error because x should - # be narrowed to a str. - verify_none(x) +def func0(x: str | None, y: int | str): + if random.random() < 0.5: + x = None + y = 1 + else: + x = "2" + y = "2" -if y in [2]: - verify_int(y) + if x in ["2"]: + reveal_type(x, expected_text="Literal['2']") - # This should generate an error because y should - # be narrowed to an int. - verify_str(y) + if y in [1]: + reveal_type(y, expected_text="Literal[1]") def func1(x: int | str | None, y: Literal[1, 2, "b"], b: int): @@ -184,3 +166,13 @@ def func13(x: type[T13]) -> type[T13]: reveal_type(x, expected_text="type[str]* | type[int]* | type[float]*") return x + + +def func14(x: str, y: dict[Any, Any]): + if x in y: + reveal_type(x, expected_text="str") + + +def func15(x: Any, y: dict[str, str]): + if x in y: + reveal_type(x, expected_text="str") diff --git a/packages/pyright-internal/src/tests/typeEvaluator1.test.ts b/packages/pyright-internal/src/tests/typeEvaluator1.test.ts index ac60a3013972..a2e2f0c06038 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator1.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator1.test.ts @@ -496,7 +496,7 @@ test('TypeNarrowingTupleLength1', () => { test('TypeNarrowingIn1', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typeNarrowingIn1.py']); - TestUtils.validateResults(analysisResults, 2); + TestUtils.validateResults(analysisResults, 0); }); test('TypeNarrowingIn2', () => {