Skip to content

Commit

Permalink
Merge branch 'master' into generate-inline-generic
Browse files Browse the repository at this point in the history
  • Loading branch information
InvincibleRMC committed Jul 25, 2024
2 parents 53a5979 + db9837f commit f4cfbc8
Show file tree
Hide file tree
Showing 10 changed files with 273 additions and 12 deletions.
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
# show_authors = False

# The name of the Pygments (syntax highlighting) style to use.
pygments_style = "sphinx"
# pygments_style = "sphinx"

# A list of ignored prefixes for module index sorting.
# modindex_common_prefix = []
Expand Down
15 changes: 10 additions & 5 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6011,11 +6011,16 @@ def has_no_custom_eq_checks(t: Type) -> bool:
if_map, else_map = {}, {}

if left_index in narrowable_operand_index_to_hash:
# We only try and narrow away 'None' for now
if is_overlapping_none(item_type):
collection_item_type = get_proper_type(
builtin_item_type(iterable_type)
)
collection_item_type = get_proper_type(builtin_item_type(iterable_type))
# Narrow if the collection is a subtype
if (
collection_item_type is not None
and collection_item_type != item_type
and is_subtype(collection_item_type, item_type)
):
if_map[operands[left_index]] = collection_item_type
# Try and narrow away 'None'
elif is_overlapping_none(item_type):
if (
collection_item_type is not None
and not is_overlapping_none(collection_item_type)
Expand Down
47 changes: 47 additions & 0 deletions mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,12 @@ def ast3_parse(
if sys.version_info >= (3, 12):
ast_TypeAlias = ast3.TypeAlias
ast_ParamSpec = ast3.ParamSpec
ast_TypeVar = ast3.TypeVar
ast_TypeVarTuple = ast3.TypeVarTuple
else:
ast_TypeAlias = Any
ast_ParamSpec = Any
ast_TypeVar = Any
ast_TypeVarTuple = Any

N = TypeVar("N", bound=Node)
Expand Down Expand Up @@ -345,6 +347,15 @@ def is_no_type_check_decorator(expr: ast3.expr) -> bool:
return False


def find_disallowed_expression_in_annotation_scope(expr: ast3.expr | None) -> ast3.expr | None:
if expr is None:
return None
for node in ast3.walk(expr):
if isinstance(node, (ast3.Yield, ast3.YieldFrom, ast3.NamedExpr, ast3.Await)):
return node
return None


class ASTConverter:
def __init__(
self,
Expand Down Expand Up @@ -1180,6 +1191,29 @@ def visit_ClassDef(self, n: ast3.ClassDef) -> ClassDef:
self.class_and_function_stack.pop()
return cdef

def validate_type_param(self, type_param: ast_TypeVar) -> None:
incorrect_expr = find_disallowed_expression_in_annotation_scope(type_param.bound)
if incorrect_expr is None:
return
if isinstance(incorrect_expr, (ast3.Yield, ast3.YieldFrom)):
self.fail(
message_registry.TYPE_VAR_YIELD_EXPRESSION_IN_BOUND,
type_param.lineno,
type_param.col_offset,
)
if isinstance(incorrect_expr, ast3.NamedExpr):
self.fail(
message_registry.TYPE_VAR_NAMED_EXPRESSION_IN_BOUND,
type_param.lineno,
type_param.col_offset,
)
if isinstance(incorrect_expr, ast3.Await):
self.fail(
message_registry.TYPE_VAR_AWAIT_EXPRESSION_IN_BOUND,
type_param.lineno,
type_param.col_offset,
)

def translate_type_params(self, type_params: list[Any]) -> list[TypeParam]:
explicit_type_params = []
for p in type_params:
Expand All @@ -1202,6 +1236,7 @@ def translate_type_params(self, type_params: list[Any]) -> list[TypeParam]:
conv = TypeConverter(self.errors, line=p.lineno)
values = [conv.visit(t) for t in p.bound.elts]
elif p.bound is not None:
self.validate_type_param(p)
bound = TypeConverter(self.errors, line=p.lineno).visit(p.bound)
explicit_type_params.append(TypeParam(p.name, TYPE_VAR_KIND, bound, values))
return explicit_type_params
Expand Down Expand Up @@ -1791,11 +1826,23 @@ def visit_MatchOr(self, n: MatchOr) -> OrPattern:
node = OrPattern([self.visit(pattern) for pattern in n.patterns])
return self.set_line(node, n)

def validate_type_alias(self, n: ast_TypeAlias) -> None:
incorrect_expr = find_disallowed_expression_in_annotation_scope(n.value)
if incorrect_expr is None:
return
if isinstance(incorrect_expr, (ast3.Yield, ast3.YieldFrom)):
self.fail(message_registry.TYPE_ALIAS_WITH_YIELD_EXPRESSION, n.lineno, n.col_offset)
if isinstance(incorrect_expr, ast3.NamedExpr):
self.fail(message_registry.TYPE_ALIAS_WITH_NAMED_EXPRESSION, n.lineno, n.col_offset)
if isinstance(incorrect_expr, ast3.Await):
self.fail(message_registry.TYPE_ALIAS_WITH_AWAIT_EXPRESSION, n.lineno, n.col_offset)

# TypeAlias(identifier name, type_param* type_params, expr value)
def visit_TypeAlias(self, n: ast_TypeAlias) -> TypeAliasStmt | AssignmentStmt:
node: TypeAliasStmt | AssignmentStmt
if NEW_GENERIC_SYNTAX in self.options.enable_incomplete_feature:
type_params = self.translate_type_params(n.type_params)
self.validate_type_alias(n)
value = self.visit(n.value)
# Since the value is evaluated lazily, wrap the value inside a lambda.
# This helps mypyc.
Expand Down
24 changes: 24 additions & 0 deletions mypy/message_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,3 +338,27 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
TYPE_VAR_TOO_FEW_CONSTRAINED_TYPES: Final = ErrorMessage(
"Type variable must have at least two constrained types", codes.MISC
)

TYPE_VAR_YIELD_EXPRESSION_IN_BOUND: Final = ErrorMessage(
"Yield expression cannot be used as a type variable bound", codes.SYNTAX
)

TYPE_VAR_NAMED_EXPRESSION_IN_BOUND: Final = ErrorMessage(
"Named expression cannot be used as a type variable bound", codes.SYNTAX
)

TYPE_VAR_AWAIT_EXPRESSION_IN_BOUND: Final = ErrorMessage(
"Await expression cannot be used as a type variable bound", codes.SYNTAX
)

TYPE_ALIAS_WITH_YIELD_EXPRESSION: Final = ErrorMessage(
"Yield expression cannot be used within a type alias", codes.SYNTAX
)

TYPE_ALIAS_WITH_NAMED_EXPRESSION: Final = ErrorMessage(
"Named expression cannot be used within a type alias", codes.SYNTAX
)

TYPE_ALIAS_WITH_AWAIT_EXPRESSION: Final = ErrorMessage(
"Await expression cannot be used within a type alias", codes.SYNTAX
)
2 changes: 1 addition & 1 deletion mypy/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def type_context(self) -> list[Type | None]:

@abstractmethod
def fail(
self, msg: str | ErrorMessage, ctx: Context, *, code: ErrorCode | None = None
self, msg: str | ErrorMessage, ctx: Context, /, *, code: ErrorCode | None = None
) -> None:
"""Emit an error message at given location."""
raise NotImplementedError
Expand Down
11 changes: 9 additions & 2 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2900,12 +2900,19 @@ def relevant_items(self) -> list[Type]:
return [i for i in self.items if not isinstance(get_proper_type(i), NoneType)]

def serialize(self) -> JsonDict:
return {".class": "UnionType", "items": [t.serialize() for t in self.items]}
return {
".class": "UnionType",
"items": [t.serialize() for t in self.items],
"uses_pep604_syntax": self.uses_pep604_syntax,
}

@classmethod
def deserialize(cls, data: JsonDict) -> UnionType:
assert data[".class"] == "UnionType"
return UnionType([deserialize_type(t) for t in data["items"]])
return UnionType(
[deserialize_type(t) for t in data["items"]],
uses_pep604_syntax=data["uses_pep604_syntax"],
)


class PartialType(ProperType):
Expand Down
17 changes: 17 additions & 0 deletions test-data/unit/check-incremental.test
Original file line number Diff line number Diff line change
Expand Up @@ -6726,3 +6726,20 @@ from typing_extensions import TypeIs
def guard(x: object) -> TypeIs[int]:
pass
[builtins fixtures/tuple.pyi]

[case testStartUsingPEP604Union]
# flags: --python-version 3.10
import a
[file a.py]
import lib

[file a.py.2]
from lib import IntOrStr
assert isinstance(1, IntOrStr)

[file lib.py]
from typing_extensions import TypeAlias

IntOrStr: TypeAlias = int | str
assert isinstance(1, IntOrStr)
[builtins fixtures/type.pyi]
112 changes: 110 additions & 2 deletions test-data/unit/check-narrowing.test
Original file line number Diff line number Diff line change
Expand Up @@ -1376,13 +1376,13 @@ else:
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"

if val in (None,):
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
reveal_type(val) # N: Revealed type is "None"
else:
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
if val not in (None,):
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
else:
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
reveal_type(val) # N: Revealed type is "None"
[builtins fixtures/primitives.pyi]

[case testNarrowingWithTupleOfTypes]
Expand Down Expand Up @@ -2114,3 +2114,111 @@ else:

[typing fixtures/typing-medium.pyi]
[builtins fixtures/ops.pyi]


[case testTypeNarrowingStringInLiteralUnion]
from typing import Literal, Tuple
typ: Tuple[Literal['a', 'b'], ...] = ('a', 'b')
x: str = "hi!"
if x in typ:
reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]"
else:
reveal_type(x) # N: Revealed type is "builtins.str"
[builtins fixtures/tuple.pyi]
[typing fixtures/typing-medium.pyi]

[case testTypeNarrowingStringInLiteralUnionSubset]
from typing import Literal, Tuple
typeAlpha: Tuple[Literal['a', 'b', 'c'], ...] = ('a', 'b')
strIn: str = "b"
strOut: str = "c"
if strIn in typeAlpha:
reveal_type(strIn) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]"
else:
reveal_type(strIn) # N: Revealed type is "builtins.str"
if strOut in typeAlpha:
reveal_type(strOut) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]"
else:
reveal_type(strOut) # N: Revealed type is "builtins.str"
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-medium.pyi]

