Skip to content

Commit

Permalink
Use polymorphic inference in unification (#17348)
Browse files Browse the repository at this point in the history
Moving towards #15907
Fixes #17206

This PR enables polymorphic inference during unification. This will
allow us to handle even more tricky situations involving generic
higher-order functions (see a random example I added in tests).
Implementation is mostly straightforward, few notes:
* This uncovered another issue with unions in solver, unfortunately
current constraint inference algorithm can sometimes infer weird
constraints like `T <: Union[T, int]`, that later confuse the solver.
* This uncovered another possible type variable clash scenario that was
not handled properly. In overloaded generic function, each overload
should have a different namespace for type variables (currently they all
just get function name). I use `module.some_func#0` etc. for overloads
namespaces instead.
* Another thing with overloads is that the switch caused unsafe overlap
check to change: after some back and forth I am keeping it mostly the
same to avoid possible regressions (unfortunately this requires some
extra refreshing of type variables).
* This makes another `ParamSpec` crash to happen more often so I fix it
in this same PR.
* Finally this uncovered a bug in handling of overloaded `__init__()`
that I am fixing here as well.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
ilevkivskyi and pre-commit-ci[bot] authored Jun 10, 2024
1 parent 5ae9e69 commit 83d54ff
Show file tree
Hide file tree
Showing 12 changed files with 253 additions and 65 deletions.
16 changes: 14 additions & 2 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,9 +791,21 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None:
if impl_type is not None:
assert defn.impl is not None

# This is what we want from implementation, it should accept all arguments
# of an overload, but the return types should go the opposite way.
if is_callable_compatible(
impl_type,
sig1,
is_compat=is_subtype,
is_proper_subtype=False,
is_compat_return=lambda l, r: is_subtype(r, l),
):
continue
# If the above check didn't work, we repeat some key steps in
# is_callable_compatible() to give a better error message.

# We perform a unification step that's very similar to what
# 'is_callable_compatible' would have done if we had set
# 'unify_generics' to True -- the only difference is that
# 'is_callable_compatible' does -- the only difference is that
# we check and see if the impl_type's return value is a
# *supertype* of the overload alternative, not a *subtype*.
#
Expand Down
9 changes: 7 additions & 2 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,14 +688,19 @@ def visit_unpack_type(self, template: UnpackType) -> list[Constraint]:

def visit_parameters(self, template: Parameters) -> list[Constraint]:
# Constraining Any against C[P] turns into infer_against_any([P], Any)
# ... which seems like the only case this can happen. Better to fail loudly otherwise.
if isinstance(self.actual, AnyType):
return self.infer_against_any(template.arg_types, self.actual)
if type_state.infer_polymorphic and isinstance(self.actual, Parameters):
# For polymorphic inference we need to be able to infer secondary constraints
# in situations like [x: T] <: P <: [x: int].
return infer_callable_arguments_constraints(template, self.actual, self.direction)
raise RuntimeError("Parameters cannot be constrained to")
if type_state.infer_polymorphic and isinstance(self.actual, ParamSpecType):
# Similar for [x: T] <: Q <: Concatenate[int, P].
return infer_callable_arguments_constraints(
template, self.actual.prefix, self.direction
)
# There also may be unpatched types after a user error, simply ignore them.
return []

# Non-leaf types

Expand Down
4 changes: 4 additions & 0 deletions mypy/message_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
)
INVALID_UNPACK: Final = "{} cannot be unpacked (must be tuple or TypeVarTuple)"
INVALID_UNPACK_POSITION: Final = "Unpack is only valid in a variadic position"
INVALID_PARAM_SPEC_LOCATION: Final = "Invalid location for ParamSpec {}"
INVALID_PARAM_SPEC_LOCATION_NOTE: Final = (
'You can use ParamSpec as the first argument to Callable, e.g., "Callable[{}, int]"'
)

