diff --git a/docs/source/conf.py b/docs/source/conf.py index fa76734054ac..5934c7474536 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -92,7 +92,7 @@ # show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = "sphinx" +# pygments_style = "sphinx" # A list of ignored prefixes for module index sorting. # modindex_common_prefix = [] diff --git a/mypy/checker.py b/mypy/checker.py index 0ae499916ec6..59de599006a8 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -6011,11 +6011,16 @@ def has_no_custom_eq_checks(t: Type) -> bool: if_map, else_map = {}, {} if left_index in narrowable_operand_index_to_hash: - # We only try and narrow away 'None' for now - if is_overlapping_none(item_type): - collection_item_type = get_proper_type( - builtin_item_type(iterable_type) - ) + collection_item_type = get_proper_type(builtin_item_type(iterable_type)) + # Narrow if the collection is a subtype + if ( + collection_item_type is not None + and collection_item_type != item_type + and is_subtype(collection_item_type, item_type) + ): + if_map[operands[left_index]] = collection_item_type + # Try and narrow away 'None' + elif is_overlapping_none(item_type): if ( collection_item_type is not None and not is_overlapping_none(collection_item_type) diff --git a/mypy/fastparse.py b/mypy/fastparse.py index 75c4bd46550c..363fc8375259 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -181,10 +181,12 @@ def ast3_parse( if sys.version_info >= (3, 12): ast_TypeAlias = ast3.TypeAlias ast_ParamSpec = ast3.ParamSpec + ast_TypeVar = ast3.TypeVar ast_TypeVarTuple = ast3.TypeVarTuple else: ast_TypeAlias = Any ast_ParamSpec = Any + ast_TypeVar = Any ast_TypeVarTuple = Any N = TypeVar("N", bound=Node) @@ -345,6 +347,15 @@ def is_no_type_check_decorator(expr: ast3.expr) -> bool: return False +def find_disallowed_expression_in_annotation_scope(expr: ast3.expr | None) -> ast3.expr | None: + if expr is None: + return None + for node in ast3.walk(expr): + if isinstance(node, (ast3.Yield, ast3.YieldFrom, ast3.NamedExpr, ast3.Await)): + return node + return None + + class ASTConverter: def __init__( self, @@ -1180,6 +1191,29 @@ def visit_ClassDef(self, n: ast3.ClassDef) -> ClassDef: self.class_and_function_stack.pop() return cdef + def validate_type_param(self, type_param: ast_TypeVar) -> None: + incorrect_expr = find_disallowed_expression_in_annotation_scope(type_param.bound) + if incorrect_expr is None: + return + if isinstance(incorrect_expr, (ast3.Yield, ast3.YieldFrom)): + self.fail( + message_registry.TYPE_VAR_YIELD_EXPRESSION_IN_BOUND, + type_param.lineno, + type_param.col_offset, + ) + if isinstance(incorrect_expr, ast3.NamedExpr): + self.fail( + message_registry.TYPE_VAR_NAMED_EXPRESSION_IN_BOUND, + type_param.lineno, + type_param.col_offset, + ) + if isinstance(incorrect_expr, ast3.Await): + self.fail( + message_registry.TYPE_VAR_AWAIT_EXPRESSION_IN_BOUND, + type_param.lineno, + type_param.col_offset, + ) + def translate_type_params(self, type_params: list[Any]) -> list[TypeParam]: explicit_type_params = [] for p in type_params: @@ -1202,6 +1236,7 @@ def translate_type_params(self, type_params: list[Any]) -> list[TypeParam]: conv = TypeConverter(self.errors, line=p.lineno) values = [conv.visit(t) for t in p.bound.elts] elif p.bound is not None: + self.validate_type_param(p) bound = TypeConverter(self.errors, line=p.lineno).visit(p.bound) explicit_type_params.append(TypeParam(p.name, TYPE_VAR_KIND, bound, values)) return explicit_type_params @@ -1791,11 +1826,23 @@ def visit_MatchOr(self, n: MatchOr) -> OrPattern: node = OrPattern([self.visit(pattern) for pattern in n.patterns]) return self.set_line(node, n) + def validate_type_alias(self, n: ast_TypeAlias) -> None: + incorrect_expr = find_disallowed_expression_in_annotation_scope(n.value) + if incorrect_expr is None: + return + if isinstance(incorrect_expr, (ast3.Yield, ast3.YieldFrom)): + self.fail(message_registry.TYPE_ALIAS_WITH_YIELD_EXPRESSION, n.lineno, n.col_offset) + if isinstance(incorrect_expr, ast3.NamedExpr): + self.fail(message_registry.TYPE_ALIAS_WITH_NAMED_EXPRESSION, n.lineno, n.col_offset) + if isinstance(incorrect_expr, ast3.Await): + self.fail(message_registry.TYPE_ALIAS_WITH_AWAIT_EXPRESSION, n.lineno, n.col_offset) + # TypeAlias(identifier name, type_param* type_params, expr value) def visit_TypeAlias(self, n: ast_TypeAlias) -> TypeAliasStmt | AssignmentStmt: node: TypeAliasStmt | AssignmentStmt if NEW_GENERIC_SYNTAX in self.options.enable_incomplete_feature: type_params = self.translate_type_params(n.type_params) + self.validate_type_alias(n) value = self.visit(n.value) # Since the value is evaluated lazily, wrap the value inside a lambda. # This helps mypyc. diff --git a/mypy/message_registry.py b/mypy/message_registry.py index 06199e70d6b4..29d539faaed6 100644 --- a/mypy/message_registry.py +++ b/mypy/message_registry.py @@ -338,3 +338,27 @@ def with_additional_msg(self, info: str) -> ErrorMessage: TYPE_VAR_TOO_FEW_CONSTRAINED_TYPES: Final = ErrorMessage( "Type variable must have at least two constrained types", codes.MISC ) + +TYPE_VAR_YIELD_EXPRESSION_IN_BOUND: Final = ErrorMessage( + "Yield expression cannot be used as a type variable bound", codes.SYNTAX +) + +TYPE_VAR_NAMED_EXPRESSION_IN_BOUND: Final = ErrorMessage( + "Named expression cannot be used as a type variable bound", codes.SYNTAX +) + +TYPE_VAR_AWAIT_EXPRESSION_IN_BOUND: Final = ErrorMessage( + "Await expression cannot be used as a type variable bound", codes.SYNTAX +) + +TYPE_ALIAS_WITH_YIELD_EXPRESSION: Final = ErrorMessage( + "Yield expression cannot be used within a type alias", codes.SYNTAX +) + +TYPE_ALIAS_WITH_NAMED_EXPRESSION: Final = ErrorMessage( + "Named expression cannot be used within a type alias", codes.SYNTAX +) + +TYPE_ALIAS_WITH_AWAIT_EXPRESSION: Final = ErrorMessage( + "Await expression cannot be used within a type alias", codes.SYNTAX +) diff --git a/mypy/plugin.py b/mypy/plugin.py index 858795addb7f..a1af7fa76350 100644 --- a/mypy/plugin.py +++ b/mypy/plugin.py @@ -240,7 +240,7 @@ def type_context(self) -> list[Type | None]: @abstractmethod def fail( - self, msg: str | ErrorMessage, ctx: Context, *, code: ErrorCode | None = None + self, msg: str | ErrorMessage, ctx: Context, /, *, code: ErrorCode | None = None ) -> None: """Emit an error message at given location.""" raise NotImplementedError diff --git a/mypy/types.py b/mypy/types.py index 91b40536f1cf..52b3121f9fb3 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -2900,12 +2900,19 @@ def relevant_items(self) -> list[Type]: return [i for i in self.items if not isinstance(get_proper_type(i), NoneType)] def serialize(self) -> JsonDict: - return {".class": "UnionType", "items": [t.serialize() for t in self.items]} + return { + ".class": "UnionType", + "items": [t.serialize() for t in self.items], + "uses_pep604_syntax": self.uses_pep604_syntax, + } @classmethod def deserialize(cls, data: JsonDict) -> UnionType: assert data[".class"] == "UnionType" - return UnionType([deserialize_type(t) for t in data["items"]]) + return UnionType( + [deserialize_type(t) for t in data["items"]], + uses_pep604_syntax=data["uses_pep604_syntax"], + ) class PartialType(ProperType): diff --git a/test-data/unit/check-incremental.test b/test-data/unit/check-incremental.test index 24292bce3e21..173265e48e6f 100644 --- a/test-data/unit/check-incremental.test +++ b/test-data/unit/check-incremental.test @@ -6726,3 +6726,20 @@ from typing_extensions import TypeIs def guard(x: object) -> TypeIs[int]: pass [builtins fixtures/tuple.pyi] + +[case testStartUsingPEP604Union] +# flags: --python-version 3.10 +import a +[file a.py] +import lib + +[file a.py.2] +from lib import IntOrStr +assert isinstance(1, IntOrStr) + +[file lib.py] +from typing_extensions import TypeAlias + +IntOrStr: TypeAlias = int | str +assert isinstance(1, IntOrStr) +[builtins fixtures/type.pyi] diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 8612df9bc663..e142fdd5d060 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1376,13 +1376,13 @@ else: reveal_type(val) # N: Revealed type is "Union[__main__.A, None]" if val in (None,): - reveal_type(val) # N: Revealed type is "Union[__main__.A, None]" + reveal_type(val) # N: Revealed type is "None" else: reveal_type(val) # N: Revealed type is "Union[__main__.A, None]" if val not in (None,): reveal_type(val) # N: Revealed type is "Union[__main__.A, None]" else: - reveal_type(val) # N: Revealed type is "Union[__main__.A, None]" + reveal_type(val) # N: Revealed type is "None" [builtins fixtures/primitives.pyi] [case testNarrowingWithTupleOfTypes] @@ -2114,3 +2114,111 @@ else: [typing fixtures/typing-medium.pyi] [builtins fixtures/ops.pyi] + + +[case testTypeNarrowingStringInLiteralUnion] +from typing import Literal, Tuple +typ: Tuple[Literal['a', 'b'], ...] = ('a', 'b') +x: str = "hi!" +if x in typ: + reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]" +else: + reveal_type(x) # N: Revealed type is "builtins.str" +[builtins fixtures/tuple.pyi] +[typing fixtures/typing-medium.pyi] + +[case testTypeNarrowingStringInLiteralUnionSubset] +from typing import Literal, Tuple +typeAlpha: Tuple[Literal['a', 'b', 'c'], ...] = ('a', 'b') +strIn: str = "b" +strOut: str = "c" +if strIn in typeAlpha: + reveal_type(strIn) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]" +else: + reveal_type(strIn) # N: Revealed type is "builtins.str" +if strOut in typeAlpha: + reveal_type(strOut) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]" +else: + reveal_type(strOut) # N: Revealed type is "builtins.str" +[builtins fixtures/primitives.pyi] +[typing fixtures/typing-medium.pyi] + +[case testNarrowingStringNotInLiteralUnion] +from typing import Literal, Tuple +typeAlpha: Tuple[Literal['a', 'b', 'c'],...] = ('a', 'b', 'c') +strIn: str = "c" +strOut: str = "d" +if strIn not in typeAlpha: + reveal_type(strIn) # N: Revealed type is "builtins.str" +else: + reveal_type(strIn) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]" +if strOut in typeAlpha: + reveal_type(strOut) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]" +else: + reveal_type(strOut) # N: Revealed type is "builtins.str" +[builtins fixtures/primitives.pyi] +[typing fixtures/typing-medium.pyi] + +[case testNarrowingStringInLiteralUnionDontExpand] +from typing import Literal, Tuple +typeAlpha: Tuple[Literal['a', 'b', 'c'], ...] = ('a', 'b', 'c') +strIn: Literal['c'] = "c" +reveal_type(strIn) # N: Revealed type is "Literal['c']" +#Check we don't expand a Literal into the Union type +if strIn not in typeAlpha: + reveal_type(strIn) # N: Revealed type is "Literal['c']" +else: + reveal_type(strIn) # N: Revealed type is "Literal['c']" +[builtins fixtures/primitives.pyi] +[typing fixtures/typing-medium.pyi] + +[case testTypeNarrowingStringInMixedUnion] +from typing import Literal, Tuple +typ: Tuple[Literal['a', 'b'], ...] = ('a', 'b') +x: str = "hi!" +if x in typ: + reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]" +else: + reveal_type(x) # N: Revealed type is "builtins.str" +[builtins fixtures/tuple.pyi] +[typing fixtures/typing-medium.pyi] + +[case testTypeNarrowingStringInSet] +from typing import Literal, Set +typ: Set[Literal['a', 'b']] = {'a', 'b'} +x: str = "hi!" +if x in typ: + reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]" +else: + reveal_type(x) # N: Revealed type is "builtins.str" +if x not in typ: + reveal_type(x) # N: Revealed type is "builtins.str" +else: + reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]" +[builtins fixtures/narrowing.pyi] +[typing fixtures/typing-medium.pyi] + +[case testTypeNarrowingStringInList] +from typing import Literal, List +typ: List[Literal['a', 'b']] = ['a', 'b'] +x: str = "hi!" +if x in typ: + reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]" +else: + reveal_type(x) # N: Revealed type is "builtins.str" +if x not in typ: + reveal_type(x) # N: Revealed type is "builtins.str" +else: + reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]" +[builtins fixtures/narrowing.pyi] +[typing fixtures/typing-medium.pyi] + +[case testTypeNarrowingUnionStringFloat] +from typing import Union +def foobar(foo: Union[str, float]): + if foo in ['a', 'b']: + reveal_type(foo) # N: Revealed type is "builtins.str" + else: + reveal_type(foo) # N: Revealed type is "Union[builtins.str, builtins.float]" +[builtins fixtures/primitives.pyi] +[typing fixtures/typing-medium.pyi] diff --git a/test-data/unit/check-python312.test b/test-data/unit/check-python312.test index 073ef7f4bdec..a3f4c87120cd 100644 --- a/test-data/unit/check-python312.test +++ b/test-data/unit/check-python312.test @@ -1667,3 +1667,49 @@ if x["other"] is not None: type Y[T] = {"item": T, **Y[T]} # E: Overwriting TypedDict field "item" while merging [builtins fixtures/dict.pyi] [typing fixtures/typing-full.pyi] + +[case testPEP695UsingIncorrectExpressionsInTypeVariableBound] +# flags: --enable-incomplete-feature=NewGenericSyntax + +type X[T: (yield 1)] = Any # E: Yield expression cannot be used as a type variable bound +type Y[T: (yield from [])] = Any # E: Yield expression cannot be used as a type variable bound +type Z[T: (a := 1)] = Any # E: Named expression cannot be used as a type variable bound +type K[T: (await 1)] = Any # E: Await expression cannot be used as a type variable bound + +type XNested[T: (1 + (yield 1))] = Any # E: Yield expression cannot be used as a type variable bound +type YNested[T: (1 + (yield from []))] = Any # E: Yield expression cannot be used as a type variable bound +type ZNested[T: (1 + (a := 1))] = Any # E: Named expression cannot be used as a type variable bound +type KNested[T: (1 + (await 1))] = Any # E: Await expression cannot be used as a type variable bound + +class FooX[T: (yield 1)]: pass # E: Yield expression cannot be used as a type variable bound +class FooY[T: (yield from [])]: pass # E: Yield expression cannot be used as a type variable bound +class FooZ[T: (a := 1)]: pass # E: Named expression cannot be used as a type variable bound +class FooK[T: (await 1)]: pass # E: Await expression cannot be used as a type variable bound + +class FooXNested[T: (1 + (yield 1))]: pass # E: Yield expression cannot be used as a type variable bound +class FooYNested[T: (1 + (yield from []))]: pass # E: Yield expression cannot be used as a type variable bound +class FooZNested[T: (1 + (a := 1))]: pass # E: Named expression cannot be used as a type variable bound +class FooKNested[T: (1 + (await 1))]: pass # E: Await expression cannot be used as a type variable bound + +def foox[T: (yield 1)](): pass # E: Yield expression cannot be used as a type variable bound +def fooy[T: (yield from [])](): pass # E: Yield expression cannot be used as a type variable bound +def fooz[T: (a := 1)](): pass # E: Named expression cannot be used as a type variable bound +def fook[T: (await 1)](): pass # E: Await expression cannot be used as a type variable bound + +def foox_nested[T: (1 + (yield 1))](): pass # E: Yield expression cannot be used as a type variable bound +def fooy_nested[T: (1 + (yield from []))](): pass # E: Yield expression cannot be used as a type variable bound +def fooz_nested[T: (1 + (a := 1))](): pass # E: Named expression cannot be used as a type variable bound +def fook_nested[T: (1 +(await 1))](): pass # E: Await expression cannot be used as a type variable bound + +[case testPEP695UsingIncorrectExpressionsInTypeAlias] +# flags: --enable-incomplete-feature=NewGenericSyntax + +type X = (yield 1) # E: Yield expression cannot be used within a type alias +type Y = (yield from []) # E: Yield expression cannot be used within a type alias +type Z = (a := 1) # E: Named expression cannot be used within a type alias +type K = (await 1) # E: Await expression cannot be used within a type alias + +type XNested = (1 + (yield 1)) # E: Yield expression cannot be used within a type alias +type YNested = (1 + (yield from [])) # E: Yield expression cannot be used within a type alias +type ZNested = (1 + (a := 1)) # E: Named expression cannot be used within a type alias +type KNested = (1 + (await 1)) # E: Await expression cannot be used within a type alias diff --git a/test-data/unit/fixtures/narrowing.pyi b/test-data/unit/fixtures/narrowing.pyi index 89ee011c1c80..a36ac7f29bd2 100644 --- a/test-data/unit/fixtures/narrowing.pyi +++ b/test-data/unit/fixtures/narrowing.pyi @@ -1,5 +1,5 @@ # Builtins stub used in check-narrowing test cases. -from typing import Generic, Sequence, Tuple, Type, TypeVar, Union +from typing import Generic, Sequence, Tuple, Type, TypeVar, Union, Iterable Tco = TypeVar('Tco', covariant=True) @@ -15,6 +15,13 @@ class function: pass class ellipsis: pass class int: pass class str: pass +class float: pass class dict(Generic[KT, VT]): pass def isinstance(x: object, t: Union[Type[object], Tuple[Type[object], ...]]) -> bool: pass + +class list(Sequence[Tco]): + def __contains__(self, other: object) -> bool: pass +class set(Iterable[Tco], Generic[Tco]): + def __init__(self, iterable: Iterable[Tco] = ...) -> None: ... + def __contains__(self, item: object) -> bool: pass