Skip to content

Commit

Permalink
Address CR
Browse files Browse the repository at this point in the history
  • Loading branch information
ilevkivskyi committed Jun 18, 2023
1 parent 96d0f39 commit 66b4567
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 4 deletions.
31 changes: 27 additions & 4 deletions mypy/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def solve_non_linear(
The whole algorithm consists of five steps:
* Propagate via linear constraints to get all possible constraints for each variable
* Find dependencies between type variables, group them in SCCs, and sor topologically
* Find dependencies between type variables, group them in SCCs, and sort topologically
* Check all SCC are intrinsically linear, we can't solve (express) T <: List[T]
* Variables in leaf SCCs that don't have constant bounds are free (choose one per SCC)
* Solve constraints iteratively starting from leafs, updating targets after each step.
Expand All @@ -112,11 +112,18 @@ def solve_non_linear(
leafs = raw_batches[0]
free_vars = []
for scc in leafs:
# If all constrain targets in this SCC are type variables within the
# same SCC then the only meaningful solution we can express, is that
# each variable is equal to a new free variable. For example if we
# have T <: S, S <: U, we deduce: T = S = U = <free>.
if all(
isinstance(c.target, TypeVarType) and c.target.id in vars
for tv in scc
for c in cmap[tv]
):
# For convenience with current type application machinery, we randomly
# choose one of the existing type variables in SCC and designate it as free
# instead of defining a new type variable as a common solution.
# TODO: be careful about upper bounds (or values) when introducing free vars.
free_vars.append(sorted(scc, key=lambda x: x.raw_id)[0])

Expand Down Expand Up @@ -146,7 +153,17 @@ def solve_non_linear(
def solve_iteratively(
batch: list[TypeVarId], cmap: dict[TypeVarId, list[Constraint]], free_vars: list[TypeVarId]
) -> dict[TypeVarId, Type | None]:
"""Solve constraints for type variables sequentially, updating targets after each step."""
"""Solve constraints sequentially, updating constraint targets after each step.
We solve for type variables that appear in `batch`. If a constraint target is not constant
(i.e. constraint looks like T :> F[S, ...]), we substitute solutions found so far in
the target F[S, ...]. This way we can gradually solve for all variables in the batch taking
one solvable variable at a time (i.e. such a variable that has at least one constant bound).
Importantly, variables in free_vars are considered constants, so for example if we have just
one initial constraint T <: List[S], we will have two SCCs {T} and {S}, then we first
designate S as free, and therefore T = List[S] is a valid solution for T.
"""
solutions = {}
relevant_constraints = []
for tv in batch:
Expand Down Expand Up @@ -293,8 +310,14 @@ def transitive_closure(
) -> tuple[dict[TypeVarId, set[Type]], dict[TypeVarId, set[Type]]]:
"""Find transitive closure for given constraints on type variables.
Transitive closure gives maximal set of lower/upper bounds for each type variable, such
we cannot deduce any further bounds by chaining other existing bounds.
Transitive closure gives maximal set of lower/upper bounds for each type variable,
such that we cannot deduce any further bounds by chaining other existing bounds.
For example if we have initial constraints [T <: S, S <: U, U <: int], the transitive
closure is given by:
* {} <: T <: {S, U, int}
* {T} <: S <: {U, int}
* {T, S} <: U <: {int}
"""
# TODO: merge propagate_constraints_for() into this function.
# TODO: add secondary constraints here to make the algorithm complete.
Expand Down
21 changes: 21 additions & 0 deletions test-data/unit/check-generics.test
Original file line number Diff line number Diff line change
Expand Up @@ -2734,6 +2734,9 @@ dict2 = {"a": C1(), **{x: C2() for x in dict1}}
reveal_type(dict2) # N: Revealed type is "builtins.dict[Any, __main__.B]"
[builtins fixtures/dict.pyi]

-- Type inference for generic decorators applied to generic callables
-- ------------------------------------------------------------------

[case testInferenceAgainstGenericCallable]
# flags: --new-type-inference
from typing import TypeVar, Callable, List
Expand Down Expand Up @@ -2794,6 +2797,12 @@ def dec(f: Callable[[S], T]) -> Callable[[S], List[T]]:
def id(x: U) -> U:
...
reveal_type(dec(id)) # N: Revealed type is "def [S] (S`1) -> builtins.list[S`1]"

@dec
def same(x: U) -> U:
...
reveal_type(same) # N: Revealed type is "def [S] (S`3) -> builtins.list[S`3]"
reveal_type(same(42)) # N: Revealed type is "builtins.list[builtins.int]"
[builtins fixtures/list.pyi]

[case testInferenceAgainstGenericCallableGenericReverse]
Expand All @@ -2809,6 +2818,12 @@ def dec(f: Callable[[S], List[T]]) -> Callable[[S], T]:
def id(x: U) -> U:
...
reveal_type(dec(id)) # N: Revealed type is "def [T] (builtins.list[T`2]) -> T`2"

@dec
def same(x: U) -> U:
...
reveal_type(same) # N: Revealed type is "def [T] (builtins.list[T`4]) -> T`4"
reveal_type(same([42])) # N: Revealed type is "builtins.int"
[builtins fixtures/list.pyi]

[case testInferenceAgainstGenericCallableGenericArg]
Expand All @@ -2824,6 +2839,12 @@ def dec(f: Callable[[S], T]) -> Callable[[S], T]:
def test(x: U) -> List[U]:
...
reveal_type(dec(test)) # N: Revealed type is "def [S] (S`1) -> builtins.list[S`1]"

@dec
def single(x: U) -> List[U]:
...
reveal_type(single) # N: Revealed type is "def [S] (S`3) -> builtins.list[S`3]"
reveal_type(single(42)) # N: Revealed type is "builtins.list[builtins.int]"
[builtins fixtures/list.pyi]

[case testInferenceAgainstGenericCallableGenericChain]
Expand Down

0 comments on commit 66b4567

Please sign in to comment.