[case testNarrowingStringNotInLiteralUnion]
from typing import Literal, Tuple
typeAlpha: Tuple[Literal['a', 'b', 'c'],...] = ('a', 'b', 'c')
strIn: str = "c"
strOut: str = "d"
if strIn not in typeAlpha:
reveal_type(strIn) # N: Revealed type is "builtins.str"
else:
reveal_type(strIn) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]"
if strOut in typeAlpha:
reveal_type(strOut) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]"
else:
reveal_type(strOut) # N: Revealed type is "builtins.str"
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-medium.pyi]

[case testNarrowingStringInLiteralUnionDontExpand]
from typing import Literal, Tuple
typeAlpha: Tuple[Literal['a', 'b', 'c'], ...] = ('a', 'b', 'c')
strIn: Literal['c'] = "c"
reveal_type(strIn) # N: Revealed type is "Literal['c']"
#Check we don't expand a Literal into the Union type
if strIn not in typeAlpha:
reveal_type(strIn) # N: Revealed type is "Literal['c']"
else:
reveal_type(strIn) # N: Revealed type is "Literal['c']"
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-medium.pyi]

[case testTypeNarrowingStringInMixedUnion]
from typing import Literal, Tuple
typ: Tuple[Literal['a', 'b'], ...] = ('a', 'b')
x: str = "hi!"
if x in typ:
reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]"
else:
reveal_type(x) # N: Revealed type is "builtins.str"
[builtins fixtures/tuple.pyi]
[typing fixtures/typing-medium.pyi]

