From 0f222687446202d9918c9b3f473ec80994b0b10f Mon Sep 17 00:00:00 2001 From: Eric Mark Martin Date: Fri, 21 Jun 2024 17:45:59 -0400 Subject: [PATCH] InspectionStubGenerator respects kw only args --- mypy/stubgenc.py | 53 ++++++++++++++++++++++++++++------------ mypy/test/teststubgen.py | 28 +++++++++++++++++++++ 2 files changed, 66 insertions(+), 15 deletions(-) diff --git a/mypy/stubgenc.py b/mypy/stubgenc.py index bacb68f6d1c73..ebb48c0640d2b 100755 --- a/mypy/stubgenc.py +++ b/mypy/stubgenc.py @@ -12,7 +12,7 @@ import keyword import os.path from types import FunctionType, ModuleType -from typing import Any, Mapping +from typing import Any, Callable, Mapping from mypy.fastparse import parse_type_comment from mypy.moduleinspect import is_c_module @@ -291,6 +291,8 @@ def get_default_function_sig(self, func: object, ctx: FunctionContext) -> Functi varargs = argspec.varargs kwargs = argspec.varkw annotations = argspec.annotations + kwonlyargs = argspec.kwonlyargs + kwonlydefaults = argspec.kwonlydefaults def get_annotation(key: str) -> str | None: if key not in annotations: @@ -303,27 +305,48 @@ def get_annotation(key: str) -> str | None: return argtype arglist: list[ArgSig] = [] + # Add the arguments to the signature - for i, arg in enumerate(args): - # Check if the argument has a default value - if defaults and i >= len(args) - len(defaults): - default_value = defaults[i - (len(args) - len(defaults))] - if arg in annotations: - argtype = annotations[arg] + def add_args(args, get_default_value: Callable[[int, str], object | None]): + for i, arg in enumerate(args): + # Check if the argument has a default value + if default_value := get_default_value(i, arg): + if arg in annotations: + argtype = annotations[arg] + else: + argtype = self.get_type_annotation(default_value) + if argtype == "None": + # None is not a useful annotation, but we can infer that the arg + # is optional + incomplete = self.add_name("_typeshed.Incomplete") + argtype = f"{incomplete} | None" + + arglist.append(ArgSig(arg, argtype, default=True)) else: - argtype = self.get_type_annotation(default_value) - if argtype == "None": - # None is not a useful annotation, but we can infer that the arg - # is optional - incomplete = self.add_name("_typeshed.Incomplete") - argtype = f"{incomplete} | None" - arglist.append(ArgSig(arg, argtype, default=True)) + arglist.append(ArgSig(arg, get_annotation(arg), default=False)) + + def get_pos_default(i: int, _arg: str) -> object | None: + if defaults and i >= len(args) - len(defaults): + return defaults[i - (len(args) - len(defaults))] else: - arglist.append(ArgSig(arg, get_annotation(arg), default=False)) + return None + + add_args(args, get_pos_default) # Add *args if present if varargs: arglist.append(ArgSig(f"*{varargs}", get_annotation(varargs))) + # if we have keyword only args, then wee need to add "*" + elif kwonlyargs: + arglist.append(ArgSig("*")) + + def get_kw_default(_i: int, arg: str) -> object | None: + if kwonlydefaults: + return kwonlydefaults.get(arg) + else: + return None + + add_args(kwonlyargs, get_kw_default) # Add **kwargs if present if kwargs: diff --git a/mypy/test/teststubgen.py b/mypy/test/teststubgen.py index 3669772854cb6..bf2f814942501 100644 --- a/mypy/test/teststubgen.py +++ b/mypy/test/teststubgen.py @@ -845,6 +845,34 @@ class TestClassVariableCls: assert_equal(gen.get_imports().splitlines(), ["from typing import ClassVar"]) assert_equal(output, ["class C:", " x: ClassVar[int] = ..."]) + + def test_non_c_generate_signature_with_kw_only_args(self) -> None: + class TestClass: + def test(self, arg0, *, keyword_only : str, keyword_only_with_default : int = 7): + pass + + output: list[str] = [] + mod = ModuleType(TestClass.__module__, "") + gen = InspectionStubGenerator(mod.__name__, known_modules=[mod.__name__], module=mod) + gen.is_c_module = False + gen.generate_function_stub( + "test", + TestClass.test, + output=output, + class_info=ClassInfo( + self_var="self", + cls=TestClass, + name="TestClass", + docstring=getattr(TestClass, "__doc__", None), + ), + ) + assert_equal( + output, + [ + "def test(self, arg0, *, keyword_only: str, keyword_only_with_default: int = ...): ..." + ], + ) + def test_generate_c_type_inheritance(self) -> None: class TestClass(KeyError): pass