From 84e3837ffdb9d40366738c1639c18b2788b32b15 Mon Sep 17 00:00:00 2001 From: Christoph Tyralla Date: Fri, 5 Apr 2024 01:34:33 +0200 Subject: [PATCH] Keep `TypeVar` arguments when narrowing generic subclasses with `isinstance` and `issubclass`. --- mypy/checker.py | 115 +++++++++++++++++++++- test-data/unit/check-narrowing.test | 146 ++++++++++++++++++++++++++++ 2 files changed, 260 insertions(+), 1 deletion(-) diff --git a/mypy/checker.py b/mypy/checker.py index 5d243195d50f..0dfe183f55d1 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -7454,7 +7454,10 @@ def conditional_types( ] ) remaining_type = restrict_subtype_away(current_type, proposed_precise_type) - return proposed_type, remaining_type + proposed_type_with_data = _transfer_type_var_args_from_current_to_proposed( + current_type, proposed_type + ) + return proposed_type_with_data, remaining_type else: # An isinstance check, but we don't understand the type return current_type, default @@ -7478,6 +7481,116 @@ def conditional_types_to_typemaps( return cast(Tuple[TypeMap, TypeMap], tuple(maps)) +def _transfer_type_var_args_from_current_to_proposed(current: Type, proposed: Type) -> Type: + """Check if the current type is among the bases of the proposed type. If so, try to transfer + the type variable arguments of the current type's instance to a copy of the proposed type's + instance. This increases information when narrowing generic classes so that, for example, + Sequence[int] is narrowed to List[int] instead of List[Any].""" + + def _get_instance_path_from_current_to_proposed( + this: Instance, target: TypeInfo + ) -> list[Instance] | None: + """Search for the current type among the bases of the proposed type and return the + "instance path" from the current to proposed type. Or None, if the current type is not a + nominal super type. At most one path is returned, which means there is no special handling + of (inconsistent) multiple inheritance.""" + if target == this.type: + return [this] + for base in this.type.bases: + path = _get_instance_path_from_current_to_proposed(base, target) + if path is not None: + path.append(this) + return path + return None + + # Handle "tuple of Instance" cases, e.g. `isinstance(x, (A, B))`: + proposed = get_proper_type(proposed) + if isinstance(proposed, UnionType): + items = [ + _transfer_type_var_args_from_current_to_proposed(current, item) + for item in flatten_nested_unions(proposed.items) + ] + return make_simplified_union(items) + + # Otherwise handle only Instances: + if not isinstance(proposed, Instance): + return proposed + + # Handle union cases like `a: A[int] | A[str]; isinstance(a, B)`: + current = get_proper_type(current) + if isinstance(current, UnionType): + items = [ + _transfer_type_var_args_from_current_to_proposed(item, proposed) + for item in flatten_nested_unions(current.items) + ] + return make_simplified_union(items) + + # Here comes the main logic: + if isinstance(current, Instance): + + # Only consider nominal subtyping: + instances = _get_instance_path_from_current_to_proposed(proposed, current.type) + if instances is None: + return proposed + assert len(instances) > 0 # shortest case: proposed type is current type + + # Make a list of the proposed type's type variable arguments that allows to replace each + # `Any` with one type variable argument or multiple type variable tuple arguments of the + # current type: + proposed_args: list[Type | tuple[Type, ...]] = list(proposed.args) + + # Try to transfer each type variable argument from the current to the base type separately: + for pos1, typevar1 in enumerate(instances[0].args): + if isinstance(typevar1, UnpackType): + typevar1 = typevar1.type + if not isinstance(typevar1, (TypeVarType, TypeVarTupleType)): + continue + # Find the position of the intermediate types' and finally the proposed type's + # related type variable (if not available, `pos2` becomes `None`): + for instance in instances[1:]: + pos2: int | None = None + for pos2, typevar2 in enumerate(instance.type.defn.type_vars): + if typevar1 == typevar2: + if instance.type.has_type_var_tuple_type: + assert (prefix := instance.type.type_var_tuple_prefix) is not None + if pos2 > prefix: + pos2 += len(instance.args) - len(instance.type.defn.type_vars) + typevar1 = instance.args[pos2] + if isinstance(typevar1, UnpackType): + typevar1 = typevar1.type + break + else: + pos2 = None + break + + # Transfer the current type's type variable argument or type variable tuple arguments: + if pos2 is not None: + if current.type.has_type_var_tuple_type: + assert (prefix := current.type.type_var_tuple_prefix) is not None + assert (suffix := current.type.type_var_tuple_suffix) is not None + if pos1 < prefix: + proposed_args[pos2] = current.args[pos1] + elif pos1 == prefix: + proposed_args[pos2] = current.args[prefix:len(current.args) - suffix] + else: + middle = len(current.args) - prefix - suffix + proposed_args[pos2] = current.args[pos1 + middle - 1] + else: + proposed_args[pos2] = current.args[pos1] + + # Combine all type variable and type variable tuple arguments to a flat list: + flattened_proposed_args: list[Type] = [] + for arg in proposed_args: + if isinstance(arg, tuple): + flattened_proposed_args.extend(arg) + else: + flattened_proposed_args.append(arg) + + return proposed.copy_modified(args=flattened_proposed_args) + + return proposed + + def gen_unique_name(base: str, table: SymbolTable) -> str: """Generate a name that does not appear in table by appending numbers to base.""" if base not in table: diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 4d117687554e..84d71ca93c62 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -2089,3 +2089,149 @@ if isinstance(x, (Z, NoneType)): # E: Subclass of "X" and "Z" cannot exist: "Z" reveal_type(x) # E: Statement is unreachable [builtins fixtures/isinstance.pyi] + +[case testKeepTypeVarArgsWhenNarrowingGenericsWithIsInstance] +from typing import Generic, Sequence, Tuple, TypeVar, Union + +s: Sequence[str] +if isinstance(s, tuple): + reveal_type(s) # N: Revealed type is "builtins.tuple[builtins.str, ...]" +else: + reveal_type(s) # N: Revealed type is "typing.Sequence[builtins.str]" +if isinstance(s, list): + reveal_type(s) # N: Revealed type is "builtins.list[builtins.str]" +else: + reveal_type(s) # N: Revealed type is "typing.Sequence[builtins.str]" + +t1: Tuple[str, int] +if isinstance(t1, tuple): + reveal_type(t1) # N: Revealed type is "Tuple[builtins.str, builtins.int]" +else: + reveal_type(t1) + +t2: Tuple[str, ...] +if isinstance(t2, tuple): + reveal_type(t2) # N: Revealed type is "builtins.tuple[builtins.str, ...]" +else: + reveal_type(t2) + +T1 = TypeVar("T1") +T2 = TypeVar("T2") +class A(Generic[T1]): ... +class B(A[T1], Generic[T1, T2]):... +a: A[str] +if isinstance(a, B): + reveal_type(a) # N: Revealed type is "__main__.B[builtins.str, Any]" +else: + reveal_type(a) # N: Revealed type is "__main__.A[builtins.str]" +class C(A[str], Generic[T1]):... +if isinstance(a, C): + reveal_type(a) # N: Revealed type is "__main__.C[Any]" +else: + reveal_type(a) # N: Revealed type is "__main__.A[builtins.str]" + +class AA(Generic[T1]): ... +class BB(A[T1], AA[T1], Generic[T1, T2]):... +aa: Union[A[int], Union[AA[str], AA[int]]] +if isinstance(aa, BB): + reveal_type(aa) # N: Revealed type is "Union[__main__.BB[builtins.int, Any], __main__.BB[builtins.str, Any]]" +else: + reveal_type(aa) # N: Revealed type is "Union[__main__.A[builtins.int], __main__.AA[builtins.str], __main__.AA[builtins.int]]" + +T3 = TypeVar("T3") +T4 = TypeVar("T4") +T5 = TypeVar("T5") +T6 = TypeVar("T6") +T7 = TypeVar("T7") +T8 = TypeVar("T8") +T9 = TypeVar("T9") +T10 = TypeVar("T10") +T11 = TypeVar("T11") +class A1(Generic[T1, T2]): ... +class A2(Generic[T3, T4]): ... +class B1(A1[T5, T6]):... +class B2(A2[T7, T8]):... +class C1(B1[T9, T10], B2[T11, T9]):... +a2: A2[str, int] +if isinstance(a2, C1): + reveal_type(a2) # N: Revealed type is "__main__.C1[builtins.int, Any, builtins.str]" +else: + reveal_type(a2) # N: Revealed type is "__main__.A2[builtins.str, builtins.int]" +[builtins fixtures/tuple.pyi] + +[case testKeepTypeVarArgsWhenNarrowingGenericsWithIsInstanceAndTuples] +from typing import Generic, TypeVar, Union + +T1 = TypeVar("T1") +T2 = TypeVar("T2") +class A(Generic[T1]): ... +class B(A[T1], Generic[T1, T2]):... +class C(A[T2], Generic[T1, T2]):... +a: Union[A[str], A[int]] +if isinstance(a, (B, C)): + reveal_type(a) # N: Revealed type is "Union[__main__.B[builtins.str, Any], __main__.C[Any, builtins.str], __main__.B[builtins.int, Any], __main__.C[Any, builtins.int]]" +else: + reveal_type(a) # N: Revealed type is "Union[__main__.A[builtins.str], __main__.A[builtins.int]]" +[builtins fixtures/isinstance.pyi] + + +[case testKeepTypeVarArgsWhenNarrowingGenericsWithIsSubclass] +from typing import Generic, Sequence, Type, TypeVar + +T1 = TypeVar("T1") +T2 = TypeVar("T2") +class A(Generic[T1]): ... +class B(A[T1], Generic[T1, T2]):... +a: Type[A[str]] +if issubclass(a, B): + reveal_type(a) # N: Revealed type is "Type[__main__.B[builtins.str, Any]]" +else: + reveal_type(a) # N: Revealed type is "Type[__main__.A[builtins.str]]" +class C(A[str], Generic[T1]):... +if issubclass(a, C): + reveal_type(a) # N: Revealed type is "Type[__main__.C[Any]]" +else: + reveal_type(a) # N: Revealed type is "Type[__main__.A[builtins.str]]" +[builtins fixtures/isinstance.pyi] + +[case testKeepTypeVarTupleArgsWhenNarrowingGenericsWithIsInstance] +from typing import Generic, Sequence, Tuple, TypeVar +from typing_extensions import TypeVarTuple, Unpack + +TP = TypeVarTuple("TP") +class A(Generic[Unpack[TP]]): ... +class B(A[Unpack[TP]]): ... +a: A[str, int] +if isinstance(a, B): + reveal_type(a) # N: Revealed type is "__main__.B[builtins.str, builtins.int]" +else: + reveal_type(a) # N: Revealed type is "__main__.A[builtins.str, builtins.int]" + +def f1(a: A[*Tuple[str, ...]]): + if isinstance(a, B): + reveal_type(a) # N: Revealed type is "__main__.B[Unpack[builtins.tuple[builtins.str, ...]]]" + +T = TypeVar("T") +def f2(a: A[T, str, T]): + if isinstance(a, B): + reveal_type(a) # N: Revealed type is "__main__.B[T`-1, builtins.str, T`-1]" + +T1 = TypeVar("T1") +T2 = TypeVar("T2") +T3 = TypeVar("T3") +T4 = TypeVar("T4") +T5 = TypeVar("T5") +T6 = TypeVar("T6") +class C(Generic[T1, Unpack[TP], T2]): ... +class D(C[T1, Unpack[TP], T2], Generic[T2, T4, T6, Unpack[TP], T5, T3, T1]): ... +class E(D[T1, T2, float, Unpack[TP], float, T3, T4]): ... +c: C[int, str, int, str] +if isinstance(c, E): + reveal_type(c) # N: Revealed type is "__main__.E[builtins.str, Any, builtins.str, builtins.int, Any, builtins.int]" +else: + reveal_type(c) # N: Revealed type is "__main__.C[builtins.int, builtins.str, builtins.int, builtins.str]" + +class F(E[T1, T2, str, int, T3, T4]): ... +if isinstance(c, F): + reveal_type(c) # N: Revealed type is "__main__.F[builtins.str, Any, Any, builtins.int]" +[builtins fixtures/tuple.pyi]