[case testTypeNarrowingStringInSet]
from typing import Literal, Set
typ: Set[Literal['a', 'b']] = {'a', 'b'}
x: str = "hi!"
if x in typ:
reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]"
else:
reveal_type(x) # N: Revealed type is "builtins.str"
if x not in typ:
reveal_type(x) # N: Revealed type is "builtins.str"
else:
reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]"
[builtins fixtures/narrowing.pyi]
[typing fixtures/typing-medium.pyi]

[case testTypeNarrowingStringInList]
from typing import Literal, List
typ: List[Literal['a', 'b']] = ['a', 'b']
x: str = "hi!"
if x in typ:
reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]"
else:
reveal_type(x) # N: Revealed type is "builtins.str"
if x not in typ:
reveal_type(x) # N: Revealed type is "builtins.str"
else:
reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]"
[builtins fixtures/narrowing.pyi]
[typing fixtures/typing-medium.pyi]

[case testTypeNarrowingUnionStringFloat]
from typing import Union
def foobar(foo: Union[str, float]):
if foo in ['a', 'b']:
reveal_type(foo) # N: Revealed type is "builtins.str"
else:
reveal_type(foo) # N: Revealed type is "Union[builtins.str, builtins.float]"
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-medium.pyi]
46 changes: 46 additions & 0 deletions test-data/unit/check-python312.test
Original file line number Diff line number Diff line change
Expand Up @@ -1667,3 +1667,49 @@ if x["other"] is not None:
type Y[T] = {"item": T, **Y[T]} # E: Overwriting TypedDict field "item" while merging
[builtins fixtures/dict.pyi]
[typing fixtures/typing-full.pyi]

