diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index 1bd02f1cc6..a34848123c 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -19180,6 +19180,91 @@ export function createTypeEvaluator( FunctionType.isConstructorMethod(functionType)); const firstNonClsSelfParamIndex = isFirstParamClsOrSelf ? 1 : 0; + // Infer implementation signature from prior overloads if present. + // Collect prior @overload function types for this symbol. + let overloadParamTypeUnions: (Type | undefined)[] | undefined; + let overloadReturnTypeUnion: Type | undefined; + do { + const scope = ScopeUtils.getScopeForNode(node); + const functionSymbol = scope?.lookUpSymbolRecursive(node.d.name.d.value); + if (!functionSymbol || !functionDecl) break; + const decls = functionSymbol.symbol.getDeclarations(); + const declIndex = decls.findIndex((d) => d === functionDecl); + if (declIndex <= 0) break; + + const priorOverloadFuncTypes: FunctionType[] = []; + for (let i = 0; i < declIndex; i++) { + const d = decls[i]; + if (d.type !== DeclarationType.Function) continue; + const prevInfo = getTypeOfFunction(d.node); + if (!prevInfo) continue; + const prevFunc = prevInfo.functionType; + if (isFunction(prevFunc) && FunctionType.isOverloaded(prevFunc)) { + priorOverloadFuncTypes.push(prevFunc); + } + } + + if (priorOverloadFuncTypes.length === 0) break; + + // Prepare unions per parameter and for return type. + overloadParamTypeUnions = new Array(node.d.params.length).fill(undefined); + const implTypeParams = functionType.shared.typeParams; + + const mapOverloadTypeToImpl = (srcType: Type, prevFunc: FunctionType): Type => { + const prevParams = prevFunc.shared.typeParams; + if (prevParams.length === 0 || implTypeParams.length === 0) { + return srcType; + } + // If counts differ, don't attempt mapping. + if (prevParams.length !== implTypeParams.length) { + return srcType; + } + return mapSubtypes(srcType, (subtype) => { + if (isTypeVar(subtype)) { + const idx = prevParams.findIndex((tv) => tv === subtype); + if (idx >= 0 && idx < implTypeParams.length) { + return implTypeParams[idx]; + } + } + return undefined; + }); + }; + + for (const prevFunc of priorOverloadFuncTypes) { + // Allow prior overloads that have the same or fewer params than the implementation. + // We will combine only up to the smaller parameter count. + if (prevFunc.shared.parameters.length > node.d.params.length) { + continue; + } + // Build unions for each param index that exists in the prior overload. + let categoriesMismatch = false; + const count = Math.min(prevFunc.shared.parameters.length, node.d.params.length); + for (let pi = 0; pi < count; pi++) { + // Only combine if param categories align with the implementation's parse tree. + if (prevFunc.shared.parameters[pi].category !== node.d.params[pi].d.category) { + categoriesMismatch = true; + break; + } + const prevParamType = FunctionType.getParamType(prevFunc, pi); + const mapped = mapOverloadTypeToImpl(prevParamType, prevFunc); + overloadParamTypeUnions![pi] = overloadParamTypeUnions![pi] + ? combineTypes([overloadParamTypeUnions![pi]!, mapped]) + : mapped; + } + if (categoriesMismatch) { + continue; + } + + // Combine return types. + if (prevFunc.shared.declaredReturnType) { + const mappedRet = mapOverloadTypeToImpl(prevFunc.shared.declaredReturnType, prevFunc); + overloadReturnTypeUnion = overloadReturnTypeUnion + ? combineTypes([overloadReturnTypeUnion, mappedRet]) + : mappedRet; + } + } + } while (false); + node.d.params.forEach((param, index) => { let paramType: Type | undefined; let annotatedType: Type | undefined; @@ -19344,12 +19429,23 @@ export function createTypeEvaluator( // If there was no annotation for the parameter, infer its type if possible. let isTypeInferred = false; + let usedOverloadInference = false; if (!paramTypeNode) { isTypeInferred = true; const inferredType = inferParamType(node, functionType.shared.flags, index, containingClassType); - if (inferredType) { + const overloadInferred = overloadParamTypeUnions ? overloadParamTypeUnions[index] : undefined; + if (overloadInferred) { + paramType = overloadInferred; + usedOverloadInference = true; + } else if (inferredType) { paramType = inferredType; } + + // After inferring from defaults and/or overloads, if there is a default of None + // and no explicit annotation, include None in the inferred implementation type. + if (param.d.defaultValue?.nodeType === ParseNodeType.Constant && param.d.defaultValue.d.constType === KeywordType.None) { + paramType = paramType ? combineTypes([paramType, getNoneType()]) : getNoneType(); + } } paramType = paramType ?? UnknownType.create(); @@ -19457,6 +19553,12 @@ export function createTypeEvaluator( } } + // If there was no explicit return annotation and this is a source file, + // try to infer from prior overloads' return types. + if (!returnTypeAnnotationNode && !fileInfo.isStubFile && overloadReturnTypeUnion) { + functionType.shared.declaredReturnType = overloadReturnTypeUnion; + } + // Accumulate any type parameters used in the return type. if (functionType.shared.declaredReturnType && returnTypeAnnotationNode) { addTypeVarsToListIfUnique( @@ -19683,6 +19785,16 @@ export function createTypeEvaluator( return ClassType.cloneForPacked(type); } + // For unions, distribute over the tuple constructor so that + // we produce a union of homogeneous tuple types rather than + // a single tuple of a union element type. + if (isUnion(type)) { + const tupleVariants = type.priv.subtypes.map((sub: Type) => + makeTupleObject(evaluatorInterface, [{ type: sub, isUnbounded: !isTypeVarTuple(sub) }]) + ); + return combineTypes(tupleVariants); + } + return makeTupleObject(evaluatorInterface, [{ type, isUnbounded: !isTypeVarTuple(type) }]); } diff --git a/packages/pyright-internal/src/tests/samples/overloadImplInference1.py b/packages/pyright-internal/src/tests/samples/overloadImplInference1.py new file mode 100644 index 0000000000..38211a876e --- /dev/null +++ b/packages/pyright-internal/src/tests/samples/overloadImplInference1.py @@ -0,0 +1,113 @@ +# This sample tests inference of implementation types from overloads +# including generics, positional-only, keyword-only, varargs and kwargs. + +from typing import Any, Tuple, Iterable, overload +from typing_extensions import assert_type # pyright: ignore[reportMissingModuleSource] + + +# Basic: union inference for param and return with generic +@overload +def f(a: int) -> str: ... + +@overload +def f[T](a: T) -> T: ... + +def f[T](a): + # Inferred parameter type should be int | T + assert_type(a, int | T) + # Return type should be str | T, so returning None should error + return None # pyright: ignore[reportReturnType] + + +# Positional-only with alpha-equivalence of type parameters +@overload +def po1[T](x: T, /) -> T: ... + +@overload +def po1[U](x: U, y: U ,/) -> U: ... + +def po1[V](x, y: V | None = None, /): # pyright: ignore[reportInvalidTypeVarUse] https://github.com/DetachHead/basedpyright/issues/1500 + # Both overloads are alpha-equivalent; inferred type should be V (not V | V) + assert_type(x, V) + return x + +# TODO: impl does not have type parameters +# @overload +# def po2(x: int) -> int: ... +# +# @overload +# def po2[T](x: T) -> T: ... +# +# def po2(x, /): +# reveal_type(x) # expect int | T +# return x + +# Positional-only union across concrete and generic +@overload +def g(x: int, /) -> int: ... + +@overload +def g[T](x: T, /) -> T: ... + +def g[T](x, /): + assert_type(x, int | T) + # Returning x should be OK because return type is int | T + return x + + +# Keyword-only parameter +@overload +def h(*, x: int) -> str: ... + +@overload +def h[T](*, x: T) -> T: ... + +def h[T](*, x): + assert_type(x, int | T) + return "" + + +# Variadic positional parameters +@overload +def va1(*args: int) -> str: ... + +@overload +def va1[T](*args: T) -> T: ... + + +def va1[T](*args): + # Inferred "args" type should be tuple[T, ...] | tuple[int, ...] + assert_type(args, tuple[T, ...] | tuple[int, ...]) + return "" + + +# Variadic keyword parameters +@overload +def kw2(**kwargs: int) -> str: ... + +@overload +def kw2[T](**kwargs: T) -> T: ... + +def kw2[T](**kwargs): + # The variable inside body is a dict[str, ] + assert_type(kwargs, dict[str, T | int]) + return "" + + +@overload +def f2(i: int) -> int: ... +@overload +def f2(i: str, j: str) -> str: ... +def f2(i, j=None): + assert_type(i, int | str) + assert_type(j, str | None) + return "" + +@overload +def f3(i: int) -> int: ... +@overload +def f3(i: str, j: str) -> str: ... +def f3(i, j: str=None): # pyright: ignore[reportArgumentType, reportInconsistentOverload] + assert_type(i, int | str) + assert_type(j, str) + return "" diff --git a/packages/pyright-internal/src/tests/typeEvaluator6.test.ts b/packages/pyright-internal/src/tests/typeEvaluator6.test.ts index 96270f74cd..04b2bfc93c 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator6.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator6.test.ts @@ -109,6 +109,13 @@ test('OverloadImpl2', () => { TestUtils.validateResults(analysisResults, 2); }); +test('OverloadImplInference1', () => { + const analysisResults = TestUtils.typeAnalyzeSampleFiles(['overloadImplInference1.py']); + TestUtils.validateResultsButBased(analysisResults, { + infos: [{ line: 42, message: 'revealed type is "int | T"' }], + }); +}); + test('OverloadOverlap1', () => { const configOptions = new ConfigOptions(Uri.empty());