Skip to content

Commit

Permalink
Allow inferring +int to be a Literal (#16910)
Browse files Browse the repository at this point in the history
This makes unary positive on integers preserve the literal value of the
integer, allowing `var: Literal[1] = +1` to be accepted. Basically I
looked for code handling `__neg__` and added a branch for `__pos__` as
well.
Fixes #16728.
  • Loading branch information
TeamSpen210 committed Feb 12, 2024
1 parent b6e91d4 commit c26f129
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 11 deletions.
4 changes: 4 additions & 0 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4437,6 +4437,10 @@ def try_getting_int_literals(self, index: Expression) -> list[int] | None:
operand = index.expr
if isinstance(operand, IntExpr):
return [-1 * operand.value]
if index.op == "+":
operand = index.expr
if isinstance(operand, IntExpr):
return [operand.value]
typ = get_proper_type(self.accept(index))
if isinstance(typ, Instance) and typ.last_known_value is not None:
typ = typ.last_known_value
Expand Down
9 changes: 6 additions & 3 deletions mypy/exprtotype.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,12 @@ def expr_to_unanalyzed_type(
elif isinstance(expr, UnaryExpr):
typ = expr_to_unanalyzed_type(expr.expr, options, allow_new_syntax)
if isinstance(typ, RawExpressionType):
if isinstance(typ.literal_value, int) and expr.op == "-":
typ.literal_value *= -1
return typ
if isinstance(typ.literal_value, int):
if expr.op == "-":
typ.literal_value *= -1
return typ
elif expr.op == "+":
return typ
raise TypeTranslationError()
elif isinstance(expr, IntExpr):
return RawExpressionType(expr.value, "builtins.int", line=expr.line, column=expr.column)
Expand Down
25 changes: 19 additions & 6 deletions mypy/plugins/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | No
return int_pow_callback
elif fullname == "builtins.int.__neg__":
return int_neg_callback
elif fullname == "builtins.int.__pos__":
return int_pos_callback
elif fullname in ("builtins.tuple.__mul__", "builtins.tuple.__rmul__"):
return tuple_mul_callback
elif fullname in {n + ".setdefault" for n in TPDICT_FB_NAMES}:
Expand Down Expand Up @@ -471,32 +473,43 @@ def int_pow_callback(ctx: MethodContext) -> Type:
return ctx.default_return_type


def int_neg_callback(ctx: MethodContext) -> Type:
"""Infer a more precise return type for int.__neg__.
def int_neg_callback(ctx: MethodContext, multiplier: int = -1) -> Type:
"""Infer a more precise return type for int.__neg__ and int.__pos__.
This is mainly used to infer the return type as LiteralType
if the original underlying object is a LiteralType object
if the original underlying object is a LiteralType object.
"""
if isinstance(ctx.type, Instance) and ctx.type.last_known_value is not None:
value = ctx.type.last_known_value.value
fallback = ctx.type.last_known_value.fallback
if isinstance(value, int):
if is_literal_type_like(ctx.api.type_context[-1]):
return LiteralType(value=-value, fallback=fallback)
return LiteralType(value=multiplier * value, fallback=fallback)
else:
return ctx.type.copy_modified(
last_known_value=LiteralType(
value=-value, fallback=ctx.type, line=ctx.type.line, column=ctx.type.column
value=multiplier * value,
fallback=ctx.type,
line=ctx.type.line,
column=ctx.type.column,
)
)
elif isinstance(ctx.type, LiteralType):
value = ctx.type.value
fallback = ctx.type.fallback
if isinstance(value, int):
return LiteralType(value=-value, fallback=fallback)
return LiteralType(value=multiplier * value, fallback=fallback)
return ctx.default_return_type


def int_pos_callback(ctx: MethodContext) -> Type:
"""Infer a more precise return type for int.__pos__.
This is identical to __neg__, except the value is not inverted.
"""
return int_neg_callback(ctx, +1)


def tuple_mul_callback(ctx: MethodContext) -> Type:
"""Infer a more precise return type for tuple.__mul__ and tuple.__rmul__.
Expand Down
13 changes: 12 additions & 1 deletion test-data/unit/check-literal.test
Original file line number Diff line number Diff line change
Expand Up @@ -397,29 +397,36 @@ from typing_extensions import Literal
a1: Literal[4]
b1: Literal[0x2a]
c1: Literal[-300]
d1: Literal[+8]

reveal_type(a1) # N: Revealed type is "Literal[4]"
reveal_type(b1) # N: Revealed type is "Literal[42]"
reveal_type(c1) # N: Revealed type is "Literal[-300]"
reveal_type(d1) # N: Revealed type is "Literal[8]"

a2t = Literal[4]
b2t = Literal[0x2a]
c2t = Literal[-300]
d2t = Literal[+8]
a2: a2t
b2: b2t
c2: c2t
d2: d2t

reveal_type(a2) # N: Revealed type is "Literal[4]"
reveal_type(b2) # N: Revealed type is "Literal[42]"
reveal_type(c2) # N: Revealed type is "Literal[-300]"
reveal_type(d2) # N: Revealed type is "Literal[8]"

def f1(x: Literal[4]) -> Literal[4]: pass
def f2(x: Literal[0x2a]) -> Literal[0x2a]: pass
def f3(x: Literal[-300]) -> Literal[-300]: pass
def f4(x: Literal[+8]) -> Literal[+8]: pass

reveal_type(f1) # N: Revealed type is "def (x: Literal[4]) -> Literal[4]"
reveal_type(f2) # N: Revealed type is "def (x: Literal[42]) -> Literal[42]"
reveal_type(f3) # N: Revealed type is "def (x: Literal[-300]) -> Literal[-300]"
reveal_type(f4) # N: Revealed type is "def (x: Literal[8]) -> Literal[8]"
[builtins fixtures/tuple.pyi]
[out]

Expand Down Expand Up @@ -2747,6 +2754,9 @@ d: Literal[1] = 1
e: Literal[2] = 2
f: Literal[+1] = 1
g: Literal[+2] = 2
h: Literal[1] = +1
i: Literal[+2] = 2
j: Literal[+3] = +3

x: Literal[+True] = True # E: Invalid type: Literal[...] cannot contain arbitrary expressions
y: Literal[-True] = -1 # E: Invalid type: Literal[...] cannot contain arbitrary expressions
Expand All @@ -2759,14 +2769,15 @@ from typing_extensions import Literal, Final

ONE: Final = 1
x: Literal[-1] = -ONE
y: Literal[+1] = +ONE

TWO: Final = 2
THREE: Final = 3

err_code = -TWO
if bool():
err_code = -THREE
[builtins fixtures/float.pyi]
[builtins fixtures/ops.pyi]

[case testAliasForEnumTypeAsLiteral]
from typing_extensions import Literal
Expand Down
4 changes: 3 additions & 1 deletion test-data/unit/check-tuples.test
Original file line number Diff line number Diff line change
Expand Up @@ -337,10 +337,12 @@ if int():
b = t1[-1]
if int():
a = t1[(0)]
if int():
b = t1[+1]
if int():
x = t3[0:3] # type (A, B, C)
if int():
y = t3[0:5:2] # type (A, C, E)
y = t3[0:+5:2] # type (A, C, E)
if int():
x = t3[:-2] # type (A, B, C)

Expand Down
1 change: 1 addition & 0 deletions test-data/unit/fixtures/tuple.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class classmethod: pass
# We need int and slice for indexing tuples.
class int:
def __neg__(self) -> 'int': pass
def __pos__(self) -> 'int': pass
class float: pass
class slice: pass
class bool(int): pass
Expand Down

0 comments on commit c26f129

Please sign in to comment.