Skip to content

Commit

Permalink
Improve concurrent close for scheduler (#8829)
Browse files Browse the repository at this point in the history
* Improve concurrent close for scheduler

* Fix test
  • Loading branch information
hendrikmakait authored Aug 13, 2024
1 parent 86dc83c commit fd92ab8
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 13 deletions.
16 changes: 9 additions & 7 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3849,7 +3849,7 @@ async def post(self):
"""Shut down the server."""
self.log.info("Shutting down on /api/shutdown request.")

await scheduler.close(reason="shutdown requested via Jupyter")
await scheduler.close(reason="jupyter-requested-shutdown")

j = ServerApp.instance(
config=Config(
Expand Down Expand Up @@ -4274,7 +4274,7 @@ def del_scheduler_file() -> None:
setproctitle(f"dask scheduler [{self.address}]")
return self

async def close(self, fast=None, close_workers=None, reason=""):
async def close(self, fast=None, close_workers=None, reason="unknown"):
"""Send cleanup signal to all coroutines then wait until finished
See Also
Expand All @@ -4291,6 +4291,10 @@ async def close(self, fast=None, close_workers=None, reason=""):
await self.finished()
return

self.status = Status.closing
logger.info("Closing scheduler. Reason: %s", reason)
setproctitle("dask scheduler [closing]")

async def log_errors(func):
try:
await func()
Expand All @@ -4301,10 +4305,6 @@ async def log_errors(func):
*[log_errors(plugin.before_close) for plugin in list(self.plugins.values())]
)

self.status = Status.closing
logger.info("Scheduler closing due to %s...", reason or "unknown reason")
setproctitle("dask scheduler [closing]")

await self.preloads.teardown()

await asyncio.gather(
Expand Down Expand Up @@ -8652,7 +8652,9 @@ def check_idle(self) -> float | None:
"Scheduler closing after being idle for %s",
format_time(self.idle_timeout),
)
self._ongoing_background_tasks.call_soon(self.close)
self._ongoing_background_tasks.call_soon(
self.close, reason="idle-timeout-exceeded"
)
return self.idle_since

def _check_no_workers(self) -> None:
Expand Down
4 changes: 2 additions & 2 deletions distributed/tests/test_jupyter.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ def test_shutsdown_cleanly(requires_default_ports):
stderr = subprocess_fut.result().stderr
assert "Traceback" not in stderr
assert (
"distributed.scheduler - INFO - Scheduler closing due to shutdown "
"requested via Jupyter...\n" in stderr
"distributed.scheduler - INFO - Closing scheduler. Reason: jupyter-requested-shutdown"
in stderr
)
assert "Shutting down on /api/shutdown request.\n" in stderr
assert (
Expand Down
27 changes: 23 additions & 4 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2394,7 +2394,7 @@ async def test_idle_timeout(c, s, a, b):
_idle_since = s.check_idle()
assert _idle_since == s.idle_since

with captured_logger("distributed.scheduler") as logs:
with captured_logger("distributed.scheduler") as caplog:
start = time()
while s.status != Status.closed:
await asyncio.sleep(0.01)
Expand All @@ -2405,9 +2405,11 @@ async def test_idle_timeout(c, s, a, b):
await asyncio.sleep(0.01)
assert time() < start + 1

assert "idle" in logs.getvalue()
assert "500" in logs.getvalue()
assert "ms" in logs.getvalue()
logs = caplog.getvalue()
assert "idle" in logs
assert "500" in logs
assert "ms" in logs
assert "idle-timeout-exceeded" in logs
assert s.idle_since > beginning
pc.stop()

Expand Down Expand Up @@ -5270,3 +5272,20 @@ async def test_stimulus_from_erred_task(c, s, a):
logger.getvalue()
== "Task f marked as failed because 1 workers died while trying to run it\n"
)


@gen_cluster(client=True)
async def test_concurrent_close_requests(c, s, *workers):
class BeforeCloseCounterPlugin(SchedulerPlugin):
async def start(self, scheduler):
self.call_count = 0

async def before_close(self):
self.call_count += 1

await c.register_plugin(BeforeCloseCounterPlugin(), name="before_close")
with captured_logger("distributed.scheduler", level=logging.INFO) as caplog:
await asyncio.gather(*[s.close(reason="test-reason") for _ in range(5)])
assert s.plugins["before_close"].call_count == 1
lines = caplog.getvalue().split("\n")
assert sum("Closing scheduler" in line for line in lines) == 1

0 comments on commit fd92ab8

Please sign in to comment.