Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 113 additions & 1 deletion packages/pyright-internal/src/analyzer/typeEvaluator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19180,6 +19180,91 @@ export function createTypeEvaluator(
FunctionType.isConstructorMethod(functionType));
const firstNonClsSelfParamIndex = isFirstParamClsOrSelf ? 1 : 0;

// Infer implementation signature from prior overloads if present.
// Collect prior @overload function types for this symbol.
let overloadParamTypeUnions: (Type | undefined)[] | undefined;
let overloadReturnTypeUnion: Type | undefined;
do {
const scope = ScopeUtils.getScopeForNode(node);
const functionSymbol = scope?.lookUpSymbolRecursive(node.d.name.d.value);
if (!functionSymbol || !functionDecl) break;
const decls = functionSymbol.symbol.getDeclarations();
const declIndex = decls.findIndex((d) => d === functionDecl);
if (declIndex <= 0) break;

const priorOverloadFuncTypes: FunctionType[] = [];
for (let i = 0; i < declIndex; i++) {
const d = decls[i];
if (d.type !== DeclarationType.Function) continue;
const prevInfo = getTypeOfFunction(d.node);
if (!prevInfo) continue;
const prevFunc = prevInfo.functionType;
if (isFunction(prevFunc) && FunctionType.isOverloaded(prevFunc)) {
priorOverloadFuncTypes.push(prevFunc);
}
}

if (priorOverloadFuncTypes.length === 0) break;

// Prepare unions per parameter and for return type.
overloadParamTypeUnions = new Array(node.d.params.length).fill(undefined);
const implTypeParams = functionType.shared.typeParams;

const mapOverloadTypeToImpl = (srcType: Type, prevFunc: FunctionType): Type => {
const prevParams = prevFunc.shared.typeParams;
if (prevParams.length === 0 || implTypeParams.length === 0) {
return srcType;
}
// If counts differ, don't attempt mapping.
if (prevParams.length !== implTypeParams.length) {
return srcType;
}
return mapSubtypes(srcType, (subtype) => {
if (isTypeVar(subtype)) {
const idx = prevParams.findIndex((tv) => tv === subtype);
if (idx >= 0 && idx < implTypeParams.length) {
return implTypeParams[idx];
}
}
return undefined;
});
};

for (const prevFunc of priorOverloadFuncTypes) {
// Allow prior overloads that have the same or fewer params than the implementation.
// We will combine only up to the smaller parameter count.
if (prevFunc.shared.parameters.length > node.d.params.length) {
continue;
}
// Build unions for each param index that exists in the prior overload.
let categoriesMismatch = false;
const count = Math.min(prevFunc.shared.parameters.length, node.d.params.length);
for (let pi = 0; pi < count; pi++) {
// Only combine if param categories align with the implementation's parse tree.
if (prevFunc.shared.parameters[pi].category !== node.d.params[pi].d.category) {
categoriesMismatch = true;
break;
}
const prevParamType = FunctionType.getParamType(prevFunc, pi);
const mapped = mapOverloadTypeToImpl(prevParamType, prevFunc);
overloadParamTypeUnions![pi] = overloadParamTypeUnions![pi]
? combineTypes([overloadParamTypeUnions![pi]!, mapped])
: mapped;
}
if (categoriesMismatch) {
continue;
}

// Combine return types.
if (prevFunc.shared.declaredReturnType) {
const mappedRet = mapOverloadTypeToImpl(prevFunc.shared.declaredReturnType, prevFunc);
overloadReturnTypeUnion = overloadReturnTypeUnion
? combineTypes([overloadReturnTypeUnion, mappedRet])
: mappedRet;
}
}
} while (false);

