diff --git a/packages/pyright-internal/src/analyzer/typeGuards.ts b/packages/pyright-internal/src/analyzer/typeGuards.ts index a765c008f472..a2dc7cb7c980 100644 --- a/packages/pyright-internal/src/analyzer/typeGuards.ts +++ b/packages/pyright-internal/src/analyzer/typeGuards.ts @@ -41,6 +41,7 @@ import { FunctionParam, FunctionParamFlags, FunctionType, + FunctionTypeFlags, isAnyOrUnknown, isClass, isClassInstance, @@ -1695,8 +1696,32 @@ function narrowTypeForInstance( } else { filteredTypes.push(convertToInstance(filterType)); } - } else if (evaluator.assignType(convertToInstance(convertVarTypeToFree(concreteFilterType)), varType)) { - filteredTypes.push(convertToInstance(varType)); + } else { + const filterTypeInstance = convertToInstance(convertVarTypeToFree(concreteFilterType)); + if (evaluator.assignType(filterTypeInstance, varType)) { + filteredTypes.push(convertToInstance(varType)); + } else { + // If this is a class instance that's not callable and it's not @final, + // a subclass could be compatible with the filter type. + if (isClassInstance(filterTypeInstance) && !ClassType.isFinal(filterTypeInstance)) { + const gradualFunc = FunctionType.createSynthesizedInstance( + '', + FunctionTypeFlags.GradualCallableForm + ); + FunctionType.addDefaultParams(gradualFunc); + + // If the class is callable (i.e. can be assigned to the generic gradual + // function signature), then the assignment check above didn't fail because + // of a signature mismatch. It failed because the class is not callable. + // We assume therefore that a subclass might be. + if (!evaluator.assignType(gradualFunc, filterTypeInstance)) { + // The resulting type should be an intersection of the filter type and + // the subtype, but we don't have a way to encode that yet. For now, + // we'll use the filter type. + filteredTypes.push(convertToInstance(filterType)); + } + } + } } } } else { diff --git a/packages/pyright-internal/src/tests/samples/typeNarrowingIsinstance5.py b/packages/pyright-internal/src/tests/samples/typeNarrowingIsinstance5.py index d910315348a7..4981b135fd62 100644 --- a/packages/pyright-internal/src/tests/samples/typeNarrowingIsinstance5.py +++ b/packages/pyright-internal/src/tests/samples/typeNarrowingIsinstance5.py @@ -1,24 +1,20 @@ # This sample tests isinstance type narrowing when the class list # includes "Callable". -from typing import Callable, Sequence, TypeVar +from typing import Callable, Sequence, TypeVar, final -class A: - ... +class A: ... class B: - def __call__(self, x: str) -> int: - ... + def __call__(self, x: str) -> int: ... -class C: - ... +class C: ... -class D(C): - ... +class D(C): ... TCall1 = TypeVar("TCall1", bound=Callable[..., int]) @@ -30,7 +26,7 @@ def func1( if isinstance(obj, (Callable, Sequence, C)): reveal_type( obj, - expected_text="((int, str) -> int) | list[int] | B | C | D | TCall1@func1", + expected_text="((int, str) -> int) | Sequence[Unknown] | C | list[int] | B | D | TCall1@func1", ) else: reveal_type(obj, expected_text="A") @@ -38,4 +34,33 @@ def func1( if isinstance(obj, Callable): reveal_type(obj, expected_text="((int, str) -> int) | B | TCall1@func1") else: - reveal_type(obj, expected_text="list[int] | C | D | A") + reveal_type(obj, expected_text="Sequence[Unknown] | C | list[int] | D | A") + + +class CB1: + def __call__(self, x: str) -> None: ... + + +def func2(c1: Callable[[int], None], c2: Callable[..., None]): + if isinstance(c1, CB1): + reveal_type(c1, expected_text="Never") + + if isinstance(c2, CB1): + reveal_type(c2, expected_text="CB1") + + +class IsNotFinal: ... + + +def func3(c1: Callable[[int], None]): + if isinstance(c1, IsNotFinal): + reveal_type(c1, expected_text="IsNotFinal") + + +@final +class IsFinal: ... + + +def func4(c1: Callable[[int], None]): + if isinstance(c1, IsFinal): + reveal_type(c1, expected_text="Never")