Skip to content

Commit

Permalink
merged master
Browse files Browse the repository at this point in the history
  • Loading branch information
bluenote10 committed Jan 13, 2024
2 parents 90c7674 + 1fd29ac commit 0aff047
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 11 deletions.
2 changes: 1 addition & 1 deletion mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def parse_type_string(
string expression "blah" using this function.
"""
try:
_, node = parse_type_comment(expr_string.strip(), line=line, column=column, errors=None)
_, node = parse_type_comment(f"({expr_string})", line=line, column=column, errors=None)
if isinstance(node, UnboundType) and node.original_str_expr is None:
node.original_str_expr = expr_string
node.original_str_fallback = expr_fallback_name
Expand Down
16 changes: 10 additions & 6 deletions mypy/stubgenc.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,12 +530,14 @@ def is_classmethod(self, class_info: ClassInfo, name: str, obj: object) -> bool:
return inspect.ismethod(obj)

def is_staticmethod(self, class_info: ClassInfo | None, name: str, obj: object) -> bool:
if self.is_c_module:
if class_info is None:
return False
elif self.is_c_module:
raw_lookup: Mapping[str, Any] = getattr(class_info.cls, "__dict__") # noqa: B009
raw_value = raw_lookup.get(name, obj)
return isinstance(raw_value, staticmethod)
else:
return class_info is not None and isinstance(
inspect.getattr_static(class_info.cls, name), staticmethod
)
return isinstance(inspect.getattr_static(class_info.cls, name), staticmethod)

@staticmethod
def is_abstract_method(obj: object) -> bool:
Expand Down Expand Up @@ -761,7 +763,7 @@ def generate_class_stub(self, class_name: str, cls: type, output: list[str]) ->
The result lines will be appended to 'output'. If necessary, any
required names will be added to 'imports'.
"""
raw_lookup = getattr(cls, "__dict__") # noqa: B009
raw_lookup: Mapping[str, Any] = getattr(cls, "__dict__") # noqa: B009
items = self.get_members(cls)
if self.resort_members:
items = sorted(items, key=lambda x: method_name_sort_key(x[0]))
Expand Down Expand Up @@ -793,7 +795,9 @@ def generate_class_stub(self, class_name: str, cls: type, output: list[str]) ->
continue
attr = "__init__"
# FIXME: make this nicer
if self.is_classmethod(class_info, attr, value):
if self.is_staticmethod(class_info, attr, value):
class_info.self_var = ""
elif self.is_classmethod(class_info, attr, value):
class_info.self_var = "cls"
else:
class_info.self_var = "self"
Expand Down
7 changes: 6 additions & 1 deletion mypy/stubtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,10 @@ def from_overloadedfuncdef(stub: nodes.OverloadedFuncDef) -> Signature[nodes.Arg
# argument. To accomplish this, we just make up a fake index-based name.
name = (
f"__{index}"
if arg.variable.name.startswith("__") or assume_positional_only
if arg.variable.name.startswith("__")
or arg.pos_only
or assume_positional_only
or arg.variable.name.strip("_") == "self"
else arg.variable.name
)
all_args.setdefault(name, []).append((arg, index))
Expand Down Expand Up @@ -870,6 +873,7 @@ def get_kind(arg_name: str) -> nodes.ArgKind:
type_annotation=None,
initializer=None,
kind=get_kind(arg_name),
pos_only=all(arg.pos_only for arg, _ in all_args[arg_name]),
)
if arg.kind.is_positional():
sig.pos.append(arg)
Expand Down Expand Up @@ -905,6 +909,7 @@ def _verify_signature(
if (
runtime_arg.kind != inspect.Parameter.POSITIONAL_ONLY
and (stub_arg.pos_only or stub_arg.variable.name.startswith("__"))
and stub_arg.variable.name.strip("_") != "self"
and not is_dunder(function_name, exclude_special=True) # noisy for dunder methods
):
yield (
Expand Down
58 changes: 57 additions & 1 deletion mypy/test/teststubtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,13 @@ def test(*args: Any, **kwargs: Any) -> None:
)

actual_errors = set(output.splitlines())
assert actual_errors == expected_errors, output
if actual_errors != expected_errors:
output = run_stubtest(
stub="\n\n".join(textwrap.dedent(c.stub.lstrip("\n")) for c in cases),
runtime="\n\n".join(textwrap.dedent(c.runtime.lstrip("\n")) for c in cases),
options=[],
)
assert actual_errors == expected_errors, output

return test

Expand Down Expand Up @@ -660,6 +666,56 @@ def f6(self, x, /): pass
""",
error=None,
)
yield Case(
stub="""
@overload
def f7(a: int, /) -> int: ...
@overload
def f7(b: str, /) -> str: ...
""",
runtime="def f7(x, /): pass",
error=None,
)
yield Case(
stub="""
@overload
def f8(a: int, c: int = 0, /) -> int: ...
@overload
def f8(b: str, d: int, /) -> str: ...
""",
runtime="def f8(x, y, /): pass",
error="f8",
)
yield Case(
stub="""
@overload
def f9(a: int, c: int = 0, /) -> int: ...
@overload
def f9(b: str, d: int, /) -> str: ...
""",
runtime="def f9(x, y=0, /): pass",
error=None,
)
yield Case(
stub="""
class Bar:
@overload
def f1(self) -> int: ...
@overload
def f1(self, a: int, /) -> int: ...
@overload
def f2(self, a: int, /) -> int: ...
@overload
def f2(self, a: str, /) -> int: ...
""",
runtime="""
class Bar:
def f1(self, *a) -> int: ...
def f2(self, *a) -> int: ...
""",
error=None,
)

@collect_cases
def test_property(self) -> Iterator[Case]:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
import os
from . import demo as demo
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, overload

class StaticMethods:
def __init__(self, *args, **kwargs) -> None: ...
@overload
@staticmethod
def overloaded_static_method(value: int) -> int: ...
@overload
@staticmethod
def overloaded_static_method(value: float) -> float: ...
@staticmethod
def some_static_method(a: int, b: int) -> int: ...

class TestStruct:
field_readwrite: int
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,36 @@
import os
from . import demo as demo
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, overload

class StaticMethods:
def __init__(self, *args, **kwargs) -> None:
"""Initialize self. See help(type(self)) for accurate signature."""
@overload
@staticmethod
def overloaded_static_method(value: int) -> int:
"""overloaded_static_method(*args, **kwargs)
Overloaded function.
1. overloaded_static_method(value: int) -> int
2. overloaded_static_method(value: float) -> float
"""
@overload
@staticmethod
def overloaded_static_method(value: float) -> float:
"""overloaded_static_method(*args, **kwargs)
Overloaded function.
1. overloaded_static_method(value: int) -> int
2. overloaded_static_method(value: float) -> float
"""
@staticmethod
def some_static_method(a: int, b: int) -> int:
"""some_static_method(a: int, b: int) -> int
None
"""

class TestStruct:
field_readwrite: int
Expand Down
21 changes: 21 additions & 0 deletions test-data/pybind11_fixtures/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@ struct TestStruct
int field_readonly;
};

struct StaticMethods
{
static int some_static_method(int a, int b) { return 42; }
static int overloaded_static_method(int value) { return 42; }
static double overloaded_static_method(double value) { return 1.0; }
};

// Bindings

void bind_test_cases(py::module& m) {
Expand All @@ -115,6 +122,20 @@ void bind_test_cases(py::module& m) {
return x.field_readonly;
},
"some docstring");

// Static methods
py::class_<StaticMethods> pyStaticMethods(m, "StaticMethods");

pyStaticMethods
.def_static(
"some_static_method",
&StaticMethods::some_static_method, R"#(None)#", py::arg("a"), py::arg("b"))
.def_static(
"overloaded_static_method",
py::overload_cast<int>(&StaticMethods::overloaded_static_method), py::arg("value"))
.def_static(
"overloaded_static_method",
py::overload_cast<double>(&StaticMethods::overloaded_static_method), py::arg("value"));
}

// ----------------------------------------------------------------------------
Expand Down
16 changes: 16 additions & 0 deletions test-data/unit/check-basic.test
Original file line number Diff line number Diff line change
Expand Up @@ -502,3 +502,19 @@ s2: str = 42 # E: Incompatible types in assignment (expression has type "int",
s3: str = 42 # E: Incompatible types in assignment (expression has type "int", variable has type "str")
[file c.py]
s3: str = 'foo'

[case testMultilineQuotedAnnotation]
x: """

int |
str

"""
reveal_type(x) # N: Revealed type is "Union[builtins.int, builtins.str]"

y: """(
int |
str
)
"""
reveal_type(y) # N: Revealed type is "Union[builtins.int, builtins.str]"

0 comments on commit 0aff047

Please sign in to comment.