Skip to content

Commit 4e0df91

Browse files
committed
enabled based overload consistency checks
1 parent a309b68 commit 4e0df91

File tree

12 files changed

+77
-27
lines changed

12 files changed

+77
-27
lines changed

packages/pyright-internal/src/analyzer/checker.ts

Lines changed: 63 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ import {
207207
partiallySpecializeType,
208208
selfSpecializeClass,
209209
transformPossibleRecursiveTypeAlias,
210+
someSubtypes,
210211
} from './typeUtils';
211212

212213
interface TypeVarUsageInfo {
@@ -2830,25 +2831,28 @@ export class Checker extends ParseTreeWalker {
28302831
constraints
28312832
);
28322833

2833-
const returnDiag = new DiagnosticAddendum();
2834-
if (
2835-
!isNever(overloadReturnType) &&
2836-
!this._evaluator.assignType(
2834+
// based overload consistency part 1: The implementation return type must overlap with each overload
2835+
// when distributing unions (i.e., at least one member of the impl union is assignable to the overload).
2836+
if (!isNever(overloadReturnType)) {
2837+
const returnDiag = new DiagnosticAddendum();
2838+
// Succeeds if any union member of the implementation overlaps with the overload return type.
2839+
const ok = someSubtypes(
28372840
implReturnType,
2838-
overloadReturnType,
2839-
returnDiag.createAddendum(),
2840-
constraints,
2841-
AssignTypeFlags.Default
2842-
)
2843-
) {
2844-
returnDiag.addMessage(
2845-
LocAddendum.functionReturnTypeMismatch().format({
2846-
sourceType: this._evaluator.printType(overloadReturnType),
2847-
destType: this._evaluator.printType(implReturnType),
2848-
})
2841+
(sub) =>
2842+
this._evaluator.assignType(overloadReturnType, sub) ||
2843+
this._evaluator.assignType(sub, overloadReturnType)
28492844
);
2850-
diag?.addAddendum(returnDiag);
2851-
isConsistent = false;
2845+
2846+
if (!ok) {
2847+
returnDiag.addMessage(
2848+
LocAddendum.functionReturnTypeMismatch().format({
2849+
sourceType: this._evaluator.printType(implReturnType),
2850+
destType: this._evaluator.printType(overloadReturnType),
2851+
})
2852+
);
2853+
diag?.addAddendum(returnDiag);
2854+
isConsistent = false;
2855+
}
28522856
}
28532857

28542858
return isConsistent;
@@ -3307,6 +3311,48 @@ export class Checker extends ParseTreeWalker {
33073311
}
33083312
}
33093313
});
3314+
3315+
// based overload consistency part 2: Implementation return type must be a subtype of the union of overload return types.
3316+
if (
3317+
this._importResolver.getConfigOptions().strictOverloadConsistency === 'strict' &&
3318+
implementation &&
3319+
isFunction(implementation)
3320+
) {
3321+
const implNode = implementation.shared.declaration?.node?.parent;
3322+
let implBound = implementation;
3323+
if (implNode) {
3324+
const liveScopeIds = ParseTreeUtils.getTypeVarScopesForNode(implNode);
3325+
implBound = makeTypeVarsBound(implementation, liveScopeIds);
3326+
}
3327+
3328+
const implReturnType =
3329+
FunctionType.getEffectiveReturnType(implBound) ?? this._evaluator.getInferredReturnType(implBound);
3330+
3331+
const mappedReturnUnion = combineTypes(
3332+
OverloadedType.getOverloads(type).map((overloadType) => {
3333+
const result =
3334+
FunctionType.getEffectiveReturnType(overloadType) ??
3335+
this._evaluator.getInferredReturnType(overloadType);
3336+
// special case CoroutineType, as it's a known instance of a "single covariant" type parameter
3337+
// see: https://github.com/DetachHead/basedpyright/issues/1523
3338+
if (isClass(result) && result.shared.fullName === 'types.CoroutineType') {
3339+
return result.shared.typeParams[2];
3340+
}
3341+
return result;
3342+
})
3343+
);
3344+
3345+
const extraDiag = new DiagnosticAddendum();
3346+
const isAssignable = this._evaluator.assignType(mappedReturnUnion, implReturnType, extraDiag);
3347+
3348+
if (!isAssignable && implementation.shared.declaration) {
3349+
this._evaluator.addDiagnostic(
3350+
DiagnosticRule.reportInconsistentOverload,
3351+
LocMessage.overloadImplementationTooWide() + extraDiag.getString(),
3352+
implementation.shared.declaration.node.d.name
3353+
);
3354+
}
3355+
}
33103356
}
33113357

