diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index 7b6a55324741..a23be464b825 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -202,7 +202,7 @@ def visit_value_pattern(self, o: ValuePattern) -> PatternType: typ = self.chk.expr_checker.accept(o.expr) typ = coerce_to_literal(typ) narrowed_type, rest_type = self.chk.conditional_types_with_intersection( - current_type, [get_type_range(typ)], o, default=current_type + current_type, [get_type_range(typ)], o, default=get_proper_type(typ) ) if not isinstance(get_proper_type(narrowed_type), (LiteralType, UninhabitedType)): return PatternType(narrowed_type, UnionType.make_union([narrowed_type, rest_type]), {}) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index b0e27fe1e3a0..3a040d94d7ba 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -1369,6 +1369,27 @@ match m3: reveal_type(m3) # N: Revealed type is "Tuple[Union[builtins.int, builtins.str]]" [builtins fixtures/tuple.pyi] +[case testMatchEnumSingleChoice] +from enum import Enum +from typing import NoReturn + +def assert_never(x: NoReturn) -> None: ... + +class Medal(Enum): + gold = 1 + +def f(m: Medal) -> None: + always_assigned: int | None = None + match m: + case Medal.gold: + always_assigned = 1 + reveal_type(m) # N: Revealed type is "Literal[__main__.Medal.gold]" + case _: + assert_never(m) + + reveal_type(always_assigned) # N: Revealed type is "builtins.int" +[builtins fixtures/bool.pyi] + [case testMatchLiteralPatternEnumNegativeNarrowing] from enum import Enum class Medal(Enum): @@ -1388,10 +1409,13 @@ def f(m: Medal) -> int: def g(m: Medal) -> int: match m: case Medal.gold: + reveal_type(m) # N: Revealed type is "Literal[__main__.Medal.gold]" return 0 case Medal.silver: + reveal_type(m) # N: Revealed type is "Literal[__main__.Medal.silver]" return 1 case Medal.bronze: + reveal_type(m) # N: Revealed type is "Literal[__main__.Medal.bronze]" return 2 [case testMatchLiteralPatternEnumCustomEquals-skip]