diff --git a/mypy/stubgen.py b/mypy/stubgen.py index e7fa65d7f949..9084da2053cf 100755 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -80,6 +80,7 @@ CallExpr, ClassDef, ComparisonExpr, + ComplexExpr, Decorator, DictExpr, EllipsisExpr, @@ -1396,6 +1397,8 @@ def is_private_member(self, fullname: str) -> bool: def get_str_type_of_node( self, rvalue: Expression, can_infer_optional: bool = False, can_be_any: bool = True ) -> str: + rvalue = self.maybe_unwrap_unary_expr(rvalue) + if isinstance(rvalue, IntExpr): return "int" if isinstance(rvalue, StrExpr): @@ -1404,8 +1407,13 @@ def get_str_type_of_node( return "bytes" if isinstance(rvalue, FloatExpr): return "float" - if isinstance(rvalue, UnaryExpr) and isinstance(rvalue.expr, IntExpr): - return "int" + if isinstance(rvalue, ComplexExpr): # 1j + return "complex" + if isinstance(rvalue, OpExpr) and rvalue.op in ("-", "+"): # -1j + 1 + if isinstance(self.maybe_unwrap_unary_expr(rvalue.left), ComplexExpr) or isinstance( + self.maybe_unwrap_unary_expr(rvalue.right), ComplexExpr + ): + return "complex" if isinstance(rvalue, NameExpr) and rvalue.name in ("True", "False"): return "bool" if can_infer_optional and isinstance(rvalue, NameExpr) and rvalue.name == "None": @@ -1417,6 +1425,40 @@ def get_str_type_of_node( else: return "" + def maybe_unwrap_unary_expr(self, expr: Expression) -> Expression: + """Unwrap (possibly nested) unary expressions. + + But, some unary expressions can change the type of expression. + While we want to preserve it. For example, `~True` is `int`. + So, we only allow a subset of unary expressions to be unwrapped. + """ + if not isinstance(expr, UnaryExpr): + return expr + + # First, try to unwrap `[+-]+ (int|float|complex)` expr: + math_ops = ("+", "-") + if expr.op in math_ops: + while isinstance(expr, UnaryExpr): + if expr.op not in math_ops or not isinstance( + expr.expr, (IntExpr, FloatExpr, ComplexExpr, UnaryExpr) + ): + break + expr = expr.expr + return expr + + # Next, try `not bool` expr: + if expr.op == "not": + while isinstance(expr, UnaryExpr): + if expr.op != "not" or not isinstance(expr.expr, (NameExpr, UnaryExpr)): + break + if isinstance(expr.expr, NameExpr) and expr.expr.name not in ("True", "False"): + break + expr = expr.expr + return expr + + # This is some other unary expr, we cannot do anything with it (yet?). + return expr + def print_annotation(self, t: Type) -> str: printer = AnnotationPrinter(self) return t.accept(printer) diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index 9c7221e7ec54..b387aa840dc9 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -127,10 +127,52 @@ class A: def g() -> None: ... -[case testVariable] -x = 1 +[case testVariables] +i = 1 +s = 'a' +f = 1.5 +c1 = 1j +c2 = 0j + 1 +bl1 = True +bl2 = False +bts = b'' +[out] +i: int +s: str +f: float +c1: complex +c2: complex +bl1: bool +bl2: bool +bts: bytes + +[case testVariablesWithUnary] +i = +-1 +f = -1.5 +c1 = -1j +c2 = -1j + 1 +bl1 = not True +bl2 = not not False +[out] +i: int +f: float +c1: complex +c2: complex +bl1: bool +bl2: bool + +[case testVariablesWithUnaryWrong] +i = not +1 +bl1 = -True +bl2 = not -False +bl3 = -(not False) [out] -x: int +from _typeshed import Incomplete + +i: Incomplete +bl1: Incomplete +bl2: Incomplete +bl3: Incomplete [case testAnnotatedVariable] x: int = 1