node.d.params.forEach((param, index) => {
let paramType: Type | undefined;
let annotatedType: Type | undefined;
Expand Down Expand Up @@ -19344,12 +19429,23 @@ export function createTypeEvaluator(

// If there was no annotation for the parameter, infer its type if possible.
let isTypeInferred = false;
let usedOverloadInference = false;
if (!paramTypeNode) {
isTypeInferred = true;
const inferredType = inferParamType(node, functionType.shared.flags, index, containingClassType);
if (inferredType) {
const overloadInferred = overloadParamTypeUnions ? overloadParamTypeUnions[index] : undefined;
if (overloadInferred) {
paramType = overloadInferred;
usedOverloadInference = true;
} else if (inferredType) {
paramType = inferredType;
}

// After inferring from defaults and/or overloads, if there is a default of None
// and no explicit annotation, include None in the inferred implementation type.
if (param.d.defaultValue?.nodeType === ParseNodeType.Constant && param.d.defaultValue.d.constType === KeywordType.None) {
paramType = paramType ? combineTypes([paramType, getNoneType()]) : getNoneType();
}
}

paramType = paramType ?? UnknownType.create();
Expand Down Expand Up @@ -19457,6 +19553,12 @@ export function createTypeEvaluator(
}
}

// If there was no explicit return annotation and this is a source file,
// try to infer from prior overloads' return types.
if (!returnTypeAnnotationNode && !fileInfo.isStubFile && overloadReturnTypeUnion) {
functionType.shared.declaredReturnType = overloadReturnTypeUnion;
}

// Accumulate any type parameters used in the return type.
if (functionType.shared.declaredReturnType && returnTypeAnnotationNode) {
addTypeVarsToListIfUnique(
Expand Down Expand Up @@ -19683,6 +19785,16 @@ export function createTypeEvaluator(
return ClassType.cloneForPacked(type);
}

// For unions, distribute over the tuple constructor so that
// we produce a union of homogeneous tuple types rather than
// a single tuple of a union element type.
if (isUnion(type)) {
const tupleVariants = type.priv.subtypes.map((sub: Type) =>
makeTupleObject(evaluatorInterface, [{ type: sub, isUnbounded: !isTypeVarTuple(sub) }])
);
return combineTypes(tupleVariants);
}

return makeTupleObject(evaluatorInterface, [{ type, isUnbounded: !isTypeVarTuple(type) }]);
}

Expand Down
113 changes: 113 additions & 0 deletions packages/pyright-internal/src/tests/samples/overloadImplInference1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# This sample tests inference of implementation types from overloads
# including generics, positional-only, keyword-only, varargs and kwargs.

from typing import Any, Tuple, Iterable, overload
from typing_extensions import assert_type # pyright: ignore[reportMissingModuleSource]


# Basic: union inference for param and return with generic
@overload
def f(a: int) -> str: ...

@overload
def f[T](a: T) -> T: ...

def f[T](a):
# Inferred parameter type should be int | T
assert_type(a, int | T)
# Return type should be str | T, so returning None should error
return None # pyright: ignore[reportReturnType]


# Positional-only with alpha-equivalence of type parameters
@overload
def po1[T](x: T, /) -> T: ...

@overload
def po1[U](x: U, y: U ,/) -> U: ...

def po1[V](x, y: V | None = None, /): # pyright: ignore[reportInvalidTypeVarUse] https://github.com/DetachHead/basedpyright/issues/1500
# Both overloads are alpha-equivalent; inferred type should be V (not V | V)
assert_type(x, V)
return x

# TODO: impl does not have type parameters
# @overload
# def po2(x: int) -> int: ...
#
# @overload
# def po2[T](x: T) -> T: ...
#
# def po2(x, /):
# reveal_type(x) # expect int | T
# return x

# Positional-only union across concrete and generic
@overload
def g(x: int, /) -> int: ...

@overload
def g[T](x: T, /) -> T: ...

def g[T](x, /):
assert_type(x, int | T)
# Returning x should be OK because return type is int | T
return x


# Keyword-only parameter
@overload
def h(*, x: int) -> str: ...

@overload
def h[T](*, x: T) -> T: ...

def h[T](*, x):
assert_type(x, int | T)
return ""


# Variadic positional parameters
@overload
def va1(*args: int) -> str: ...

@overload
def va1[T](*args: T) -> T: ...


def va1[T](*args):
# Inferred "args" type should be tuple[T, ...] | tuple[int, ...]
assert_type(args, tuple[T, ...] | tuple[int, ...])
return ""


# Variadic keyword parameters
@overload
def kw2(**kwargs: int) -> str: ...

@overload
def kw2[T](**kwargs: T) -> T: ...

def kw2[T](**kwargs):
# The variable inside body is a dict[str, <value type>]
assert_type(kwargs, dict[str, T | int])
return ""


@overload
def f2(i: int) -> int: ...
@overload
def f2(i: str, j: str) -> str: ...
def f2(i, j=None):
assert_type(i, int | str)
assert_type(j, str | None)
return ""

@overload
def f3(i: int) -> int: ...
@overload
def f3(i: str, j: str) -> str: ...
def f3(i, j: str=None): # pyright: ignore[reportArgumentType, reportInconsistentOverload]
assert_type(i, int | str)
assert_type(j, str)
return ""
7 changes: 7 additions & 0 deletions packages/pyright-internal/src/tests/typeEvaluator6.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,13 @@ test('OverloadImpl2', () => {
TestUtils.validateResults(analysisResults, 2);
});

test('OverloadImplInference1', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['overloadImplInference1.py']);
TestUtils.validateResultsButBased(analysisResults, {
infos: [{ line: 42, message: 'revealed type is "int | T"' }],
});
});

test('OverloadOverlap1', () => {
const configOptions = new ConfigOptions(Uri.empty());

Expand Down
Loading