diff --git a/packages/pyright-internal/src/analyzer/checker.ts b/packages/pyright-internal/src/analyzer/checker.ts index e5b7282e8a49..089574334a93 100644 --- a/packages/pyright-internal/src/analyzer/checker.ts +++ b/packages/pyright-internal/src/analyzer/checker.ts @@ -6635,6 +6635,7 @@ export class Checker extends ParseTreeWalker { } const baseClass = baseClassAndSymbol.classType; + const baseClassSelf = ClassType.cloneAsInstance(selfSpecializeClass(baseClass, { useBoundTypeVars: true })); const childClassSelf = ClassType.cloneAsInstance( selfSpecializeClass(childClassType, { useBoundTypeVars: true }) ); @@ -6643,7 +6644,7 @@ export class Checker extends ParseTreeWalker { this._evaluator.getEffectiveTypeOfSymbol(baseClassAndSymbol.symbol), baseClass, this._evaluator.getTypeClassType(), - childClassSelf + baseClassSelf ); overrideType = partiallySpecializeType( diff --git a/packages/pyright-internal/src/tests/samples/methodOverride2.py b/packages/pyright-internal/src/tests/samples/methodOverride2.py index ad9b4671da4d..a557a44a340a 100644 --- a/packages/pyright-internal/src/tests/samples/methodOverride2.py +++ b/packages/pyright-internal/src/tests/samples/methodOverride2.py @@ -6,75 +6,54 @@ class Base1: - def f1(self, *, kwarg0: int) -> None: - ... + def f1(self, *, kwarg0: int) -> None: ... - def f2(self, *, kwarg0: int) -> None: - ... + def f2(self, *, kwarg0: int) -> None: ... - def f3(self, *, kwarg0: int) -> None: - ... + def f3(self, *, kwarg0: int) -> None: ... - def f4(self, *, kwarg0: int) -> None: - ... + def f4(self, *, kwarg0: int) -> None: ... - def g1(self, a: int, /, b: str, *, kwarg0: int) -> None: - ... + def g1(self, a: int, /, b: str, *, kwarg0: int) -> None: ... - def g2(self, a: int, /, b: str, *, kwarg0: int) -> None: - ... + def g2(self, a: int, /, b: str, *, kwarg0: int) -> None: ... - def g3(self, a: int, /, b: str, *, kwarg0: int) -> None: - ... + def g3(self, a: int, /, b: str, *, kwarg0: int) -> None: ... - def g4(self, a: int, /, b: str, *, kwarg0: int) -> None: - ... + def g4(self, a: int, /, b: str, *, kwarg0: int) -> None: ... - def g5(self, a: int, /, b: str, *, kwarg0: int) -> None: - ... + def g5(self, a: int, /, b: str, *, kwarg0: int) -> None: ... - def g6(self, a: int, /, b: str, *, kwarg0: int) -> None: - ... + def g6(self, a: int, /, b: str, *, kwarg0: int) -> None: ... - def h1(self, a: int, *args: int) -> None: - ... + def h1(self, a: int, *args: int) -> None: ... class Derived1(Base1): - def f1(self, arg0: int = 0, *, kwarg0: int, kwarg1: int = 0) -> None: - ... + def f1(self, arg0: int = 0, *, kwarg0: int, kwarg1: int = 0) -> None: ... # This should generate an error because of a positional parameter mismatch. - def f2(self, arg0: int, *, kwarg0: int, kwarg1: int = 0) -> None: - ... + def f2(self, arg0: int, *, kwarg0: int, kwarg1: int = 0) -> None: ... # This should generate an error because of a missing kwarg1. - def f3(self, arg0: int = 0, *, kwarg0: int, kwarg1: int) -> None: - ... + def f3(self, arg0: int = 0, *, kwarg0: int, kwarg1: int) -> None: ... # This should generate an error because kwarg0 is the wrong type. - def f4(self, arg0: int = 0, *kwarg0: str) -> None: - ... + def f4(self, arg0: int = 0, *kwarg0: str) -> None: ... - def g1(self, xxx: int, /, b: str, *, kwarg0: int) -> None: - ... + def g1(self, xxx: int, /, b: str, *, kwarg0: int) -> None: ... - def g2(self, __a: int, b: str, *, kwarg0: int) -> None: - ... + def g2(self, __a: int, b: str, *, kwarg0: int) -> None: ... # This should generate an error because of a name mismatch between b and c. - def g3(self, __a: int, c: str, *, kwarg0: int) -> None: - ... + def g3(self, __a: int, c: str, *, kwarg0: int) -> None: ... # This should generate an error because of a type mismatch for b. - def g4(self, __a: int, b: int, *, kwarg0: int) -> None: - ... + def g4(self, __a: int, b: int, *, kwarg0: int) -> None: ... - def g5(self, __a: int, b: str = "hi", *, kwarg0: int) -> None: - ... + def g5(self, __a: int, b: str = "hi", *, kwarg0: int) -> None: ... - def g6(self, __a: int, b: str, c: str = "hi", *, kwarg0: int) -> None: - ... + def g6(self, __a: int, b: str, c: str = "hi", *, kwarg0: int) -> None: ... P = ParamSpec("P") @@ -82,36 +61,34 @@ def g6(self, __a: int, b: str, c: str = "hi", *, kwarg0: int) -> None: class Base2(Generic[P, R]): - def method1(self, *args: P.args, **kwargs: P.kwargs) -> R: - ... + def method1(self, *args: P.args, **kwargs: P.kwargs) -> R: ... - def method2(self, *args: P.args, **kwargs: P.kwargs) -> R: - ... + def method2(self, *args: P.args, **kwargs: P.kwargs) -> R: ... class Derived2(Base2[P, R]): - def method1(self, *args: P.args, **kwargs: P.kwargs) -> R: - ... + def method1(self, *args: P.args, **kwargs: P.kwargs) -> R: ... - def method2(self, *args: Any, **kwargs: Any) -> R: - ... + def method2(self, *args: Any, **kwargs: Any) -> R: ... T = TypeVar("T") class Base3: - def method1(self, x: Self) -> Self: - ... + def method1(self, x: Self) -> Self: ... - def method2(self, x: Self) -> Self: - ... + def method2(self, x: Self) -> Self: ... + + def method3(self, x: Self) -> Self: ... class Derived3(Generic[T], Base3): - def method1(self, x: "Derived3[T]") -> "Derived3[T]": - ... + # This should generate an error. + def method1(self, x: "Derived3[T]") -> "Derived3[T]": ... + + # This should generate an error. + def method2(self, x: "Derived3[int]") -> "Derived3[int]": ... # This should generate an error. - def method2(self, x: "Derived3[int]") -> "Derived3[int]": - ... + def method3(self, x: Self) -> Self: ... diff --git a/packages/pyright-internal/src/tests/typeEvaluator3.test.ts b/packages/pyright-internal/src/tests/typeEvaluator3.test.ts index 62c9effcd640..191a039870de 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator3.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator3.test.ts @@ -958,7 +958,7 @@ test('MethodOverride2', () => { configOptions.diagnosticRuleSet.reportIncompatibleMethodOverride = 'error'; analysisResults = TestUtils.typeAnalyzeSampleFiles(['methodOverride2.py'], configOptions); - TestUtils.validateResults(analysisResults, 6); + TestUtils.validateResults(analysisResults, 8); }); test('MethodOverride3', () => {