From 9258de34e425b69c2952d46d81a68df8a4ce34b5 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Mon, 20 May 2024 12:48:09 +0100 Subject: [PATCH 1/2] Detect invalid uses of Generic/Protocol in new-style generic classes --- mypy/semanal.py | 16 ++++++++++++---- test-data/unit/check-python312.test | 20 ++++++++++++++++++++ 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/mypy/semanal.py b/mypy/semanal.py index 7d6c75b274ee..943ff223f658 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -2058,11 +2058,19 @@ class Foo(Bar, Generic[T]): ... continue result = self.analyze_class_typevar_declaration(base) if result is not None: - if declared_tvars: - self.fail("Only single Generic[...] or Protocol[...] can be in bases", context) - removed.append(i) tvars = result[0] is_protocol |= result[1] + if declared_tvars: + if defn.type_args: + if is_protocol: + self.fail('No arguments expected for "Protocol" base class', context) + else: + self.fail("Generic[...] base class is redundant", context) + else: + self.fail( + "Only single Generic[...] or Protocol[...] can be in bases", context + ) + removed.append(i) declared_tvars.extend(tvars) if isinstance(base, UnboundType): sym = self.lookup_qualified(base.name, base) @@ -2074,7 +2082,7 @@ class Foo(Bar, Generic[T]): ... all_tvars = self.get_all_bases_tvars(base_type_exprs, removed) if declared_tvars: - if len(remove_dups(declared_tvars)) < len(declared_tvars): + if len(remove_dups(declared_tvars)) < len(declared_tvars) and not defn.type_args: self.fail("Duplicate type variables in Generic[...] or Protocol[...]", context) declared_tvars = remove_dups(declared_tvars) if not set(all_tvars).issubset(set(declared_tvars)): diff --git a/test-data/unit/check-python312.test b/test-data/unit/check-python312.test index cce22634df6d..1d5f8ea34c5c 100644 --- a/test-data/unit/check-python312.test +++ b/test-data/unit/check-python312.test @@ -1161,3 +1161,23 @@ def decorator(x: str) -> Any: ... @decorator(T) # E: Argument 1 to "decorator" has incompatible type "int"; expected "str" class C[T]: pass + +[case testPEP695InvalidGenericOrProtocolBaseClass] +# mypy: enable-incomplete-feature=NewGenericSyntax + +from typing import Generic, Protocol, TypeVar + +S = TypeVar("S") + +class C[T](Generic[T]): # E: Generic[...] base class is redundant + pass +class C2[T](Generic[S]): # E: Generic[...] base class is redundant + pass + +a: C[int] +b: C2[int, str] + +class P[T](Protocol[T]): # E: No arguments expected for "Protocol" base class + pass +class P2[T](Protocol[S]): # E: No arguments expected for "Protocol" base class + pass From 268eec82e7c4d21908155a515f0f1bcc4319f793 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Mon, 20 May 2024 14:49:18 +0100 Subject: [PATCH 2/2] Generate error if mixing old-style and new-style type vars --- mypy/messages.py | 4 ++++ mypy/semanal.py | 24 +++++++++++++++++++----- test-data/unit/check-python312.test | 22 +++++++++++++++++++++- 3 files changed, 44 insertions(+), 6 deletions(-) diff --git a/mypy/messages.py b/mypy/messages.py index 199b7c42b11b..8f923462c789 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -2421,6 +2421,10 @@ def annotation_in_unchecked_function(self, context: Context) -> None: code=codes.ANNOTATION_UNCHECKED, ) + def type_parameters_should_be_declared(self, undeclared: list[str], context: Context) -> None: + names = ", ".join('"' + n + '"' for n in undeclared) + self.fail(f"All type parameters should be declared ({names} not declared)", context) + def quote_type_string(type_string: str) -> str: """Quotes a type representation for use in messages.""" diff --git a/mypy/semanal.py b/mypy/semanal.py index 943ff223f658..b77f5faaaa1e 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -1101,6 +1101,14 @@ def update_function_type_variables(self, fun_type: CallableType, defn: FuncItem) fun_type.variables, has_self_type = a.bind_function_type_variables(fun_type, defn) if has_self_type and self.type is not None: self.setup_self_type() + if defn.type_args: + bound_fullnames = {v.fullname for v in fun_type.variables} + declared_fullnames = {self.qualified_name(p.name) for p in defn.type_args} + extra = sorted(bound_fullnames - declared_fullnames) + if extra: + self.msg.type_parameters_should_be_declared( + [n.split(".")[-1] for n in extra], defn + ) return has_self_type def setup_self_type(self) -> None: @@ -2086,11 +2094,17 @@ class Foo(Bar, Generic[T]): ... self.fail("Duplicate type variables in Generic[...] or Protocol[...]", context) declared_tvars = remove_dups(declared_tvars) if not set(all_tvars).issubset(set(declared_tvars)): - self.fail( - "If Generic[...] or Protocol[...] is present" - " it should list all type variables", - context, - ) + if defn.type_args: + undeclared = sorted(set(all_tvars) - set(declared_tvars)) + self.msg.type_parameters_should_be_declared( + [tv[0] for tv in undeclared], context + ) + else: + self.fail( + "If Generic[...] or Protocol[...] is present" + " it should list all type variables", + context, + ) # In case of error, Generic tvars will go first declared_tvars = remove_dups(declared_tvars + all_tvars) else: diff --git a/test-data/unit/check-python312.test b/test-data/unit/check-python312.test index 1d5f8ea34c5c..f5d9fd195f04 100644 --- a/test-data/unit/check-python312.test +++ b/test-data/unit/check-python312.test @@ -1164,7 +1164,6 @@ class C[T]: [case testPEP695InvalidGenericOrProtocolBaseClass] # mypy: enable-incomplete-feature=NewGenericSyntax - from typing import Generic, Protocol, TypeVar S = TypeVar("S") @@ -1181,3 +1180,24 @@ class P[T](Protocol[T]): # E: No arguments expected for "Protocol" base class pass class P2[T](Protocol[S]): # E: No arguments expected for "Protocol" base class pass + +[case testPEP695MixNewAndOldStyleGenerics] +# mypy: enable-incomplete-feature=NewGenericSyntax +from typing import TypeVar + +S = TypeVar("S") +U = TypeVar("U") + +def f[T](x: T, y: S) -> T | S: ... # E: All type parameters should be declared ("S" not declared) +def g[T](x: S, y: U) -> T | S | U: ... # E: All type parameters should be declared ("S", "U" not declared) + +def h[S: int](x: S) -> S: + a: int = x + return x + +class C[T]: + def m[X, S](self, x: S, y: U) -> X | S | U: ... # E: All type parameters should be declared ("U" not declared) + def m2(self, x: T, y: S) -> T | S: ... + +class D[T](C[S]): # E: All type parameters should be declared ("S" not declared) + pass