[case testPEP695UsingIncorrectExpressionsInTypeVariableBound]
# flags: --enable-incomplete-feature=NewGenericSyntax

type X[T: (yield 1)] = Any # E: Yield expression cannot be used as a type variable bound
type Y[T: (yield from [])] = Any # E: Yield expression cannot be used as a type variable bound
type Z[T: (a := 1)] = Any # E: Named expression cannot be used as a type variable bound
type K[T: (await 1)] = Any # E: Await expression cannot be used as a type variable bound

type XNested[T: (1 + (yield 1))] = Any # E: Yield expression cannot be used as a type variable bound
type YNested[T: (1 + (yield from []))] = Any # E: Yield expression cannot be used as a type variable bound
type ZNested[T: (1 + (a := 1))] = Any # E: Named expression cannot be used as a type variable bound
type KNested[T: (1 + (await 1))] = Any # E: Await expression cannot be used as a type variable bound

class FooX[T: (yield 1)]: pass # E: Yield expression cannot be used as a type variable bound
class FooY[T: (yield from [])]: pass # E: Yield expression cannot be used as a type variable bound
class FooZ[T: (a := 1)]: pass # E: Named expression cannot be used as a type variable bound
class FooK[T: (await 1)]: pass # E: Await expression cannot be used as a type variable bound

class FooXNested[T: (1 + (yield 1))]: pass # E: Yield expression cannot be used as a type variable bound
class FooYNested[T: (1 + (yield from []))]: pass # E: Yield expression cannot be used as a type variable bound
class FooZNested[T: (1 + (a := 1))]: pass # E: Named expression cannot be used as a type variable bound
class FooKNested[T: (1 + (await 1))]: pass # E: Await expression cannot be used as a type variable bound

def foox[T: (yield 1)](): pass # E: Yield expression cannot be used as a type variable bound
def fooy[T: (yield from [])](): pass # E: Yield expression cannot be used as a type variable bound
def fooz[T: (a := 1)](): pass # E: Named expression cannot be used as a type variable bound
def fook[T: (await 1)](): pass # E: Await expression cannot be used as a type variable bound

def foox_nested[T: (1 + (yield 1))](): pass # E: Yield expression cannot be used as a type variable bound
def fooy_nested[T: (1 + (yield from []))](): pass # E: Yield expression cannot be used as a type variable bound
def fooz_nested[T: (1 + (a := 1))](): pass # E: Named expression cannot be used as a type variable bound
def fook_nested[T: (1 +(await 1))](): pass # E: Await expression cannot be used as a type variable bound

[case testPEP695UsingIncorrectExpressionsInTypeAlias]
# flags: --enable-incomplete-feature=NewGenericSyntax

type X = (yield 1) # E: Yield expression cannot be used within a type alias
type Y = (yield from []) # E: Yield expression cannot be used within a type alias
type Z = (a := 1) # E: Named expression cannot be used within a type alias
type K = (await 1) # E: Await expression cannot be used within a type alias

type XNested = (1 + (yield 1)) # E: Yield expression cannot be used within a type alias
type YNested = (1 + (yield from [])) # E: Yield expression cannot be used within a type alias
type ZNested = (1 + (a := 1)) # E: Named expression cannot be used within a type alias
type KNested = (1 + (await 1)) # E: Await expression cannot be used within a type alias
Loading

0 comments on commit f4cfbc8

Please sign in to comment.