diff --git a/mypy/checker.py b/mypy/checker.py index a8cb2b862fbc..0c27da8b5ac8 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -216,7 +216,7 @@ is_literal_type, is_named_instance, ) -from mypy.types_utils import is_optional, remove_optional, store_argument_type, strip_type +from mypy.types_utils import is_overlapping_none, remove_optional, store_argument_type, strip_type from mypy.typetraverser import TypeTraverserVisitor from mypy.typevars import fill_typevars, fill_typevars_with_any, has_no_typevars from mypy.util import is_dunder, is_sunder, is_typeshed_file @@ -5660,13 +5660,13 @@ def has_no_custom_eq_checks(t: Type) -> bool: if left_index in narrowable_operand_index_to_hash: # We only try and narrow away 'None' for now - if is_optional(item_type): + if is_overlapping_none(item_type): collection_item_type = get_proper_type( builtin_item_type(iterable_type) ) if ( collection_item_type is not None - and not is_optional(collection_item_type) + and not is_overlapping_none(collection_item_type) and not ( isinstance(collection_item_type, Instance) and collection_item_type.type.fullname == "builtins.object" @@ -6073,7 +6073,7 @@ def refine_away_none_in_comparison( non_optional_types = [] for i in chain_indices: typ = operand_types[i] - if not is_optional(typ): + if not is_overlapping_none(typ): non_optional_types.append(typ) # Make sure we have a mixture of optional and non-optional types. @@ -6083,7 +6083,7 @@ def refine_away_none_in_comparison( if_map = {} for i in narrowable_operand_indices: expr_type = operand_types[i] - if not is_optional(expr_type): + if not is_overlapping_none(expr_type): continue if any(is_overlapping_erased_types(expr_type, t) for t in non_optional_types): if_map[operands[i]] = remove_optional(expr_type) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 62e2298ba59d..114cde8327e0 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -169,7 +169,12 @@ is_named_instance, split_with_prefix_and_suffix, ) -from mypy.types_utils import is_generic_instance, is_optional, is_self_type_like, remove_optional +from mypy.types_utils import ( + is_generic_instance, + is_overlapping_none, + is_self_type_like, + remove_optional, +) from mypy.typestate import type_state from mypy.typevars import fill_typevars from mypy.typevartuples import find_unpack_in_list @@ -1809,7 +1814,7 @@ def infer_function_type_arguments_using_context( # valid results. erased_ctx = replace_meta_vars(ctx, ErasedType()) ret_type = callable.ret_type - if is_optional(ret_type) and is_optional(ctx): + if is_overlapping_none(ret_type) and is_overlapping_none(ctx): # If both the context and the return type are optional, unwrap the optional, # since in 99% cases this is what a user expects. In other words, we replace # Optional[T] <: Optional[int] diff --git a/mypy/plugins/common.py b/mypy/plugins/common.py index 65d967577bea..55f2870cadb4 100644 --- a/mypy/plugins/common.py +++ b/mypy/plugins/common.py @@ -43,7 +43,7 @@ deserialize_type, get_proper_type, ) -from mypy.types_utils import is_optional +from mypy.types_utils import is_overlapping_none from mypy.typevars import fill_typevars from mypy.util import get_unique_redefinition_name @@ -141,7 +141,7 @@ def find_shallow_matching_overload_item(overload: Overloaded, call: CallExpr) -> break elif ( arg_none - and not is_optional(arg_type) + and not is_overlapping_none(arg_type) and not ( isinstance(arg_type, Instance) and arg_type.type.fullname == "builtins.object" diff --git a/mypy/suggestions.py b/mypy/suggestions.py index 8e1225f00a2f..268f3032fc9b 100644 --- a/mypy/suggestions.py +++ b/mypy/suggestions.py @@ -79,7 +79,7 @@ UnionType, get_proper_type, ) -from mypy.types_utils import is_optional, remove_optional +from mypy.types_utils import is_overlapping_none, remove_optional from mypy.util import split_target @@ -752,7 +752,7 @@ def score_type(self, t: Type, arg_pos: bool) -> int: return 20 if any(has_any_type(x) for x in t.items): return 15 - if not is_optional(t): + if not is_overlapping_none(t): return 10 if isinstance(t, CallableType) and (has_any_type(t) or is_tricky_callable(t)): return 10 @@ -868,7 +868,7 @@ def visit_typeddict_type(self, t: TypedDictType) -> str: return t.fallback.accept(self) def visit_union_type(self, t: UnionType) -> str: - if len(t.items) == 2 and is_optional(t): + if len(t.items) == 2 and is_overlapping_none(t): return f"Optional[{remove_optional(t).accept(self)}]" else: return super().visit_union_type(t) diff --git a/mypy/types_utils.py b/mypy/types_utils.py index 43bca05d6bf9..7f2e38ef3753 100644 --- a/mypy/types_utils.py +++ b/mypy/types_utils.py @@ -101,10 +101,10 @@ def is_generic_instance(tp: Type) -> bool: return isinstance(tp, Instance) and bool(tp.args) -def is_optional(t: Type) -> bool: +def is_overlapping_none(t: Type) -> bool: t = get_proper_type(t) - return isinstance(t, UnionType) and any( - isinstance(get_proper_type(e), NoneType) for e in t.items + return isinstance(t, NoneType) or ( + isinstance(t, UnionType) and any(isinstance(get_proper_type(e), NoneType) for e in t.items) ) diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index b763e0ff3b68..291f73a45230 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1263,6 +1263,32 @@ def g() -> None: [builtins fixtures/dict.pyi] +[case testNarrowingOptionalEqualsNone] +from typing import Optional + +class A: ... + +val: Optional[A] + +if val == None: + reveal_type(val) # N: Revealed type is "Union[__main__.A, None]" +else: + reveal_type(val) # N: Revealed type is "Union[__main__.A, None]" +if val != None: + reveal_type(val) # N: Revealed type is "Union[__main__.A, None]" +else: + reveal_type(val) # N: Revealed type is "Union[__main__.A, None]" + +if val in (None,): + reveal_type(val) # N: Revealed type is "Union[__main__.A, None]" +else: + reveal_type(val) # N: Revealed type is "Union[__main__.A, None]" +if val not in (None,): + reveal_type(val) # N: Revealed type is "Union[__main__.A, None]" +else: + reveal_type(val) # N: Revealed type is "Union[__main__.A, None]" +[builtins fixtures/primitives.pyi] + [case testNarrowingWithTupleOfTypes] from typing import Tuple, Type diff --git a/test-data/unit/fixtures/primitives.pyi b/test-data/unit/fixtures/primitives.pyi index b74252857d6f..c9b1e3f4e983 100644 --- a/test-data/unit/fixtures/primitives.pyi +++ b/test-data/unit/fixtures/primitives.pyi @@ -45,7 +45,8 @@ class memoryview(Sequence[int]): def __iter__(self) -> Iterator[int]: pass def __contains__(self, other: object) -> bool: pass def __getitem__(self, item: int) -> int: pass -class tuple(Generic[T]): pass +class tuple(Generic[T]): + def __contains__(self, other: object) -> bool: pass class list(Sequence[T]): def __iter__(self) -> Iterator[T]: pass def __contains__(self, other: object) -> bool: pass