33123358
private _reportFinalInLoop(symbol: Symbol) {

packages/pyright-internal/src/common/configOptions.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1469,6 +1469,8 @@ export class ConfigOptions {
14691469
// Overrides the default timeout for file enumeration operations.
14701470
fileEnumerationTimeoutInSec?: number;
14711471

1472+
strictOverloadConsistency: 'allow_wide' | 'strict' = 'strict';
1473+
14721474
// https://github.com/microsoft/TypeScript/issues/3841
14731475
declare ['constructor']: typeof ConfigOptions;
14741476

packages/pyright-internal/src/localization/localize.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,6 +775,7 @@ export namespace Localizer {
775775
new ParameterizedString<{ name: string; index: number }>(
776776
getRawString('Diagnostic.overloadImplementationMismatch')
777777
);
778+
export const overloadImplementationTooWide = () => getRawString('Diagnostic.overloadImplementationTooWide');
778779
export const overloadOverrideImpl = () => getRawString('Diagnostic.overloadOverrideImpl');
779780
export const overloadOverrideNoImpl = () => getRawString('Diagnostic.overloadOverrideNoImpl');
780781
export const overloadReturnTypeMismatch = () =>

packages/pyright-internal/src/localization/package.nls.en-us.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -957,6 +957,7 @@
957957
"comment": "{Locked='@final'}"
958958
},
959959
"overloadImplementationMismatch": "Overloaded implementation is not consistent with signature of overload {index}",
960+
"overloadImplementationTooWide": "Implementation return type is too wide",
960961
"overloadOverrideImpl": {
961962
"message": "@override decorator should be applied only to the implementation",
962963
"comment": "{Locked='@override'}"

packages/pyright-internal/src/tests/samples/decorator2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ def atomic(*, savepoint: bool = True) -> Callable[[F], F]: ...
1414

1515

1616
def atomic(
17-
__func: Optional[Callable[..., None]] = None, *, savepoint: bool = True
18-
) -> Union[Callable[[], None], Callable[[F], F]]: ...
17+
__func: F | None = None, *, savepoint: bool = True
18+
) -> Union[F, Callable[[F], F]]: ...
1919

2020

2121
@atomic

packages/pyright-internal/src/tests/samples/methodOverride6.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def m1(self, x: bool) -> int: ...
4242
@overload
4343
def m1(self, x: str) -> str: ...
4444

45-
def m1(self, x: bool | str) -> int | float | str:
45+
def m1(self, x: bool | str) -> int | str:
4646
return x
4747

4848

@@ -55,7 +55,7 @@ def m1(self, x: bool) -> int: ...
5555

5656
# This should generate an error because the overloads are
5757
# in the wrong order.
58-
def m1(self, x: bool | str) -> int | float | str:
58+
def m1(self, x: bool | str) -> int | str:
5959
return x
6060

6161

packages/pyright-internal/src/tests/samples/overload2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def deco2(
7474

7575
def deco2(
7676
x: Callable[[], T | None] = lambda: None,
77-
) -> Callable[[Callable[P, T]], Callable[P, T | None]]: ...
77+
) -> Callable[[Callable[P, T]], Callable[P, T]] | Callable[[Callable[P, T]], Callable[P, T | None]]: ...
7878

7979

8080
@deco2(x=dict)

packages/pyright-internal/src/tests/samples/overloadCall4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def overloaded1(x: A) -> str: ...
2525
def overloaded1(x: _T1) -> _T1: ...
2626

2727

28-
def overloaded1(x: A | B) -> str | B: ...
28+
def overloaded1[T: B](x: A | T) -> str | T: ...
2929

3030

3131
def func1(a: A | B, b: A | B | C):

packages/pyright-internal/src/tests/samples/overloadCall5.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def func1(__iter1: Iterable[_T1], __iter2: Iterable[_T2]) -> Tuple[_T1, _T2]: ..
2222
def func1(*iterables: Iterable[_T1]) -> float: ...
2323

2424

25-
def func1(*iterables: Iterable[_T1 | _T2]) -> Tuple[_T1 | _T2, ...] | float: ...
25+
def func1(*iterables: Iterable[_T1 | _T2]) -> Tuple[_T1 | _T2, ...] | float: ... # pyright: ignore the too wide error
2626

2727

2828
def test1(x: Iterable[int]):

packages/pyright-internal/src/tests/samples/overloadOverlap1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def func20(choices: AnyStr) -> AnyStr: ...
326326
def func20(choices: AllStr) -> AllStr: ...
327327

328328

329-
def func20(choices: AllStr) -> AllStr: ...
329+
def func20(choices: AnyStr | AllStr) -> AnyStr | AllStr: ...
330330

331331

332332
# This should generate an overlapping overload error.

0 commit comments

Comments
 (0)