From fd92ab83bf96a6b0090d64942e1d6e16a726d263 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 13 Aug 2024 15:11:33 +0200 Subject: [PATCH] Improve concurrent close for scheduler (#8829) * Improve concurrent close for scheduler * Fix test --- distributed/scheduler.py | 16 +++++++++------- distributed/tests/test_jupyter.py | 4 ++-- distributed/tests/test_scheduler.py | 27 +++++++++++++++++++++++---- 3 files changed, 34 insertions(+), 13 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 363b33bb418..a989aa7fe37 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3849,7 +3849,7 @@ async def post(self): """Shut down the server.""" self.log.info("Shutting down on /api/shutdown request.") - await scheduler.close(reason="shutdown requested via Jupyter") + await scheduler.close(reason="jupyter-requested-shutdown") j = ServerApp.instance( config=Config( @@ -4274,7 +4274,7 @@ def del_scheduler_file() -> None: setproctitle(f"dask scheduler [{self.address}]") return self - async def close(self, fast=None, close_workers=None, reason=""): + async def close(self, fast=None, close_workers=None, reason="unknown"): """Send cleanup signal to all coroutines then wait until finished See Also @@ -4291,6 +4291,10 @@ async def close(self, fast=None, close_workers=None, reason=""): await self.finished() return + self.status = Status.closing + logger.info("Closing scheduler. Reason: %s", reason) + setproctitle("dask scheduler [closing]") + async def log_errors(func): try: await func() @@ -4301,10 +4305,6 @@ async def log_errors(func): *[log_errors(plugin.before_close) for plugin in list(self.plugins.values())] ) - self.status = Status.closing - logger.info("Scheduler closing due to %s...", reason or "unknown reason") - setproctitle("dask scheduler [closing]") - await self.preloads.teardown() await asyncio.gather( @@ -8652,7 +8652,9 @@ def check_idle(self) -> float | None: "Scheduler closing after being idle for %s", format_time(self.idle_timeout), ) - self._ongoing_background_tasks.call_soon(self.close) + self._ongoing_background_tasks.call_soon( + self.close, reason="idle-timeout-exceeded" + ) return self.idle_since def _check_no_workers(self) -> None: diff --git a/distributed/tests/test_jupyter.py b/distributed/tests/test_jupyter.py index 5280d11093c..3f45678f812 100644 --- a/distributed/tests/test_jupyter.py +++ b/distributed/tests/test_jupyter.py @@ -146,8 +146,8 @@ def test_shutsdown_cleanly(requires_default_ports): stderr = subprocess_fut.result().stderr assert "Traceback" not in stderr assert ( - "distributed.scheduler - INFO - Scheduler closing due to shutdown " - "requested via Jupyter...\n" in stderr + "distributed.scheduler - INFO - Closing scheduler. Reason: jupyter-requested-shutdown" + in stderr ) assert "Shutting down on /api/shutdown request.\n" in stderr assert ( diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index a3c33169993..d492b769eb6 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -2394,7 +2394,7 @@ async def test_idle_timeout(c, s, a, b): _idle_since = s.check_idle() assert _idle_since == s.idle_since - with captured_logger("distributed.scheduler") as logs: + with captured_logger("distributed.scheduler") as caplog: start = time() while s.status != Status.closed: await asyncio.sleep(0.01) @@ -2405,9 +2405,11 @@ async def test_idle_timeout(c, s, a, b): await asyncio.sleep(0.01) assert time() < start + 1 - assert "idle" in logs.getvalue() - assert "500" in logs.getvalue() - assert "ms" in logs.getvalue() + logs = caplog.getvalue() + assert "idle" in logs + assert "500" in logs + assert "ms" in logs + assert "idle-timeout-exceeded" in logs assert s.idle_since > beginning pc.stop() @@ -5270,3 +5272,20 @@ async def test_stimulus_from_erred_task(c, s, a): logger.getvalue() == "Task f marked as failed because 1 workers died while trying to run it\n" ) + + +@gen_cluster(client=True) +async def test_concurrent_close_requests(c, s, *workers): + class BeforeCloseCounterPlugin(SchedulerPlugin): + async def start(self, scheduler): + self.call_count = 0 + + async def before_close(self): + self.call_count += 1 + + await c.register_plugin(BeforeCloseCounterPlugin(), name="before_close") + with captured_logger("distributed.scheduler", level=logging.INFO) as caplog: + await asyncio.gather(*[s.close(reason="test-reason") for _ in range(5)]) + assert s.plugins["before_close"].call_count == 1 + lines = caplog.getvalue().split("\n") + assert sum("Closing scheduler" in line for line in lines) == 1