diff --git a/mypy/checker.py b/mypy/checker.py index 56be3db3f9e7..9f987cb5ccdf 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -4,7 +4,7 @@ import itertools from collections import defaultdict -from contextlib import contextmanager, nullcontext +from contextlib import ExitStack, contextmanager from typing import ( AbstractSet, Callable, @@ -526,17 +526,11 @@ def check_second_pass( # print("XXX in pass %d, class %s, function %s" % # (self.pass_num, type_name, node.fullname or node.name)) done.add(node) - with ( - self.tscope.class_scope(active_typeinfo) - if active_typeinfo - else nullcontext() - ): - with ( - self.scope.push_class(active_typeinfo) - if active_typeinfo - else nullcontext() - ): - self.check_partial(node) + with ExitStack() as stack: + if active_typeinfo: + stack.enter_context(self.tscope.class_scope(active_typeinfo)) + stack.enter_context(self.scope.push_class(active_typeinfo)) + self.check_partial(node) return True def check_partial(self, node: DeferredNodeType | FineGrainedDeferredNodeType) -> None: