diff --git a/packages/pyright-internal/src/analyzer/checker.ts b/packages/pyright-internal/src/analyzer/checker.ts index aa59052e5a..2aaa4542b5 100644 --- a/packages/pyright-internal/src/analyzer/checker.ts +++ b/packages/pyright-internal/src/analyzer/checker.ts @@ -207,6 +207,7 @@ import { partiallySpecializeType, selfSpecializeClass, transformPossibleRecursiveTypeAlias, + someSubtypes, } from './typeUtils'; interface TypeVarUsageInfo { @@ -2830,25 +2831,28 @@ export class Checker extends ParseTreeWalker { constraints ); - const returnDiag = new DiagnosticAddendum(); - if ( - !isNever(overloadReturnType) && - !this._evaluator.assignType( + // based overload consistency part 1: The implementation return type must overlap with each overload + // when distributing unions (i.e., at least one member of the impl union is assignable to the overload). + if (!isNever(overloadReturnType)) { + const returnDiag = new DiagnosticAddendum(); + // Succeeds if any union member of the implementation overlaps with the overload return type. + const ok = someSubtypes( implReturnType, - overloadReturnType, - returnDiag.createAddendum(), - constraints, - AssignTypeFlags.Default - ) - ) { - returnDiag.addMessage( - LocAddendum.functionReturnTypeMismatch().format({ - sourceType: this._evaluator.printType(overloadReturnType), - destType: this._evaluator.printType(implReturnType), - }) + (sub) => + this._evaluator.assignType(overloadReturnType, sub) || + this._evaluator.assignType(sub, overloadReturnType) ); - diag?.addAddendum(returnDiag); - isConsistent = false; + + if (!ok) { + returnDiag.addMessage( + LocAddendum.functionReturnTypeMismatch().format({ + sourceType: this._evaluator.printType(implReturnType), + destType: this._evaluator.printType(overloadReturnType), + }) + ); + diag?.addAddendum(returnDiag); + isConsistent = false; + } } return isConsistent; @@ -3307,6 +3311,48 @@ export class Checker extends ParseTreeWalker { } } }); + + // based overload consistency part 2: Implementation return type must be a subtype of the union of overload return types. + if ( + this._importResolver.getConfigOptions().strictOverloadConsistency && + implementation && + isFunction(implementation) + ) { + const implNode = implementation.shared.declaration?.node?.parent; + let implBound = implementation; + if (implNode) { + const liveScopeIds = ParseTreeUtils.getTypeVarScopesForNode(implNode); + implBound = makeTypeVarsBound(implementation, liveScopeIds); + } + + const implReturnType = + FunctionType.getEffectiveReturnType(implBound) ?? this._evaluator.getInferredReturnType(implBound); + + const mappedReturnUnion = combineTypes( + OverloadedType.getOverloads(type).map((overloadType) => { + const result = + FunctionType.getEffectiveReturnType(overloadType) ?? + this._evaluator.getInferredReturnType(overloadType); + // special case CoroutineType, as it's a known instance of a "single covariant" type parameter + // see: https://github.com/DetachHead/basedpyright/issues/1523 + if (isClass(result) && result.shared.fullName === 'types.CoroutineType') { + return result.shared.typeParams[2]; + } + return result; + }) + ); + + const extraDiag = new DiagnosticAddendum(); + const isAssignable = this._evaluator.assignType(mappedReturnUnion, implReturnType, extraDiag); + + if (!isAssignable && implementation.shared.declaration) { + this._evaluator.addDiagnostic( + DiagnosticRule.reportInconsistentOverload, + LocMessage.overloadImplementationTooWide() + extraDiag.getString(), + implementation.shared.declaration.node.d.name + ); + } + } } private _reportFinalInLoop(symbol: Symbol) { diff --git a/packages/pyright-internal/src/common/configOptions.ts b/packages/pyright-internal/src/common/configOptions.ts index 0b0a23dad9..d4ebeb3853 100644 --- a/packages/pyright-internal/src/common/configOptions.ts +++ b/packages/pyright-internal/src/common/configOptions.ts @@ -1469,6 +1469,8 @@ export class ConfigOptions { // Overrides the default timeout for file enumeration operations. fileEnumerationTimeoutInSec?: number; + strictOverloadConsistency = true; + // https://github.com/microsoft/TypeScript/issues/3841 declare ['constructor']: typeof ConfigOptions; diff --git a/packages/pyright-internal/src/localization/localize.ts b/packages/pyright-internal/src/localization/localize.ts index 70459d4789..4bd56511bb 100644 --- a/packages/pyright-internal/src/localization/localize.ts +++ b/packages/pyright-internal/src/localization/localize.ts @@ -775,6 +775,7 @@ export namespace Localizer { new ParameterizedString<{ name: string; index: number }>( getRawString('Diagnostic.overloadImplementationMismatch') ); + export const overloadImplementationTooWide = () => getRawString('Diagnostic.overloadImplementationTooWide'); export const overloadOverrideImpl = () => getRawString('Diagnostic.overloadOverrideImpl'); export const overloadOverrideNoImpl = () => getRawString('Diagnostic.overloadOverrideNoImpl'); export const overloadReturnTypeMismatch = () => diff --git a/packages/pyright-internal/src/localization/package.nls.en-us.json b/packages/pyright-internal/src/localization/package.nls.en-us.json index 5424970b39..61beb942ee 100644 --- a/packages/pyright-internal/src/localization/package.nls.en-us.json +++ b/packages/pyright-internal/src/localization/package.nls.en-us.json @@ -957,6 +957,7 @@ "comment": "{Locked='@final'}" }, "overloadImplementationMismatch": "Overloaded implementation is not consistent with signature of overload {index}", + "overloadImplementationTooWide": "Implementation return type is too wide", "overloadOverrideImpl": { "message": "@override decorator should be applied only to the implementation", "comment": "{Locked='@override'}" diff --git a/packages/pyright-internal/src/tests/samples/dataclassConverter1.py b/packages/pyright-internal/src/tests/samples/dataclassConverter1.py index fe0cd36347..01b3ba912d 100644 --- a/packages/pyright-internal/src/tests/samples/dataclassConverter1.py +++ b/packages/pyright-internal/src/tests/samples/dataclassConverter1.py @@ -84,7 +84,7 @@ def overloaded_converter(s: str) -> int: ... def overloaded_converter(s: list[str]) -> int: ... -def overloaded_converter(s: float | str | list[str], *args: str) -> int | float | str: +def overloaded_converter(s: float | str | list[str], *args: str) -> int | str: return 0 @@ -141,8 +141,8 @@ def wrong_converter_overload(s: float) -> str: ... def wrong_converter_overload(s: str) -> str: ... -def wrong_converter_overload(s: float | str) -> int | str: - return 1 +def wrong_converter_overload(s: float | str) -> str: + return "" class Errors(ModelBase): diff --git a/packages/pyright-internal/src/tests/samples/decorator2.py b/packages/pyright-internal/src/tests/samples/decorator2.py index 0d76d98869..a3959fd12d 100644 --- a/packages/pyright-internal/src/tests/samples/decorator2.py +++ b/packages/pyright-internal/src/tests/samples/decorator2.py @@ -14,8 +14,8 @@ def atomic(*, savepoint: bool = True) -> Callable[[F], F]: ... def atomic( - __func: Optional[Callable[..., None]] = None, *, savepoint: bool = True -) -> Union[Callable[[], None], Callable[[F], F]]: ... + __func: F | None = None, *, savepoint: bool = True +) -> Union[F, Callable[[F], F]]: ... @atomic diff --git a/packages/pyright-internal/src/tests/samples/methodOverride6.py b/packages/pyright-internal/src/tests/samples/methodOverride6.py index 20b2bdbfc6..936b9e8011 100644 --- a/packages/pyright-internal/src/tests/samples/methodOverride6.py +++ b/packages/pyright-internal/src/tests/samples/methodOverride6.py @@ -42,7 +42,7 @@ def m1(self, x: bool) -> int: ... @overload def m1(self, x: str) -> str: ... - def m1(self, x: bool | str) -> int | float | str: + def m1(self, x: bool | str) -> int | str: return x @@ -55,7 +55,7 @@ def m1(self, x: bool) -> int: ... # This should generate an error because the overloads are # in the wrong order. - def m1(self, x: bool | str) -> int | float | str: + def m1(self, x: bool | str) -> int | str: return x diff --git a/packages/pyright-internal/src/tests/samples/overload2.py b/packages/pyright-internal/src/tests/samples/overload2.py index dce67ec361..883ed878aa 100644 --- a/packages/pyright-internal/src/tests/samples/overload2.py +++ b/packages/pyright-internal/src/tests/samples/overload2.py @@ -74,7 +74,7 @@ def deco2( def deco2( x: Callable[[], T | None] = lambda: None, -) -> Callable[[Callable[P, T]], Callable[P, T | None]]: ... +) -> Callable[[Callable[P, T]], Callable[P, T]] | Callable[[Callable[P, T]], Callable[P, T | None]]: ... @deco2(x=dict) diff --git a/packages/pyright-internal/src/tests/samples/overloadCall4.py b/packages/pyright-internal/src/tests/samples/overloadCall4.py index 1f1ed998fb..80868341de 100644 --- a/packages/pyright-internal/src/tests/samples/overloadCall4.py +++ b/packages/pyright-internal/src/tests/samples/overloadCall4.py @@ -25,7 +25,7 @@ def overloaded1(x: A) -> str: ... def overloaded1(x: _T1) -> _T1: ... -def overloaded1(x: A | B) -> str | B: ... +def overloaded1[T: B](x: A | T) -> str | T: ... def func1(a: A | B, b: A | B | C): diff --git a/packages/pyright-internal/src/tests/samples/overloadCall5.py b/packages/pyright-internal/src/tests/samples/overloadCall5.py index fb2730a226..73287330f4 100644 --- a/packages/pyright-internal/src/tests/samples/overloadCall5.py +++ b/packages/pyright-internal/src/tests/samples/overloadCall5.py @@ -22,7 +22,7 @@ def func1(__iter1: Iterable[_T1], __iter2: Iterable[_T2]) -> Tuple[_T1, _T2]: .. def func1(*iterables: Iterable[_T1]) -> float: ... -def func1(*iterables: Iterable[_T1 | _T2]) -> Tuple[_T1 | _T2, ...] | float: ... +def func1(*iterables: Iterable[_T1 | _T2]) -> Tuple[_T1 | _T2, ...] | float: ... # pyright: ignore the too wide error def test1(x: Iterable[int]): diff --git a/packages/pyright-internal/src/tests/samples/overloadOverlap1.py b/packages/pyright-internal/src/tests/samples/overloadOverlap1.py index ce8f6876e0..c73a4789d9 100644 --- a/packages/pyright-internal/src/tests/samples/overloadOverlap1.py +++ b/packages/pyright-internal/src/tests/samples/overloadOverlap1.py @@ -326,7 +326,7 @@ def func20(choices: AnyStr) -> AnyStr: ... def func20(choices: AllStr) -> AllStr: ... -def func20(choices: AllStr) -> AllStr: ... +def func20(choices: AnyStr | AllStr) -> AnyStr | AllStr: ... # This should generate an overlapping overload error. diff --git a/packages/pyright-internal/src/tests/samples/typeIs1.py b/packages/pyright-internal/src/tests/samples/typeIs1.py index d98c5854bb..6681f961bc 100644 --- a/packages/pyright-internal/src/tests/samples/typeIs1.py +++ b/packages/pyright-internal/src/tests/samples/typeIs1.py @@ -163,7 +163,7 @@ def func10( ) -> TypeIs[tuple[int, ...]]: ... -def func10(v: tuple[int | str, ...], b: bool = True) -> bool: ... +def func10(v: tuple[int | str, ...], b: bool = True) -> TypeIs[tuple[int, ...]] | TypeIs[tuple[str, ...]]: ... v0 = is_int(int) diff --git a/packages/pyright-internal/src/tests/typeEvaluator6.test.ts b/packages/pyright-internal/src/tests/typeEvaluator6.test.ts index 96270f74cd..f55fbdbfff 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator6.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator6.test.ts @@ -101,7 +101,7 @@ test('OverloadOverride1', () => { test('OverloadImpl1', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['overloadImpl1.py']); - TestUtils.validateResults(analysisResults, 6); + TestUtils.validateResults(analysisResults, 7); }); test('OverloadImpl2', () => {