Skip to content

Commit

Permalink
Use namespaces for function type variables (#17311)
Browse files Browse the repository at this point in the history
Fixes #16582

IMO this is long overdue. Currently, type variable IDs are 99% unique,
but when they accidentally clash, it causes hard to debug issues. The
implementation is generally straightforward, but it uncovered a whole
bunch of unrelated bugs. Few notes:
* This still doesn't fix the type variables in nested generic callable
types (those that appear in return types of another generic callable).
It is non-trivial to put namespace there, and luckily this situation is
already special-cased in `checkexpr.py` to avoid ID clashes.
* This uncovered a bug in overloaded dunder overrides handling, fix is
simple.
* This also uncovered a deeper problem in unsafe overload overlap logic
(w.r.t. partial parameters overlap). Here proper fix would be hard, so
instead I tweak current logic so it will not cause false positives, at a
cost of possible false negatives.
* This makes explicit that we use a somewhat ad-hoc logic for join/meet
of generic callables. FWIW I decided to keep it, since it seems to work
reasonably well.
* This accidentally highlighted two bugs in error message locations. One
very old one related to type aliases, I fixed newly discovered cases by
extending a previous partial fix. Second, the error locations generated
by `partial` plugin were completely off (you can see examples in
`mypy_primer` where there were errors on empty lines etc).
* This PR (naturally) causes a significant amount of new valid errors
(fixed false negatives). To improve the error messages, I extend the
name disambiguation logic to include type variables (and also type
aliases, while I am at it), previously it only applied to `Instance`s.
Note that I use a notation `TypeVar@namespace`, which is a semantic
equivalent of qualified name for type variables. For now, I shorten the
namespace to only the last component, to make errors less verbose. We
can reconsider this if it causes confusion.
* Finally, this PR will hopefully allow a more principled implementation
of #15907

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
ilevkivskyi and pre-commit-ci[bot] authored Jun 5, 2024
1 parent 6bdd854 commit 6682563
Show file tree
Hide file tree
Showing 26 changed files with 332 additions and 133 deletions.
4 changes: 3 additions & 1 deletion mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2171,7 +2171,9 @@ def bind_and_map_method(
def get_op_other_domain(self, tp: FunctionLike) -> Type | None:
if isinstance(tp, CallableType):
if tp.arg_kinds and tp.arg_kinds[0] == ARG_POS:
return tp.arg_types[0]
# For generic methods, domain comparison is tricky, as a first
# approximation erase all remaining type variables to bounds.
return erase_typevars(tp.arg_types[0], {v.id for v in tp.variables})
return None
elif isinstance(tp, Overloaded):
raw_items = [self.get_op_other_domain(it) for it in tp.items]
Expand Down
13 changes: 7 additions & 6 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@
TypedDictType,
TypeOfAny,
TypeType,
TypeVarId,
TypeVarLikeType,
TypeVarTupleType,
TypeVarType,
Expand Down Expand Up @@ -4933,7 +4934,7 @@ def check_lst_expr(self, e: ListExpr | SetExpr | TupleExpr, fullname: str, tag:
tv = TypeVarType(
"T",
"T",
id=-1,
id=TypeVarId(-1, namespace="<lst>"),
values=[],
upper_bound=self.object_type(),
default=AnyType(TypeOfAny.from_omitted_generics),
Expand Down Expand Up @@ -5164,15 +5165,15 @@ def visit_dict_expr(self, e: DictExpr) -> Type:
kt = TypeVarType(
"KT",
"KT",
id=-1,
id=TypeVarId(-1, namespace="<dict>"),
values=[],
upper_bound=self.object_type(),
default=AnyType(TypeOfAny.from_omitted_generics),
)
vt = TypeVarType(
"VT",
"VT",
id=-2,
id=TypeVarId(-2, namespace="<dict>"),
values=[],
upper_bound=self.object_type(),
default=AnyType(TypeOfAny.from_omitted_generics),
Expand Down Expand Up @@ -5564,7 +5565,7 @@ def check_generator_or_comprehension(
tv = TypeVarType(
"T",
"T",
id=-1,
id=TypeVarId(-1, namespace="<genexp>"),
values=[],
upper_bound=self.object_type(),
default=AnyType(TypeOfAny.from_omitted_generics),
Expand All @@ -5591,15 +5592,15 @@ def visit_dictionary_comprehension(self, e: DictionaryComprehension) -> Type:
ktdef = TypeVarType(
"KT",
"KT",
id=-1,
id=TypeVarId(-1, namespace="<dict>"),
values=[],
upper_bound=self.object_type(),
default=AnyType(TypeOfAny.from_omitted_generics),
)
vtdef = TypeVarType(
"VT",
"VT",
id=-2,
id=TypeVarId(-2, namespace="<dict>"),
values=[],
upper_bound=self.object_type(),
default=AnyType(TypeOfAny.from_omitted_generics),
Expand Down
2 changes: 1 addition & 1 deletion mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def interpolate_args_for_unpack(self, t: CallableType, var_arg: UnpackType) -> l
new_unpack: Type
if isinstance(var_arg_type, Instance):
# we have something like Unpack[Tuple[Any, ...]]
new_unpack = var_arg
new_unpack = UnpackType(var_arg.type.accept(self))
elif isinstance(var_arg_type, TupleType):
# We have something like Unpack[Tuple[Unpack[Ts], X1, X2]]
expanded_tuple = var_arg_type.accept(self)
Expand Down
31 changes: 31 additions & 0 deletions mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Sequence, overload

import mypy.typeops
from mypy.expandtype import expand_type
from mypy.maptype import map_instance_to_supertype
from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT, VARIANCE_NOT_READY
from mypy.state import state
Expand Down Expand Up @@ -36,6 +37,7 @@
TypedDictType,
TypeOfAny,
TypeType,
TypeVarId,
TypeVarLikeType,
TypeVarTupleType,
TypeVarType,
Expand Down Expand Up @@ -718,7 +720,35 @@ def is_similar_callables(t: CallableType, s: CallableType) -> bool:
)


def update_callable_ids(c: CallableType, ids: list[TypeVarId]) -> CallableType:
tv_map = {}
tvs = []
for tv, new_id in zip(c.variables, ids):
new_tv = tv.copy_modified(id=new_id)
tvs.append(new_tv)
tv_map[tv.id] = new_tv
return expand_type(c, tv_map).copy_modified(variables=tvs)


def match_generic_callables(t: CallableType, s: CallableType) -> tuple[CallableType, CallableType]:
# The case where we combine/join/meet similar callables, situation where both are generic
# requires special care. A more principled solution may involve unify_generic_callable(),
# but it would have two problems:
# * This adds risk of infinite recursion: e.g. join -> unification -> solver -> join
# * Using unification is an incorrect thing for meets, as it "widens" the types
# Finally, this effectively falls back to an old behaviour before namespaces were added to
# type variables, and it worked relatively well.
max_len = max(len(t.variables), len(s.variables))
min_len = min(len(t.variables), len(s.variables))
if min_len == 0:
return t, s
new_ids = [TypeVarId.new(meta_level=0) for _ in range(max_len)]
# Note: this relies on variables being in order they appear in function definition.
return update_callable_ids(t, new_ids), update_callable_ids(s, new_ids)


def join_similar_callables(t: CallableType, s: CallableType) -> CallableType:
t, s = match_generic_callables(t, s)
arg_types: list[Type] = []
for i in range(len(t.arg_types)):
arg_types.append(safe_meet(t.arg_types[i], s.arg_types[i]))
Expand Down Expand Up @@ -771,6 +801,7 @@ def safe_meet(t: Type, s: Type) -> Type:


def combine_similar_callables(t: CallableType, s: CallableType) -> CallableType:
t, s = match_generic_callables(t, s)
arg_types: list[Type] = []
for i in range(len(t.arg_types)):
arg_types.append(safe_join(t.arg_types[i], s.arg_types[i]))
Expand Down
3 changes: 2 additions & 1 deletion mypy/meet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,8 +1024,9 @@ def default(self, typ: Type) -> ProperType:


def meet_similar_callables(t: CallableType, s: CallableType) -> CallableType:
from mypy.join import safe_join
from mypy.join import match_generic_callables, safe_join

t, s = match_generic_callables(t, s)
arg_types: list[Type] = []
for i in range(len(t.arg_types)):
arg_types.append(safe_join(t.arg_types[i], s.arg_types[i]))
Expand Down
76 changes: 57 additions & 19 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
TypeOfAny,
TypeStrVisitor,
TypeType,
TypeVarLikeType,
TypeVarTupleType,
TypeVarType,
UnboundType,
Expand Down Expand Up @@ -2502,14 +2503,16 @@ def format_literal_value(typ: LiteralType) -> str:
return typ.value_repr()

if isinstance(typ, TypeAliasType) and typ.is_recursive:
# TODO: find balance here, str(typ) doesn't support custom verbosity, and may be
# too verbose for user messages, OTOH it nicely shows structure of recursive types.
if verbosity < 2:
type_str = typ.alias.name if typ.alias else "<alias (unfixed)>"
if typ.alias is None:
type_str = "<alias (unfixed)>"
else:
if verbosity >= 2 or (fullnames and typ.alias.fullname in fullnames):
type_str = typ.alias.fullname
else:
type_str = typ.alias.name
if typ.args:
type_str += f"[{format_list(typ.args)}]"
return type_str
return str(typ)
return type_str

# TODO: always mention type alias names in errors.
typ = get_proper_type(typ)
Expand Down Expand Up @@ -2550,9 +2553,15 @@ def format_literal_value(typ: LiteralType) -> str:
return f"Unpack[{format(typ.type)}]"
elif isinstance(typ, TypeVarType):
# This is similar to non-generic instance types.
fullname = scoped_type_var_name(typ)
if verbosity >= 2 or (fullnames and fullname in fullnames):
return fullname
return typ.name
elif isinstance(typ, TypeVarTupleType):
# This is similar to non-generic instance types.
fullname = scoped_type_var_name(typ)
if verbosity >= 2 or (fullnames and fullname in fullnames):
return fullname
return typ.name
elif isinstance(typ, ParamSpecType):
# Concatenate[..., P]
Expand All @@ -2563,6 +2572,7 @@ def format_literal_value(typ: LiteralType) -> str:

return f"[{args}, **{typ.name_with_suffix()}]"
else:
# TODO: better disambiguate ParamSpec name clashes.
return typ.name_with_suffix()
elif isinstance(typ, TupleType):
# Prefer the name of the fallback class (if not tuple), as it's more informative.
Expand Down Expand Up @@ -2680,29 +2690,51 @@ def format_literal_value(typ: LiteralType) -> str:
return "object"


def collect_all_instances(t: Type) -> list[Instance]:
"""Return all instances that `t` contains (including `t`).
def collect_all_named_types(t: Type) -> list[Type]:
"""Return all instances/aliases/type variables that `t` contains (including `t`).
This is similar to collect_all_inner_types from typeanal but only
returns instances and will recurse into fallbacks.
"""
visitor = CollectAllInstancesQuery()
visitor = CollectAllNamedTypesQuery()
t.accept(visitor)
return visitor.instances
return visitor.types


class CollectAllInstancesQuery(TypeTraverserVisitor):
class CollectAllNamedTypesQuery(TypeTraverserVisitor):
def __init__(self) -> None:
self.instances: list[Instance] = []
self.types: list[Type] = []

def visit_instance(self, t: Instance) -> None:
self.instances.append(t)
self.types.append(t)
super().visit_instance(t)

def visit_type_alias_type(self, t: TypeAliasType) -> None:
if t.alias and not t.is_recursive:
t.alias.target.accept(self)
super().visit_type_alias_type(t)
get_proper_type(t).accept(self)
else:
self.types.append(t)
super().visit_type_alias_type(t)

def visit_type_var(self, t: TypeVarType) -> None:
self.types.append(t)
super().visit_type_var(t)

def visit_type_var_tuple(self, t: TypeVarTupleType) -> None:
self.types.append(t)
super().visit_type_var_tuple(t)

def visit_param_spec(self, t: ParamSpecType) -> None:
self.types.append(t)
super().visit_param_spec(t)


def scoped_type_var_name(t: TypeVarLikeType) -> str:
if not t.id.namespace:
return t.name
# TODO: support rare cases when both TypeVar name and namespace suffix coincide.
*_, suffix = t.id.namespace.split(".")
return f"{t.name}@{suffix}"


def find_type_overlaps(*types: Type) -> set[str]:
Expand All @@ -2713,8 +2745,14 @@ def find_type_overlaps(*types: Type) -> set[str]:
"""
d: dict[str, set[str]] = {}
for type in types:
for inst in collect_all_instances(type):
d.setdefault(inst.type.name, set()).add(inst.type.fullname)
for t in collect_all_named_types(type):
if isinstance(t, ProperType) and isinstance(t, Instance):
d.setdefault(t.type.name, set()).add(t.type.fullname)
elif isinstance(t, TypeAliasType) and t.alias:
d.setdefault(t.alias.name, set()).add(t.alias.fullname)
else:
assert isinstance(t, TypeVarLikeType)
d.setdefault(t.name, set()).add(scoped_type_var_name(t))
for shortname in d.keys():
if f"typing.{shortname}" in TYPES_FOR_UNIMPORTED_HINTS:
d[shortname].add(f"typing.{shortname}")
Expand All @@ -2732,7 +2770,7 @@ def format_type(
"""
Convert a type to a relatively short string suitable for error messages.
`verbosity` is a coarse grained control on the verbosity of the type
`verbosity` is a coarse-grained control on the verbosity of the type
This function returns a string appropriate for unmodified use in error
messages; this means that it will be quoted in most cases. If
Expand All @@ -2748,7 +2786,7 @@ def format_type_bare(
"""
Convert a type to a relatively short string suitable for error messages.
`verbosity` is a coarse grained control on the verbosity of the type
`verbosity` is a coarse-grained control on the verbosity of the type
`fullnames` specifies a set of names that should be printed in full
This function will return an unquoted string. If a caller doesn't need to
Expand Down
17 changes: 9 additions & 8 deletions mypy/plugins/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
Type,
TypeOfAny,
TypeType,
TypeVarId,
TypeVarType,
UninhabitedType,
UnionType,
Expand Down Expand Up @@ -807,25 +808,25 @@ def _add_order(ctx: mypy.plugin.ClassDefContext, adder: MethodAdder) -> None:
# AT = TypeVar('AT')
# def __lt__(self: AT, other: AT) -> bool
# This way comparisons with subclasses will work correctly.
fullname = f"{ctx.cls.info.fullname}.{SELF_TVAR_NAME}"
tvd = TypeVarType(
SELF_TVAR_NAME,
ctx.cls.info.fullname + "." + SELF_TVAR_NAME,
id=-1,
fullname,
# Namespace is patched per-method below.
id=TypeVarId(-1, namespace=""),
values=[],
upper_bound=object_type,
default=AnyType(TypeOfAny.from_omitted_generics),
)
self_tvar_expr = TypeVarExpr(
SELF_TVAR_NAME,
ctx.cls.info.fullname + "." + SELF_TVAR_NAME,
[],
object_type,
AnyType(TypeOfAny.from_omitted_generics),
SELF_TVAR_NAME, fullname, [], object_type, AnyType(TypeOfAny.from_omitted_generics)
)
ctx.cls.info.names[SELF_TVAR_NAME] = SymbolTableNode(MDEF, self_tvar_expr)

args = [Argument(Var("other", tvd), tvd, None, ARG_POS)]
for method in ["__lt__", "__le__", "__gt__", "__ge__"]:
namespace = f"{ctx.cls.info.fullname}.{method}"
tvd = tvd.copy_modified(id=TypeVarId(tvd.id.raw_id, namespace=namespace))
args = [Argument(Var("other", tvd), tvd, None, ARG_POS)]
adder.add_method(method, args, bool_type, self_type=tvd, tvd=tvd)


Expand Down
5 changes: 3 additions & 2 deletions mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
TupleType,
Type,
TypeOfAny,
TypeVarId,
TypeVarType,
UninhabitedType,
UnionType,
Expand Down Expand Up @@ -314,8 +315,8 @@ def transform(self) -> bool:
obj_type = self._api.named_type("builtins.object")
order_tvar_def = TypeVarType(
SELF_TVAR_NAME,
info.fullname + "." + SELF_TVAR_NAME,
id=-1,
f"{info.fullname}.{SELF_TVAR_NAME}",
id=TypeVarId(-1, namespace=f"{info.fullname}.{method_name}"),
values=[],
upper_bound=obj_type,
default=AnyType(TypeOfAny.from_omitted_generics),
Expand Down
14 changes: 12 additions & 2 deletions mypy/plugins/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import mypy.checker
import mypy.plugin
from mypy.argmap import map_actuals_to_formals
from mypy.nodes import ARG_POS, ARG_STAR2, ArgKind, Argument, FuncItem, Var
from mypy.nodes import ARG_POS, ARG_STAR2, ArgKind, Argument, CallExpr, FuncItem, Var
from mypy.plugins.common import add_method_to_class
from mypy.types import (
AnyType,
Expand Down Expand Up @@ -151,12 +151,22 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
actual_arg_names = [a for param in ctx.arg_names[1:] for a in param]
actual_types = [a for param in ctx.arg_types[1:] for a in param]

# Create a valid context for various ad-hoc inspections in check_call().
call_expr = CallExpr(
callee=ctx.args[0][0],
args=actual_args,
arg_kinds=actual_arg_kinds,
arg_names=actual_arg_names,
analyzed=ctx.context.analyzed if isinstance(ctx.context, CallExpr) else None,
)
call_expr.set_line(ctx.context)

_, bound = ctx.api.expr_checker.check_call(
callee=defaulted,
args=actual_args,
arg_kinds=actual_arg_kinds,
arg_names=actual_arg_names,
context=defaulted,
context=call_expr,
)
bound = get_proper_type(bound)
if not isinstance(bound, CallableType):
Expand Down
Loading

0 comments on commit 6682563

Please sign in to comment.