Skip to content

Commit

Permalink
Keep TypeVar arguments when narrowing generic subclasses with `isin…
Browse files Browse the repository at this point in the history
…stance` and `issubclass`.
  • Loading branch information
tyralla committed Apr 4, 2024
1 parent 4310586 commit 84e3837
Show file tree
Hide file tree
Showing 2 changed files with 260 additions and 1 deletion.
115 changes: 114 additions & 1 deletion mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7454,7 +7454,10 @@ def conditional_types(
]
)
remaining_type = restrict_subtype_away(current_type, proposed_precise_type)
return proposed_type, remaining_type
proposed_type_with_data = _transfer_type_var_args_from_current_to_proposed(
current_type, proposed_type
)
return proposed_type_with_data, remaining_type
else:
# An isinstance check, but we don't understand the type
return current_type, default
Expand All @@ -7478,6 +7481,116 @@ def conditional_types_to_typemaps(
return cast(Tuple[TypeMap, TypeMap], tuple(maps))


def _transfer_type_var_args_from_current_to_proposed(current: Type, proposed: Type) -> Type:
"""Check if the current type is among the bases of the proposed type. If so, try to transfer
the type variable arguments of the current type's instance to a copy of the proposed type's
instance. This increases information when narrowing generic classes so that, for example,
Sequence[int] is narrowed to List[int] instead of List[Any]."""

def _get_instance_path_from_current_to_proposed(
this: Instance, target: TypeInfo
) -> list[Instance] | None:
"""Search for the current type among the bases of the proposed type and return the
"instance path" from the current to proposed type. Or None, if the current type is not a
nominal super type. At most one path is returned, which means there is no special handling
of (inconsistent) multiple inheritance."""
if target == this.type:
return [this]
for base in this.type.bases:
path = _get_instance_path_from_current_to_proposed(base, target)
if path is not None:
path.append(this)
return path
return None

# Handle "tuple of Instance" cases, e.g. `isinstance(x, (A, B))`:
proposed = get_proper_type(proposed)
if isinstance(proposed, UnionType):
items = [
_transfer_type_var_args_from_current_to_proposed(current, item)
for item in flatten_nested_unions(proposed.items)
]
return make_simplified_union(items)

# Otherwise handle only Instances:
if not isinstance(proposed, Instance):
return proposed

# Handle union cases like `a: A[int] | A[str]; isinstance(a, B)`:
current = get_proper_type(current)
if isinstance(current, UnionType):
items = [
_transfer_type_var_args_from_current_to_proposed(item, proposed)
for item in flatten_nested_unions(current.items)
]
return make_simplified_union(items)

# Here comes the main logic:
if isinstance(current, Instance):

# Only consider nominal subtyping:
instances = _get_instance_path_from_current_to_proposed(proposed, current.type)
if instances is None:
return proposed
assert len(instances) > 0 # shortest case: proposed type is current type

# Make a list of the proposed type's type variable arguments that allows to replace each
# `Any` with one type variable argument or multiple type variable tuple arguments of the
# current type:
proposed_args: list[Type | tuple[Type, ...]] = list(proposed.args)

# Try to transfer each type variable argument from the current to the base type separately:
for pos1, typevar1 in enumerate(instances[0].args):
if isinstance(typevar1, UnpackType):
typevar1 = typevar1.type
if not isinstance(typevar1, (TypeVarType, TypeVarTupleType)):
continue
# Find the position of the intermediate types' and finally the proposed type's
# related type variable (if not available, `pos2` becomes `None`):
for instance in instances[1:]:
pos2: int | None = None
for pos2, typevar2 in enumerate(instance.type.defn.type_vars):
if typevar1 == typevar2:
if instance.type.has_type_var_tuple_type:
assert (prefix := instance.type.type_var_tuple_prefix) is not None
if pos2 > prefix:
pos2 += len(instance.args) - len(instance.type.defn.type_vars)
typevar1 = instance.args[pos2]
if isinstance(typevar1, UnpackType):
typevar1 = typevar1.type
break
else:
pos2 = None
break

# Transfer the current type's type variable argument or type variable tuple arguments:
if pos2 is not None:
if current.type.has_type_var_tuple_type:
assert (prefix := current.type.type_var_tuple_prefix) is not None
assert (suffix := current.type.type_var_tuple_suffix) is not None
if pos1 < prefix:
proposed_args[pos2] = current.args[pos1]
elif pos1 == prefix:
proposed_args[pos2] = current.args[prefix:len(current.args) - suffix]
else:
middle = len(current.args) - prefix - suffix
proposed_args[pos2] = current.args[pos1 + middle - 1]
else:
proposed_args[pos2] = current.args[pos1]

# Combine all type variable and type variable tuple arguments to a flat list:
flattened_proposed_args: list[Type] = []
for arg in proposed_args:
if isinstance(arg, tuple):
flattened_proposed_args.extend(arg)
else:
flattened_proposed_args.append(arg)

return proposed.copy_modified(args=flattened_proposed_args)

return proposed


def gen_unique_name(base: str, table: SymbolTable) -> str:
"""Generate a name that does not appear in table by appending numbers to base."""
if base not in table:
Expand Down
146 changes: 146 additions & 0 deletions test-data/unit/check-narrowing.test
Original file line number Diff line number Diff line change
Expand Up @@ -2089,3 +2089,149 @@ if isinstance(x, (Z, NoneType)): # E: Subclass of "X" and "Z" cannot exist: "Z"
reveal_type(x) # E: Statement is unreachable

[builtins fixtures/isinstance.pyi]

[case testKeepTypeVarArgsWhenNarrowingGenericsWithIsInstance]
from typing import Generic, Sequence, Tuple, TypeVar, Union

s: Sequence[str]
if isinstance(s, tuple):
reveal_type(s) # N: Revealed type is "builtins.tuple[builtins.str, ...]"
else:
reveal_type(s) # N: Revealed type is "typing.Sequence[builtins.str]"
if isinstance(s, list):
reveal_type(s) # N: Revealed type is "builtins.list[builtins.str]"
else:
reveal_type(s) # N: Revealed type is "typing.Sequence[builtins.str]"

t1: Tuple[str, int]
if isinstance(t1, tuple):
reveal_type(t1) # N: Revealed type is "Tuple[builtins.str, builtins.int]"
else:
reveal_type(t1)

t2: Tuple[str, ...]
if isinstance(t2, tuple):
reveal_type(t2) # N: Revealed type is "builtins.tuple[builtins.str, ...]"
else:
reveal_type(t2)

T1 = TypeVar("T1")
T2 = TypeVar("T2")
class A(Generic[T1]): ...
class B(A[T1], Generic[T1, T2]):...
a: A[str]
if isinstance(a, B):
reveal_type(a) # N: Revealed type is "__main__.B[builtins.str, Any]"
else:
reveal_type(a) # N: Revealed type is "__main__.A[builtins.str]"
class C(A[str], Generic[T1]):...
if isinstance(a, C):
reveal_type(a) # N: Revealed type is "__main__.C[Any]"
else:
reveal_type(a) # N: Revealed type is "__main__.A[builtins.str]"

class AA(Generic[T1]): ...
class BB(A[T1], AA[T1], Generic[T1, T2]):...
aa: Union[A[int], Union[AA[str], AA[int]]]
if isinstance(aa, BB):
reveal_type(aa) # N: Revealed type is "Union[__main__.BB[builtins.int, Any], __main__.BB[builtins.str, Any]]"
else:
reveal_type(aa) # N: Revealed type is "Union[__main__.A[builtins.int], __main__.AA[builtins.str], __main__.AA[builtins.int]]"

T3 = TypeVar("T3")
T4 = TypeVar("T4")
T5 = TypeVar("T5")
T6 = TypeVar("T6")
T7 = TypeVar("T7")
T8 = TypeVar("T8")
T9 = TypeVar("T9")
T10 = TypeVar("T10")
T11 = TypeVar("T11")
class A1(Generic[T1, T2]): ...
class A2(Generic[T3, T4]): ...
class B1(A1[T5, T6]):...
class B2(A2[T7, T8]):...
class C1(B1[T9, T10], B2[T11, T9]):...
a2: A2[str, int]
if isinstance(a2, C1):
reveal_type(a2) # N: Revealed type is "__main__.C1[builtins.int, Any, builtins.str]"
else:
reveal_type(a2) # N: Revealed type is "__main__.A2[builtins.str, builtins.int]"
[builtins fixtures/tuple.pyi]

[case testKeepTypeVarArgsWhenNarrowingGenericsWithIsInstanceAndTuples]
from typing import Generic, TypeVar, Union

T1 = TypeVar("T1")
T2 = TypeVar("T2")
class A(Generic[T1]): ...
class B(A[T1], Generic[T1, T2]):...
class C(A[T2], Generic[T1, T2]):...
a: Union[A[str], A[int]]
if isinstance(a, (B, C)):
reveal_type(a) # N: Revealed type is "Union[__main__.B[builtins.str, Any], __main__.C[Any, builtins.str], __main__.B[builtins.int, Any], __main__.C[Any, builtins.int]]"
else:
reveal_type(a) # N: Revealed type is "Union[__main__.A[builtins.str], __main__.A[builtins.int]]"
[builtins fixtures/isinstance.pyi]


[case testKeepTypeVarArgsWhenNarrowingGenericsWithIsSubclass]
from typing import Generic, Sequence, Type, TypeVar

T1 = TypeVar("T1")
T2 = TypeVar("T2")
class A(Generic[T1]): ...
class B(A[T1], Generic[T1, T2]):...
a: Type[A[str]]
if issubclass(a, B):
reveal_type(a) # N: Revealed type is "Type[__main__.B[builtins.str, Any]]"
else:
reveal_type(a) # N: Revealed type is "Type[__main__.A[builtins.str]]"
class C(A[str], Generic[T1]):...
if issubclass(a, C):
reveal_type(a) # N: Revealed type is "Type[__main__.C[Any]]"
else:
reveal_type(a) # N: Revealed type is "Type[__main__.A[builtins.str]]"
[builtins fixtures/isinstance.pyi]

[case testKeepTypeVarTupleArgsWhenNarrowingGenericsWithIsInstance]
from typing import Generic, Sequence, Tuple, TypeVar
from typing_extensions import TypeVarTuple, Unpack

TP = TypeVarTuple("TP")
class A(Generic[Unpack[TP]]): ...
class B(A[Unpack[TP]]): ...
a: A[str, int]
if isinstance(a, B):
reveal_type(a) # N: Revealed type is "__main__.B[builtins.str, builtins.int]"
else:
reveal_type(a) # N: Revealed type is "__main__.A[builtins.str, builtins.int]"

def f1(a: A[*Tuple[str, ...]]):
if isinstance(a, B):
reveal_type(a) # N: Revealed type is "__main__.B[Unpack[builtins.tuple[builtins.str, ...]]]"

T = TypeVar("T")
def f2(a: A[T, str, T]):
if isinstance(a, B):
reveal_type(a) # N: Revealed type is "__main__.B[T`-1, builtins.str, T`-1]"

T1 = TypeVar("T1")
T2 = TypeVar("T2")
T3 = TypeVar("T3")
T4 = TypeVar("T4")
T5 = TypeVar("T5")
T6 = TypeVar("T6")
class C(Generic[T1, Unpack[TP], T2]): ...
class D(C[T1, Unpack[TP], T2], Generic[T2, T4, T6, Unpack[TP], T5, T3, T1]): ...
class E(D[T1, T2, float, Unpack[TP], float, T3, T4]): ...
c: C[int, str, int, str]
if isinstance(c, E):
reveal_type(c) # N: Revealed type is "__main__.E[builtins.str, Any, builtins.str, builtins.int, Any, builtins.int]"
else:
reveal_type(c) # N: Revealed type is "__main__.C[builtins.int, builtins.str, builtins.int, builtins.str]"

class F(E[T1, T2, str, int, T3, T4]): ...
if isinstance(c, F):
reveal_type(c) # N: Revealed type is "__main__.F[builtins.str, Any, Any, builtins.int]"
[builtins fixtures/tuple.pyi]

0 comments on commit 84e3837

Please sign in to comment.