# TypeVar
INCOMPATIBLE_TYPEVAR_VALUE: Final = 'Value of type variable "{}" of {} cannot be {}'
Expand Down
38 changes: 29 additions & 9 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,9 @@ def __init__(
# new uses of this, as this may cause leaking `UnboundType`s to type checking.
self.allow_unbound_tvars = False

# Used to pass information about current overload index to visit_func_def().
self.current_overload_item: int | None = None

# mypyc doesn't properly handle implementing an abstractproperty
# with a regular attribute so we make them properties
@property
Expand Down Expand Up @@ -869,6 +872,11 @@ def visit_func_def(self, defn: FuncDef) -> None:
with self.scope.function_scope(defn):
self.analyze_func_def(defn)

def function_fullname(self, fullname: str) -> str:
if self.current_overload_item is None:
return fullname
return f"{fullname}#{self.current_overload_item}"

def analyze_func_def(self, defn: FuncDef) -> None:
if self.push_type_args(defn.type_args, defn) is None:
self.defer(defn)
Expand All @@ -895,17 +903,16 @@ def analyze_func_def(self, defn: FuncDef) -> None:
self.prepare_method_signature(defn, self.type, has_self_type)

# Analyze function signature
with self.tvar_scope_frame(self.tvar_scope.method_frame(defn.fullname)):
fullname = self.function_fullname(defn.fullname)
with self.tvar_scope_frame(self.tvar_scope.method_frame(fullname)):
if defn.type:
self.check_classvar_in_signature(defn.type)
assert isinstance(defn.type, CallableType)
# Signature must be analyzed in the surrounding scope so that
# class-level imported names and type variables are in scope.
analyzer = self.type_analyzer()
tag = self.track_incomplete_refs()
result = analyzer.visit_callable_type(
defn.type, nested=False, namespace=defn.fullname
)
result = analyzer.visit_callable_type(defn.type, nested=False, namespace=fullname)
# Don't store not ready types (including placeholders).
if self.found_incomplete_ref(tag) or has_placeholder(result):
self.defer(defn)
Expand Down Expand Up @@ -1117,7 +1124,8 @@ def update_function_type_variables(self, fun_type: CallableType, defn: FuncItem)
if defn is generic. Return True, if the signature contains typing.Self
type, or False otherwise.
"""
with self.tvar_scope_frame(self.tvar_scope.method_frame(defn.fullname)):
fullname = self.function_fullname(defn.fullname)
with self.tvar_scope_frame(self.tvar_scope.method_frame(fullname)):
a = self.type_analyzer()
fun_type.variables, has_self_type = a.bind_function_type_variables(fun_type, defn)
if has_self_type and self.type is not None:
Expand Down Expand Up @@ -1175,6 +1183,14 @@ def visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:
with self.scope.function_scope(defn):
self.analyze_overloaded_func_def(defn)

@contextmanager
def overload_item_set(self, item: int | None) -> Iterator[None]:
self.current_overload_item = item
try:
yield
finally:
self.current_overload_item = None

def analyze_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:
# OverloadedFuncDef refers to any legitimate situation where you have
# more than one declaration for the same function in a row. This occurs
Expand All @@ -1187,7 +1203,8 @@ def analyze_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:

first_item = defn.items[0]
first_item.is_overload = True
first_item.accept(self)
with self.overload_item_set(0):
first_item.accept(self)

if isinstance(first_item, Decorator) and first_item.func.is_property:
# This is a property.
Expand Down Expand Up @@ -1272,7 +1289,8 @@ def analyze_overload_sigs_and_impl(
if i != 0:
# Assume that the first item was already visited
item.is_overload = True
item.accept(self)
with self.overload_item_set(i if i < len(defn.items) - 1 else None):
item.accept(self)
# TODO: support decorated overloaded functions properly
if isinstance(item, Decorator):
callable = function_type(item.func, self.named_type("builtins.function"))
Expand Down Expand Up @@ -1444,15 +1462,17 @@ def add_function_to_symbol_table(self, func: FuncDef | OverloadedFuncDef) -> Non
self.add_symbol(func.name, func, func)

def analyze_arg_initializers(self, defn: FuncItem) -> None:
with self.tvar_scope_frame(self.tvar_scope.method_frame(defn.fullname)):
fullname = self.function_fullname(defn.fullname)
with self.tvar_scope_frame(self.tvar_scope.method_frame(fullname)):
# Analyze default arguments
for arg in defn.arguments:
if arg.initializer:
arg.initializer.accept(self)

def analyze_function_body(self, defn: FuncItem) -> None:
is_method = self.is_class_scope()
with self.tvar_scope_frame(self.tvar_scope.method_frame(defn.fullname)):
fullname = self.function_fullname(defn.fullname)
with self.tvar_scope_frame(self.tvar_scope.method_frame(fullname)):
# Bind the type variables again to visit the body.
if defn.type:
a = self.type_analyzer()
Expand Down
22 changes: 18 additions & 4 deletions mypy/semanal_typeargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from mypy import errorcodes as codes, message_registry
from mypy.errorcodes import ErrorCode
from mypy.errors import Errors
from mypy.message_registry import INVALID_PARAM_SPEC_LOCATION, INVALID_PARAM_SPEC_LOCATION_NOTE
from mypy.messages import format_type
from mypy.mixedtraverser import MixedTraverserVisitor
from mypy.nodes import ARG_STAR, Block, ClassDef, Context, FakeInfo, FuncItem, MypyFile
Expand Down Expand Up @@ -146,13 +147,25 @@ def validate_args(
for (i, arg), tvar in zip(enumerate(args), type_vars):
if isinstance(tvar, TypeVarType):
if isinstance(arg, ParamSpecType):
# TODO: Better message
is_error = True
self.fail(f'Invalid location for ParamSpec "{arg.name}"', ctx)
self.fail(
INVALID_PARAM_SPEC_LOCATION.format(format_type(arg, self.options)),
ctx,
code=codes.VALID_TYPE,
)
self.note(
"You can use ParamSpec as the first argument to Callable, e.g., "
"'Callable[{}, int]'".format(arg.name),
INVALID_PARAM_SPEC_LOCATION_NOTE.format(arg.name),
ctx,
code=codes.VALID_TYPE,
)
continue
if isinstance(arg, Parameters):
is_error = True
self.fail(
f"Cannot use {format_type(arg, self.options)} for regular type variable,"
" only for ParamSpec",
ctx,
code=codes.VALID_TYPE,
)
continue
if tvar.values:
Expand Down Expand Up @@ -204,6 +217,7 @@ def validate_args(
"Can only replace ParamSpec with a parameter types list or"
f" another ParamSpec, got {format_type(arg, self.options)}",
ctx,
code=codes.VALID_TYPE,
)
return is_error

Expand Down
9 changes: 8 additions & 1 deletion mypy/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,8 @@ def skip_reverse_union_constraints(cs: list[Constraint]) -> list[Constraint]:
is a linear constraint. This is however not true in presence of union types, for example
T :> Union[S, int] vs S <: T. Trying to solve such constraints would be detected ambiguous
as (T, S) form a non-linear SCC. However, simply removing the linear part results in a valid
solution T = Union[S, int], S = <free>.
solution T = Union[S, int], S = <free>. A similar scenario is when we get T <: Union[T, int],
such constraints carry no information, and will equally confuse linearity check.
TODO: a cleaner solution may be to avoid inferring such constraints in first place, but
this would require passing around a flag through all infer_constraints() calls.
Expand All @@ -525,7 +526,13 @@ def skip_reverse_union_constraints(cs: list[Constraint]) -> list[Constraint]:
if isinstance(p_target, UnionType):
for item in p_target.items:
if isinstance(item, TypeVarType):
if item == c.origin_type_var and c.op == SUBTYPE_OF:
reverse_union_cs.add(c)
continue
# These two forms are semantically identical, but are different from
# the point of view of Constraint.__eq__().
reverse_union_cs.add(Constraint(item, neg_op(c.op), c.origin_type_var))
reverse_union_cs.add(Constraint(c.origin_type_var, c.op, item))
return [c for c in cs if c not in reverse_union_cs]


Expand Down
16 changes: 14 additions & 2 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
import mypy.constraints
import mypy.typeops
from mypy.erasetype import erase_type
from mypy.expandtype import expand_self_type, expand_type, expand_type_by_instance
from mypy.expandtype import (
expand_self_type,
expand_type,
expand_type_by_instance,
freshen_function_type_vars,
)
from mypy.maptype import map_instance_to_supertype

# Circular import; done in the function instead.
Expand Down Expand Up @@ -1860,6 +1865,11 @@ def unify_generic_callable(
"""
import mypy.solve

if set(type.type_var_ids()) & {v.id for v in mypy.typeops.get_all_type_vars(target)}:
# Overload overlap check does nasty things like unifying in opposite direction.
# This can easily create type variable clashes, so we need to refresh.
type = freshen_function_type_vars(type)

if return_constraint_direction is None:
return_constraint_direction = mypy.constraints.SUBTYPE_OF

Expand All @@ -1882,7 +1892,9 @@ def unify_generic_callable(
constraints = [
c for c in constraints if not isinstance(get_proper_type(c.target), NoneType)
]
inferred_vars, _ = mypy.solve.solve_constraints(type.variables, constraints)
inferred_vars, _ = mypy.solve.solve_constraints(
type.variables, constraints, allow_polymorphic=True
)
if None in inferred_vars:
return None
non_none_inferred_vars = cast(List[Type], inferred_vars)
Expand Down
17 changes: 13 additions & 4 deletions mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@
from mypy import errorcodes as codes, message_registry, nodes
from mypy.errorcodes import ErrorCode
from mypy.expandtype import expand_type
from mypy.messages import MessageBuilder, format_type_bare, quote_type_string, wrong_type_arg_count
from mypy.message_registry import INVALID_PARAM_SPEC_LOCATION, INVALID_PARAM_SPEC_LOCATION_NOTE
from mypy.messages import (
MessageBuilder,
format_type,
format_type_bare,
quote_type_string,
wrong_type_arg_count,
)
from mypy.nodes import (
ARG_NAMED,
ARG_NAMED_OPT,
Expand Down Expand Up @@ -1782,12 +1789,14 @@ def anal_type(
analyzed = AnyType(TypeOfAny.from_error)
else:
self.fail(
f'Invalid location for ParamSpec "{analyzed.name}"', t, code=codes.VALID_TYPE
INVALID_PARAM_SPEC_LOCATION.format(format_type(analyzed, self.options)),
t,
code=codes.VALID_TYPE,
)
self.note(
"You can use ParamSpec as the first argument to Callable, e.g., "
"'Callable[{}, int]'".format(analyzed.name),
INVALID_PARAM_SPEC_LOCATION_NOTE.format(analyzed.name),
t,
code=codes.VALID_TYPE,
)
analyzed = AnyType(TypeOfAny.from_error)
return analyzed
Expand Down
33 changes: 25 additions & 8 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,14 @@ def type_object_type_from_function(
# ...
#
# We need to map B's __init__ to the type (List[T]) -> None.
signature = bind_self(signature, original_type=default_self, is_classmethod=is_new)
signature = bind_self(
signature,
original_type=default_self,
is_classmethod=is_new,
# Explicit instance self annotations have special handling in class_callable(),
# we don't need to bind any type variables in them if they are generic.
ignore_instances=True,
)
signature = cast(FunctionLike, map_type_from_supertype(signature, info, def_info))

special_sig: str | None = None
Expand Down Expand Up @@ -244,7 +251,9 @@ class C(D[E[T]], Generic[T]): ...
return expand_type_by_instance(typ, inst_type)


def supported_self_type(typ: ProperType, allow_callable: bool = True) -> bool:
def supported_self_type(
typ: ProperType, allow_callable: bool = True, allow_instances: bool = True
) -> bool:
"""Is this a supported kind of explicit self-types?
Currently, this means an X or Type[X], where X is an instance or
Expand All @@ -257,14 +266,19 @@ def supported_self_type(typ: ProperType, allow_callable: bool = True) -> bool:
# as well as callable self for callback protocols.
return True
return isinstance(typ, TypeVarType) or (
isinstance(typ, Instance) and typ != fill_typevars(typ.type)
allow_instances and isinstance(typ, Instance) and typ != fill_typevars(typ.type)
)


F = TypeVar("F", bound=FunctionLike)


def bind_self(method: F, original_type: Type | None = None, is_classmethod: bool = False) -> F:
def bind_self(
method: F,
original_type: Type | None = None,
is_classmethod: bool = False,
ignore_instances: bool = False,
) -> F:
"""Return a copy of `method`, with the type of its first parameter (usually
self or cls) bound to original_type.
Expand All @@ -288,9 +302,10 @@ class B(A): pass
"""
if isinstance(method, Overloaded):
return cast(
F, Overloaded([bind_self(c, original_type, is_classmethod) for c in method.items])
)
items = [
bind_self(c, original_type, is_classmethod, ignore_instances) for c in method.items
]
return cast(F, Overloaded(items))
assert isinstance(method, CallableType)
func = method
if not func.arg_types:
Expand All @@ -310,7 +325,9 @@ class B(A): pass
# this special-casing looks not very principled, there is nothing meaningful we can infer
# from such definition, since it is inherently indefinitely recursive.
allow_callable = func.name is None or not func.name.startswith("__call__ of")
if func.variables and supported_self_type(self_param_type, allow_callable=allow_callable):
if func.variables and supported_self_type(
self_param_type, allow_callable=allow_callable, allow_instances=not ignore_instances
):
from mypy.infer import infer_type_arguments

if original_type is None:
Expand Down
Loading

0 comments on commit 83d54ff

Please sign in to comment.