diff --git a/mypy/solve.py b/mypy/solve.py index efe8e487c506..9770364bf892 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -6,7 +6,7 @@ from typing import Iterable, Sequence from typing_extensions import TypeAlias as _TypeAlias -from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, infer_constraints +from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, infer_constraints, neg_op from mypy.expandtype import expand_type from mypy.graph_utils import prepare_sccs, strongly_connected_components, topsort from mypy.join import join_types @@ -69,6 +69,10 @@ def solve_constraints( extra_vars.extend([v.id for v in c.extra_tvars if v.id not in vars + extra_vars]) originals.update({v.id: v for v in c.extra_tvars if v.id not in originals}) + if allow_polymorphic: + # Constraints inferred from unions require special handling in polymorphic inference. + constraints = skip_reverse_union_constraints(constraints) + # Collect a list of constraints for each type variable. cmap: dict[TypeVarId, list[Constraint]] = {tv: [] for tv in vars + extra_vars} for con in constraints: @@ -431,10 +435,7 @@ def transitive_closure( uppers[l] |= uppers[upper] for lt in lowers[lower]: for ut in uppers[upper]: - # TODO: what if secondary constraints result in inference - # against polymorphic actual (also in below branches)? - remaining |= set(infer_constraints(lt, ut, SUBTYPE_OF)) - remaining |= set(infer_constraints(ut, lt, SUPERTYPE_OF)) + add_secondary_constraints(remaining, lt, ut) elif c.op == SUBTYPE_OF: if c.target in uppers[c.type_var]: continue @@ -442,8 +443,7 @@ def transitive_closure( if (l, c.type_var) in graph: uppers[l].add(c.target) for lt in lowers[c.type_var]: - remaining |= set(infer_constraints(lt, c.target, SUBTYPE_OF)) - remaining |= set(infer_constraints(c.target, lt, SUPERTYPE_OF)) + add_secondary_constraints(remaining, lt, c.target) else: assert c.op == SUPERTYPE_OF if c.target in lowers[c.type_var]: @@ -452,11 +452,24 @@ def transitive_closure( if (c.type_var, u) in graph: lowers[u].add(c.target) for ut in uppers[c.type_var]: - remaining |= set(infer_constraints(ut, c.target, SUPERTYPE_OF)) - remaining |= set(infer_constraints(c.target, ut, SUBTYPE_OF)) + add_secondary_constraints(remaining, c.target, ut) return graph, lowers, uppers +def add_secondary_constraints(cs: set[Constraint], lower: Type, upper: Type) -> None: + """Add secondary constraints inferred between lower and upper (in place).""" + if isinstance(get_proper_type(upper), UnionType) and isinstance( + get_proper_type(lower), UnionType + ): + # When both types are unions, this can lead to inferring spurious constraints, + # for example Union[T, int] <: S <: Union[T, int] may infer T <: int. + # To avoid this, just skip them for now. + return + # TODO: what if secondary constraints result in inference against polymorphic actual? + cs.update(set(infer_constraints(lower, upper, SUBTYPE_OF))) + cs.update(set(infer_constraints(upper, lower, SUPERTYPE_OF))) + + def compute_dependencies( tvars: list[TypeVarId], graph: Graph, lowers: Bounds, uppers: Bounds ) -> dict[TypeVarId, list[TypeVarId]]: @@ -494,6 +507,28 @@ def check_linear(scc: set[TypeVarId], lowers: Bounds, uppers: Bounds) -> bool: return True +def skip_reverse_union_constraints(cs: list[Constraint]) -> list[Constraint]: + """Avoid ambiguities for constraints inferred from unions during polymorphic inference. + + Polymorphic inference implicitly relies on assumption that a reverse of a linear constraint + is a linear constraint. This is however not true in presence of union types, for example + T :> Union[S, int] vs S <: T. Trying to solve such constraints would be detected ambiguous + as (T, S) form a non-linear SCC. However, simply removing the linear part results in a valid + solution T = Union[S, int], S = . + + TODO: a cleaner solution may be to avoid inferring such constraints in first place, but + this would require passing around a flag through all infer_constraints() calls. + """ + reverse_union_cs = set() + for c in cs: + p_target = get_proper_type(c.target) + if isinstance(p_target, UnionType): + for item in p_target.items: + if isinstance(item, TypeVarType): + reverse_union_cs.add(Constraint(item, neg_op(c.op), c.origin_type_var)) + return [c for c in cs if c not in reverse_union_cs] + + def get_vars(target: Type, vars: list[TypeVarId]) -> set[TypeVarId]: """Find type variables for which we are solving in a target type.""" return {tv.id for tv in get_all_type_vars(target)} & set(vars) diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index 0d162238450a..6c98ba2088b1 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -3767,3 +3767,24 @@ def f(values: List[T]) -> T: ... x = foo(f([C()])) reveal_type(x) # N: Revealed type is "__main__.C" [builtins fixtures/list.pyi] + +[case testInferenceAgainstGenericCallableUnion] +from typing import Callable, TypeVar, List, Union + +T = TypeVar("T") +S = TypeVar("S") + +def dec(f: Callable[[S], T]) -> Callable[[S], List[T]]: ... +@dec +def func(arg: T) -> Union[T, str]: + ... +reveal_type(func) # N: Revealed type is "def [S] (S`1) -> builtins.list[Union[S`1, builtins.str]]" +reveal_type(func(42)) # N: Revealed type is "builtins.list[Union[builtins.int, builtins.str]]" + +def dec2(f: Callable[[S], List[T]]) -> Callable[[S], T]: ... +@dec2 +def func2(arg: T) -> List[Union[T, str]]: + ... +reveal_type(func2) # N: Revealed type is "def [S] (S`4) -> Union[S`4, builtins.str]" +reveal_type(func2(42)) # N: Revealed type is "Union[builtins.int, builtins.str]" +[builtins fixtures/list.pyi] diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index db8c76fd21e9..eb6fbf07f045 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -2086,3 +2086,25 @@ reveal_type(d(b, f1)) # E: Cannot infer type argument 1 of "d" \ # N: Revealed type is "def (*Any, **Any)" reveal_type(d(b, f2)) # N: Revealed type is "def (builtins.int)" [builtins fixtures/paramspec.pyi] + +[case testInferenceAgainstGenericCallableUnionParamSpec] +from typing import Callable, TypeVar, List, Union +from typing_extensions import ParamSpec + +T = TypeVar("T") +P = ParamSpec("P") + +def dec(f: Callable[P, T]) -> Callable[P, List[T]]: ... +@dec +def func(arg: T) -> Union[T, str]: + ... +reveal_type(func) # N: Revealed type is "def [T] (arg: T`-1) -> builtins.list[Union[T`-1, builtins.str]]" +reveal_type(func(42)) # N: Revealed type is "builtins.list[Union[builtins.int, builtins.str]]" + +def dec2(f: Callable[P, List[T]]) -> Callable[P, T]: ... +@dec2 +def func2(arg: T) -> List[Union[T, str]]: + ... +reveal_type(func2) # N: Revealed type is "def [T] (arg: T`-1) -> Union[T`-1, builtins.str]" +reveal_type(func2(42)) # N: Revealed type is "Union[builtins.int, builtins.str]" +[builtins fixtures/paramspec.pyi]