diff --git a/packages/pyright-internal/src/analyzer/typeGuards.ts b/packages/pyright-internal/src/analyzer/typeGuards.ts index 45fc7f2fc80e..a765c008f472 100644 --- a/packages/pyright-internal/src/analyzer/typeGuards.ts +++ b/packages/pyright-internal/src/analyzer/typeGuards.ts @@ -2514,31 +2514,41 @@ function narrowTypeForLiteralComparison( } return subtype; - } else if (isClassInstance(subtype) && ClassType.isSameGenericClass(literalType, subtype)) { + } + + if (isClassInstance(subtype) && ClassType.isSameGenericClass(literalType, subtype)) { if (subtype.priv.literalValue !== undefined) { const literalValueMatches = ClassType.isLiteralValueSame(subtype, literalType); if (isPositiveTest) { return literalValueMatches ? subtype : undefined; - } else { - const isEnumOrBool = ClassType.isEnumClass(literalType) || ClassType.isBuiltIn(literalType, 'bool'); - - // For negative tests, we can eliminate the literal value if it doesn't match, - // but only for equality tests or for 'is' tests that involve enums or bools. - return literalValueMatches && (isEnumOrBool || !isIsOperator) ? undefined : subtype; } - } else if (isPositiveTest) { + + const isEnumOrBool = ClassType.isEnumClass(literalType) || ClassType.isBuiltIn(literalType, 'bool'); + + // For negative tests, we can eliminate the literal value if it doesn't match, + // but only for equality tests or for 'is' tests that involve enums or bools. + return literalValueMatches && (isEnumOrBool || !isIsOperator) ? undefined : subtype; + } + + if (isPositiveTest) { return literalType; - } else { - // If we're able to enumerate all possible literal values - // (for bool or enum), we can eliminate all others in a negative test. - const allLiteralTypes = enumerateLiteralsForType(evaluator, subtype); - if (allLiteralTypes && allLiteralTypes.length > 0) { - return combineTypes( - allLiteralTypes.filter((type) => !ClassType.isLiteralValueSame(type, literalType)) - ); - } } - } else if (isPositiveTest) { + + // If we're able to enumerate all possible literal values + // (for bool or enum), we can eliminate all others in a negative test. + const allLiteralTypes = enumerateLiteralsForType(evaluator, subtype); + if (allLiteralTypes && allLiteralTypes.length > 0) { + return combineTypes(allLiteralTypes.filter((type) => !ClassType.isLiteralValueSame(type, literalType))); + } + + return subtype; + } + + if (isPositiveTest) { + if (isClassInstance(subtype) && ClassType.isBuiltIn(subtype, 'LiteralString')) { + return literalType; + } + if (isIsOperator || isNoneInstance(subtype)) { const isSubtype = evaluator.assignType(subtype, literalType); return isSubtype ? literalType : undefined; diff --git a/packages/pyright-internal/src/tests/samples/typeNarrowingLiteral1.py b/packages/pyright-internal/src/tests/samples/typeNarrowingLiteral1.py index fffe7251350e..a0e90dc2ac62 100644 --- a/packages/pyright-internal/src/tests/samples/typeNarrowingLiteral1.py +++ b/packages/pyright-internal/src/tests/samples/typeNarrowingLiteral1.py @@ -1,7 +1,7 @@ # This sample tests the type analyzer's type narrowing # logic for literals. -from typing import Literal, TypeVar, Union +from typing import Literal, LiteralString, TypeVar, Union def func1(p1: Literal["a", "b", "c"]): @@ -52,3 +52,10 @@ def func5(x: S) -> S: else: reveal_type(x, expected_text="Literal['b']") return x + + +def func6(x: LiteralString): + if x == "a": + reveal_type(x, expected_text="Literal['a']") + else: + reveal_type(x, expected_text="LiteralString")