Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion packages/pyright-internal/src/analyzer/checker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6635,6 +6635,7 @@ export class Checker extends ParseTreeWalker {
}

const baseClass = baseClassAndSymbol.classType;
const baseClassSelf = ClassType.cloneAsInstance(selfSpecializeClass(baseClass, { useBoundTypeVars: true }));
const childClassSelf = ClassType.cloneAsInstance(
selfSpecializeClass(childClassType, { useBoundTypeVars: true })
);
Expand All @@ -6643,7 +6644,7 @@ export class Checker extends ParseTreeWalker {
this._evaluator.getEffectiveTypeOfSymbol(baseClassAndSymbol.symbol),
baseClass,
this._evaluator.getTypeClassType(),
childClassSelf
baseClassSelf
);

overrideType = partiallySpecializeType(
Expand Down
93 changes: 35 additions & 58 deletions packages/pyright-internal/src/tests/samples/methodOverride2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,112 +6,89 @@


class Base1:
def f1(self, *, kwarg0: int) -> None:
...
def f1(self, *, kwarg0: int) -> None: ...

def f2(self, *, kwarg0: int) -> None:
...
def f2(self, *, kwarg0: int) -> None: ...

def f3(self, *, kwarg0: int) -> None:
...
def f3(self, *, kwarg0: int) -> None: ...

def f4(self, *, kwarg0: int) -> None:
...
def f4(self, *, kwarg0: int) -> None: ...

def g1(self, a: int, /, b: str, *, kwarg0: int) -> None:
...
def g1(self, a: int, /, b: str, *, kwarg0: int) -> None: ...

def g2(self, a: int, /, b: str, *, kwarg0: int) -> None:
...
def g2(self, a: int, /, b: str, *, kwarg0: int) -> None: ...

def g3(self, a: int, /, b: str, *, kwarg0: int) -> None:
...
def g3(self, a: int, /, b: str, *, kwarg0: int) -> None: ...

def g4(self, a: int, /, b: str, *, kwarg0: int) -> None:
...
def g4(self, a: int, /, b: str, *, kwarg0: int) -> None: ...

def g5(self, a: int, /, b: str, *, kwarg0: int) -> None:
...
def g5(self, a: int, /, b: str, *, kwarg0: int) -> None: ...

def g6(self, a: int, /, b: str, *, kwarg0: int) -> None:
...
def g6(self, a: int, /, b: str, *, kwarg0: int) -> None: ...

def h1(self, a: int, *args: int) -> None:
...
def h1(self, a: int, *args: int) -> None: ...


class Derived1(Base1):
def f1(self, arg0: int = 0, *, kwarg0: int, kwarg1: int = 0) -> None:
...
def f1(self, arg0: int = 0, *, kwarg0: int, kwarg1: int = 0) -> None: ...

# This should generate an error because of a positional parameter mismatch.
def f2(self, arg0: int, *, kwarg0: int, kwarg1: int = 0) -> None:
...
def f2(self, arg0: int, *, kwarg0: int, kwarg1: int = 0) -> None: ...

# This should generate an error because of a missing kwarg1.
def f3(self, arg0: int = 0, *, kwarg0: int, kwarg1: int) -> None:
...
def f3(self, arg0: int = 0, *, kwarg0: int, kwarg1: int) -> None: ...

# This should generate an error because kwarg0 is the wrong type.
def f4(self, arg0: int = 0, *kwarg0: str) -> None:
...
def f4(self, arg0: int = 0, *kwarg0: str) -> None: ...

def g1(self, xxx: int, /, b: str, *, kwarg0: int) -> None:
...
def g1(self, xxx: int, /, b: str, *, kwarg0: int) -> None: ...

def g2(self, __a: int, b: str, *, kwarg0: int) -> None:
...
def g2(self, __a: int, b: str, *, kwarg0: int) -> None: ...

# This should generate an error because of a name mismatch between b and c.
def g3(self, __a: int, c: str, *, kwarg0: int) -> None:
...
def g3(self, __a: int, c: str, *, kwarg0: int) -> None: ...

# This should generate an error because of a type mismatch for b.
def g4(self, __a: int, b: int, *, kwarg0: int) -> None:
...
def g4(self, __a: int, b: int, *, kwarg0: int) -> None: ...

def g5(self, __a: int, b: str = "hi", *, kwarg0: int) -> None:
...
def g5(self, __a: int, b: str = "hi", *, kwarg0: int) -> None: ...

def g6(self, __a: int, b: str, c: str = "hi", *, kwarg0: int) -> None:
...
def g6(self, __a: int, b: str, c: str = "hi", *, kwarg0: int) -> None: ...


P = ParamSpec("P")
R = TypeVar("R")


class Base2(Generic[P, R]):
def method1(self, *args: P.args, **kwargs: P.kwargs) -> R:
...
def method1(self, *args: P.args, **kwargs: P.kwargs) -> R: ...

def method2(self, *args: P.args, **kwargs: P.kwargs) -> R:
...
def method2(self, *args: P.args, **kwargs: P.kwargs) -> R: ...


class Derived2(Base2[P, R]):
def method1(self, *args: P.args, **kwargs: P.kwargs) -> R:
...
def method1(self, *args: P.args, **kwargs: P.kwargs) -> R: ...

def method2(self, *args: Any, **kwargs: Any) -> R:
...
def method2(self, *args: Any, **kwargs: Any) -> R: ...


T = TypeVar("T")


class Base3:
def method1(self, x: Self) -> Self:
...
def method1(self, x: Self) -> Self: ...

def method2(self, x: Self) -> Self:
...
def method2(self, x: Self) -> Self: ...

def method3(self, x: Self) -> Self: ...


class Derived3(Generic[T], Base3):
def method1(self, x: "Derived3[T]") -> "Derived3[T]":
...
# This should generate an error.
def method1(self, x: "Derived3[T]") -> "Derived3[T]": ...

# This should generate an error.
def method2(self, x: "Derived3[int]") -> "Derived3[int]": ...

# This should generate an error.
def method2(self, x: "Derived3[int]") -> "Derived3[int]":
...
def method3(self, x: Self) -> Self: ...
2 changes: 1 addition & 1 deletion packages/pyright-internal/src/tests/typeEvaluator3.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,7 @@ test('MethodOverride2', () => {

configOptions.diagnosticRuleSet.reportIncompatibleMethodOverride = 'error';
analysisResults = TestUtils.typeAnalyzeSampleFiles(['methodOverride2.py'], configOptions);
TestUtils.validateResults(analysisResults, 6);
TestUtils.validateResults(analysisResults, 8);
});

test('MethodOverride3', () => {
Expand Down