Skip to content

Commit

Permalink
Teach stubgen to work with complex and unary expressions (#15661)
Browse files Browse the repository at this point in the history
  • Loading branch information
sobolevn authored Jul 13, 2023
1 parent 3983381 commit 2ebd51e
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 5 deletions.
46 changes: 44 additions & 2 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
CallExpr,
ClassDef,
ComparisonExpr,
ComplexExpr,
Decorator,
DictExpr,
EllipsisExpr,
Expand Down Expand Up @@ -1396,6 +1397,8 @@ def is_private_member(self, fullname: str) -> bool:
def get_str_type_of_node(
self, rvalue: Expression, can_infer_optional: bool = False, can_be_any: bool = True
) -> str:
rvalue = self.maybe_unwrap_unary_expr(rvalue)

if isinstance(rvalue, IntExpr):
return "int"
if isinstance(rvalue, StrExpr):
Expand All @@ -1404,8 +1407,13 @@ def get_str_type_of_node(
return "bytes"
if isinstance(rvalue, FloatExpr):
return "float"
if isinstance(rvalue, UnaryExpr) and isinstance(rvalue.expr, IntExpr):
return "int"
if isinstance(rvalue, ComplexExpr): # 1j
return "complex"
if isinstance(rvalue, OpExpr) and rvalue.op in ("-", "+"): # -1j + 1
if isinstance(self.maybe_unwrap_unary_expr(rvalue.left), ComplexExpr) or isinstance(
self.maybe_unwrap_unary_expr(rvalue.right), ComplexExpr
):
return "complex"
if isinstance(rvalue, NameExpr) and rvalue.name in ("True", "False"):
return "bool"
if can_infer_optional and isinstance(rvalue, NameExpr) and rvalue.name == "None":
Expand All @@ -1417,6 +1425,40 @@ def get_str_type_of_node(
else:
return ""

def maybe_unwrap_unary_expr(self, expr: Expression) -> Expression:
"""Unwrap (possibly nested) unary expressions.
But, some unary expressions can change the type of expression.
While we want to preserve it. For example, `~True` is `int`.
So, we only allow a subset of unary expressions to be unwrapped.
"""
if not isinstance(expr, UnaryExpr):
return expr

# First, try to unwrap `[+-]+ (int|float|complex)` expr:
math_ops = ("+", "-")
if expr.op in math_ops:
while isinstance(expr, UnaryExpr):
if expr.op not in math_ops or not isinstance(
expr.expr, (IntExpr, FloatExpr, ComplexExpr, UnaryExpr)
):
break
expr = expr.expr
return expr

# Next, try `not bool` expr:
if expr.op == "not":
while isinstance(expr, UnaryExpr):
if expr.op != "not" or not isinstance(expr.expr, (NameExpr, UnaryExpr)):
break
if isinstance(expr.expr, NameExpr) and expr.expr.name not in ("True", "False"):
break
expr = expr.expr
return expr

# This is some other unary expr, we cannot do anything with it (yet?).
return expr

def print_annotation(self, t: Type) -> str:
printer = AnnotationPrinter(self)
return t.accept(printer)
Expand Down
48 changes: 45 additions & 3 deletions test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,52 @@ class A:

def g() -> None: ...

[case testVariable]
x = 1
[case testVariables]
i = 1
s = 'a'
f = 1.5
c1 = 1j
c2 = 0j + 1
bl1 = True
bl2 = False
bts = b''
[out]
i: int
s: str
f: float
c1: complex
c2: complex
bl1: bool
bl2: bool
bts: bytes

[case testVariablesWithUnary]
i = +-1
f = -1.5
c1 = -1j
c2 = -1j + 1
bl1 = not True
bl2 = not not False
[out]
i: int
f: float
c1: complex
c2: complex
bl1: bool
bl2: bool

[case testVariablesWithUnaryWrong]
i = not +1
bl1 = -True
bl2 = not -False
bl3 = -(not False)
[out]
x: int
from _typeshed import Incomplete

i: Incomplete
bl1: Incomplete
bl2: Incomplete
bl3: Incomplete

[case testAnnotatedVariable]
x: int = 1
Expand Down

0 comments on commit 2ebd51e

Please sign in to comment.