From 831675f7d8ecb2f622409e2f66f6f5695285d53a Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 13 Sep 2024 11:16:45 +0200 Subject: [PATCH] Refactor and add bi-directional checking for adaptive stopping --- distributed/deploy/adaptive.py | 115 +++++++-- distributed/deploy/adaptive_core.py | 114 ++------- distributed/deploy/tests/test_adaptive.py | 235 +++++++++++++++++- .../deploy/tests/test_adaptive_core.py | 138 ++-------- 4 files changed, 368 insertions(+), 234 deletions(-) diff --git a/distributed/deploy/adaptive.py b/distributed/deploy/adaptive.py index 1638659db4..08c4212add 100644 --- a/distributed/deploy/adaptive.py +++ b/distributed/deploy/adaptive.py @@ -1,20 +1,39 @@ from __future__ import annotations import logging +from collections.abc import Hashable +from datetime import timedelta from inspect import isawaitable +from typing import TYPE_CHECKING, Any, Callable, Literal, cast from tornado.ioloop import IOLoop import dask.config from dask.utils import parse_timedelta +from distributed.compatibility import PeriodicCallback +from distributed.core import Status from distributed.deploy.adaptive_core import AdaptiveCore from distributed.protocol import pickle from distributed.utils import log_errors +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + from distributed.deploy.cluster import Cluster + from distributed.scheduler import WorkerState + logger = logging.getLogger(__name__) +AdaptiveStateState: TypeAlias = Literal[ + "starting", + "running", + "stopped", + "inactive", +] + + class Adaptive(AdaptiveCore): ''' Adaptively allocate workers based on scheduler load. A superclass. @@ -81,16 +100,21 @@ class Adaptive(AdaptiveCore): specified in the dask config under the distributed.adaptive key. ''' + interval: float | None + periodic_callback: PeriodicCallback | None + #: Whether this adaptive strategy is periodically adapting + state: AdaptiveStateState + def __init__( self, - cluster=None, - interval=None, - minimum=None, - maximum=None, - wait_count=None, - target_duration=None, - worker_key=None, - **kwargs, + cluster: Cluster, + interval: str | float | timedelta | None = None, + minimum: int | None = None, + maximum: float | None = None, + wait_count: int | None = None, + target_duration: str | float | timedelta | None = None, + worker_key: Callable[[WorkerState], Hashable] | None = None, + **kwargs: Any, ): self.cluster = cluster self.worker_key = worker_key @@ -99,20 +123,78 @@ def __init__( if interval is None: interval = dask.config.get("distributed.adaptive.interval") if minimum is None: - minimum = dask.config.get("distributed.adaptive.minimum") + minimum = cast(int, dask.config.get("distributed.adaptive.minimum")) if maximum is None: - maximum = dask.config.get("distributed.adaptive.maximum") + maximum = cast(float, dask.config.get("distributed.adaptive.maximum")) if wait_count is None: - wait_count = dask.config.get("distributed.adaptive.wait-count") + wait_count = cast(int, dask.config.get("distributed.adaptive.wait-count")) if target_duration is None: - target_duration = dask.config.get("distributed.adaptive.target-duration") + target_duration = cast( + str, dask.config.get("distributed.adaptive.target-duration") + ) + + self.interval = parse_timedelta(interval, "seconds") + self.periodic_callback = None + + if self.interval and self.cluster: + import weakref + + self_ref = weakref.ref(self) + + async def _adapt(): + adaptive = self_ref() + if not adaptive or adaptive.state != "running": + return + if adaptive.cluster.status != Status.running: + adaptive.stop(reason="cluster-not-running") + return + try: + await adaptive.adapt() + except Exception: + logger.warning( + "Adaptive encountered an error while adapting", exc_info=True + ) + + self.periodic_callback = PeriodicCallback(_adapt, self.interval * 1000) + self.state = "starting" + self.loop.add_callback(self._start) + else: + self.state = "inactive" self.target_duration = parse_timedelta(target_duration) - super().__init__( - minimum=minimum, maximum=maximum, wait_count=wait_count, interval=interval + super().__init__(minimum=minimum, maximum=maximum, wait_count=wait_count) + + def _start(self) -> None: + if self.state != "starting": + return + + assert self.periodic_callback is not None + self.periodic_callback.start() + self.state = "running" + logger.info( + "Adaptive scaling started: minimum=%s maximum=%s", + self.minimum, + self.maximum, ) + def stop(self, reason: str = "unknown") -> None: + if self.state in ("inactive", "stopped"): + return + + if self.state == "running": + assert self.periodic_callback is not None + self.periodic_callback.stop() + logger.info( + "Adaptive scaling stopped: minimum=%s maximum=%s. Reason: %s", + self.minimum, + self.maximum, + reason, + ) + + self.periodic_callback = None + self.state = "stopped" + @property def scheduler(self): return self.cluster.scheduler_comm @@ -210,6 +292,9 @@ async def scale_up(self, n): def loop(self) -> IOLoop: """Override Adaptive.loop""" if self.cluster: - return self.cluster.loop + return self.cluster.loop # type: ignore[return-value] else: return IOLoop.current() + + def __del__(self): + self.stop(reason="adaptive-deleted") diff --git a/distributed/deploy/adaptive_core.py b/distributed/deploy/adaptive_core.py index 90254f4d9c..ccbfb0ccf8 100644 --- a/distributed/deploy/adaptive_core.py +++ b/distributed/deploy/adaptive_core.py @@ -2,37 +2,24 @@ import logging import math +from abc import ABC, abstractmethod from collections import defaultdict, deque from collections.abc import Iterable -from datetime import timedelta -from typing import TYPE_CHECKING, Literal, cast +from typing import TYPE_CHECKING, cast import tlz as toolz -from tornado.ioloop import IOLoop import dask.config -from dask.utils import parse_timedelta -from distributed.compatibility import PeriodicCallback from distributed.metrics import time if TYPE_CHECKING: - from typing_extensions import TypeAlias - from distributed.scheduler import WorkerState logger = logging.getLogger(__name__) -AdaptiveStateState: TypeAlias = Literal[ - "starting", - "running", - "stopped", - "inactive", -] - - -class AdaptiveCore: +class AdaptiveCore(ABC): """ The core logic for adaptive deployments, with none of the cluster details @@ -89,56 +76,24 @@ class AdaptiveCore: """ minimum: int - maximum: int | float + maximum: float wait_count: int - interval: int | float - periodic_callback: PeriodicCallback | None - plan: set[WorkerState] - requested: set[WorkerState] - observed: set[WorkerState] close_counts: defaultdict[WorkerState, int] log: deque[tuple[float, dict]] - #: Whether this adaptive strategy is periodically adapting - state: AdaptiveStateState _adapting: bool def __init__( self, minimum: int = 0, - maximum: int | float = math.inf, + maximum: float = math.inf, wait_count: int = 3, - interval: str | int | float | timedelta = "1s", ): if not isinstance(maximum, int) and not math.isinf(maximum): - raise TypeError(f"maximum must be int or inf; got {maximum}") + raise ValueError(f"maximum must be int or inf; got {maximum}") self.minimum = minimum self.maximum = maximum self.wait_count = wait_count - self.interval = parse_timedelta(interval, "seconds") - self.periodic_callback = None - - if self.interval: - import weakref - - self_ref = weakref.ref(self) - - async def _adapt(): - core = self_ref() - if core: - await core.adapt() - - self.periodic_callback = PeriodicCallback(_adapt, self.interval * 1000) - self.state = "starting" - self.loop.add_callback(self._start) - else: - self.state = "inactive" - try: - self.plan = set() - self.requested = set() - self.observed = set() - except Exception: - pass # internal state self.close_counts = defaultdict(int) @@ -147,38 +102,22 @@ async def _adapt(): maxlen=dask.config.get("distributed.admin.low-level-log-length") ) - def _start(self) -> None: - if self.state != "starting": - return - - assert self.periodic_callback is not None - self.periodic_callback.start() - self.state = "running" - logger.info( - "Adaptive scaling started: minimum=%s maximum=%s", - self.minimum, - self.maximum, - ) - - def stop(self) -> None: - if self.state in ("inactive", "stopped"): - return + @property + @abstractmethod + def plan(self) -> set[WorkerState]: ... - if self.state == "running": - assert self.periodic_callback is not None - self.periodic_callback.stop() - logger.info( - "Adaptive scaling stopped: minimum=%s maximum=%s", - self.minimum, - self.maximum, - ) + @property + @abstractmethod + def requested(self) -> set[WorkerState]: ... - self.periodic_callback = None - self.state = "stopped" + @property + @abstractmethod + def observed(self) -> set[WorkerState]: ... + @abstractmethod async def target(self) -> int: """The target number of workers that should exist""" - raise NotImplementedError() + ... async def workers_to_close(self, target: int) -> list: """ @@ -198,11 +137,11 @@ async def safe_target(self) -> int: return n - async def scale_down(self, n: int) -> None: - raise NotImplementedError() + @abstractmethod + async def scale_down(self, n: int) -> None: ... - async def scale_up(self, workers: Iterable) -> None: - raise NotImplementedError() + @abstractmethod + async def scale_up(self, workers: Iterable) -> None: ... async def recommendations(self, target: int) -> dict: """ @@ -270,16 +209,5 @@ async def adapt(self) -> None: await self.scale_up(**recommendations) if status == "down": await self.scale_down(**recommendations) - except Exception: - logger.warning( - "Adaptive encountered an error while adapting", exc_info=True - ) finally: self._adapting = False - - def __del__(self): - self.stop() - - @property - def loop(self) -> IOLoop: - return IOLoop.current() diff --git a/distributed/deploy/tests/test_adaptive.py b/distributed/deploy/tests/test_adaptive.py index 7513f2e884..441c1c609d 100644 --- a/distributed/deploy/tests/test_adaptive.py +++ b/distributed/deploy/tests/test_adaptive.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import logging import math from time import sleep @@ -17,9 +18,16 @@ Worker, wait, ) +from distributed.core import Status from distributed.deploy.cluster import Cluster from distributed.metrics import time -from distributed.utils_test import async_poll_for, gen_cluster, gen_test, slowinc +from distributed.utils_test import ( + async_poll_for, + captured_logger, + gen_cluster, + gen_test, + slowinc, +) def test_adaptive_local_cluster(loop): @@ -368,17 +376,23 @@ async def test_adapt_cores_memory(): @gen_test() async def test_adaptive_config(): - with dask.config.set( - {"distributed.adaptive.minimum": 10, "distributed.adaptive.wait-count": 8} - ): - try: - adapt = Adaptive(interval="5s") - assert adapt.minimum == 10 - assert adapt.maximum == math.inf - assert adapt.interval == 5 - assert adapt.wait_count == 8 - finally: - adapt.stop() + async with LocalCluster( + n_workers=0, + asynchronous=True, + silence_logs=False, + dashboard_address=":0", + ) as cluster: + with dask.config.set( + {"distributed.adaptive.minimum": 10, "distributed.adaptive.wait-count": 8} + ): + try: + adapt = Adaptive(cluster, interval="5s") + assert adapt.minimum == 10 + assert adapt.maximum == math.inf + assert adapt.interval == 5 + assert adapt.wait_count == 8 + finally: + adapt.stop() @gen_test() @@ -526,3 +540,200 @@ async def test_respect_average_nthreads(c, s, w): await asyncio.sleep(0.001) assert s.adaptive_target() == 40 + + +class MyAdaptive(Adaptive): + def __init__(self, *args, interval=None, **kwargs): + super().__init__(*args, interval=interval, **kwargs) + self._target = 0 + self._log = [] + self._observed = set() + self._plan = set() + self._requested = set() + + @property + def observed(self): + return self._observed + + @property + def plan(self): + return self._plan + + @property + def requested(self): + return self._requested + + async def target(self): + return self._target + + async def scale_up(self, n=0): + self._plan = self._requested = set(range(n)) + + async def scale_down(self, workers=()): + for collection in [self.plan, self.requested, self.observed]: + for w in workers: + collection.discard(w) + + +@gen_test() +async def test_adaptive_stops_on_cluster_status_change(): + async with LocalCluster( + n_workers=0, + asynchronous=True, + silence_logs=False, + dashboard_address=":0", + ) as cluster: + adapt = Adaptive(cluster, interval="100 ms") + assert adapt.state == "starting" + await async_poll_for(lambda: adapt.state == "running", timeout=5) + + assert adapt.periodic_callback + assert adapt.periodic_callback.is_running() + + try: + cluster.status = Status.closing + + await async_poll_for(lambda: adapt.state != "running", timeout=5) + assert adapt.state == "stopped" + assert not adapt.periodic_callback + finally: + # Set back to running to let normal shutdown do its thing + cluster.status = Status.running + + +@gen_test() +async def test_interval(): + async with LocalCluster( + n_workers=0, + asynchronous=True, + silence_logs=False, + dashboard_address=":0", + ) as cluster: + adapt = MyAdaptive(cluster=cluster, interval="100 ms") + assert not adapt.plan + + for i in [0, 3, 1]: + start = time() + adapt._target = i + while len(adapt.plan) != i: + await asyncio.sleep(0.01) + assert time() < start + 2 + + adapt.stop() + await asyncio.sleep(0.05) + + adapt._target = 10 + await asyncio.sleep(0.02) + assert len(adapt.plan) == 1 # last value from before, unchanged + + +@gen_test() +async def test_adapt_logs_error_in_safe_target(): + class BadAdaptive(MyAdaptive): + """Adaptive subclass which raises an OSError when attempting to adapt + + We use this to check that error handling works properly + """ + + def safe_target(self): + raise OSError() + + async with LocalCluster( + n_workers=0, + asynchronous=True, + silence_logs=False, + dashboard_address=":0", + ) as cluster: + with captured_logger( + "distributed.deploy.adaptive", level=logging.WARNING + ) as log: + adapt = cluster.adapt( + Adaptive=BadAdaptive, minimum=1, maximum=4, interval="10ms" + ) + while "encountered an error" not in log.getvalue(): + await asyncio.sleep(0.01) + assert "stop" not in log.getvalue() + assert adapt.state == "running" + assert adapt.periodic_callback + assert adapt.periodic_callback.is_running() + + +@gen_test() +async def test_adapt_callback_logs_error_in_scale_down(): + class BadAdaptive(MyAdaptive): + async def scale_down(self, workers=None): + raise OSError() + + async with LocalCluster( + n_workers=0, + asynchronous=True, + silence_logs=False, + dashboard_address=":0", + ) as cluster: + adapt = cluster.adapt( + Adaptive=BadAdaptive, minimum=1, maximum=4, wait_count=0, interval="10ms" + ) + adapt._target = 2 + await async_poll_for(lambda: adapt.state == "running", timeout=5) + assert adapt.periodic_callback.is_running() + await adapt.adapt() + assert len(adapt.plan) == 2 + assert len(adapt.requested) == 2 + with captured_logger( + "distributed.deploy.adaptive", level=logging.WARNING + ) as log: + adapt._target = 0 + while "encountered an error" not in log.getvalue(): + await asyncio.sleep(0.01) + assert "stop" not in log.getvalue() + assert not adapt._adapting + assert adapt.periodic_callback + assert adapt.periodic_callback.is_running() + + +@pytest.mark.parametrize("wait_until_running", [True, False]) +@gen_test() +async def test_adaptive_logs_stopping_once(wait_until_running): + async with LocalCluster( + n_workers=0, + asynchronous=True, + silence_logs=False, + dashboard_address=":0", + ) as cluster: + with captured_logger("distributed.deploy.adaptive") as log: + adapt = cluster.adapt(Adaptive=MyAdaptive, interval="100ms") + if wait_until_running: + await async_poll_for(lambda: adapt.state == "running", timeout=5) + assert adapt.periodic_callback + assert adapt.periodic_callback.is_running() + pc = adapt.periodic_callback + else: + assert adapt.periodic_callback + assert not adapt.periodic_callback.is_running() + pc = adapt.periodic_callback + + adapt.stop() + adapt.stop() + assert adapt.state == "stopped" + assert not adapt.periodic_callback + assert not pc.is_running() + lines = log.getvalue().splitlines() + assert sum("Adaptive scaling stopped" in line for line in lines) == 1 + + +@gen_test() +async def test_adapt_stop_del(): + async with LocalCluster( + n_workers=0, + asynchronous=True, + silence_logs=False, + dashboard_address=":0", + ) as cluster: + adapt = cluster.adapt(Adaptive=MyAdaptive, interval="100ms") + pc = adapt.periodic_callback + await async_poll_for(lambda: adapt.state == "running", timeout=5) # noqa: F821 + + # Remove reference of adaptive object from cluster + cluster._adaptive = None + del adapt + await async_poll_for(lambda: not pc.is_running(), timeout=5) diff --git a/distributed/deploy/tests/test_adaptive_core.py b/distributed/deploy/tests/test_adaptive_core.py index e2e7844026..cc2336e76a 100644 --- a/distributed/deploy/tests/test_adaptive_core.py +++ b/distributed/deploy/tests/test_adaptive_core.py @@ -1,24 +1,35 @@ from __future__ import annotations -import asyncio -import logging - from distributed.deploy.adaptive_core import AdaptiveCore -from distributed.metrics import time -from distributed.utils_test import captured_logger, gen_test +from distributed.utils_test import gen_test -class MyAdaptive(AdaptiveCore): - def __init__(self, *args, interval=None, **kwargs): - super().__init__(*args, interval=interval, **kwargs) +class MyAdaptiveCore(AdaptiveCore): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._observed = set() + self._plan = set() + self._requested = set() self._target = 0 self._log = [] + @property + def observed(self): + return self._observed + + @property + def plan(self): + return self._plan + + @property + def requested(self): + return self._requested + async def target(self): return self._target async def scale_up(self, n=0): - self.plan = self.requested = set(range(n)) + self._plan = self._requested = set(range(n)) async def scale_down(self, workers=()): for collection in [self.plan, self.requested, self.observed]: @@ -28,7 +39,7 @@ async def scale_down(self, workers=()): @gen_test() async def test_safe_target(): - adapt = MyAdaptive(minimum=1, maximum=4) + adapt = MyAdaptiveCore(minimum=1, maximum=4) assert await adapt.safe_target() == 1 adapt._target = 10 assert await adapt.safe_target() == 4 @@ -36,7 +47,7 @@ async def test_safe_target(): @gen_test() async def test_scale_up(): - adapt = MyAdaptive(minimum=1, maximum=4) + adapt = MyAdaptiveCore(minimum=1, maximum=4) await adapt.adapt() assert adapt.log[-1][1] == {"status": "up", "n": 1} assert adapt.plan == {0} @@ -49,12 +60,12 @@ async def test_scale_up(): @gen_test() async def test_scale_down(): - adapt = MyAdaptive(minimum=1, maximum=4, wait_count=2) + adapt = MyAdaptiveCore(minimum=1, maximum=4, wait_count=2) adapt._target = 10 await adapt.adapt() assert len(adapt.log) == 1 - adapt.observed = {0, 1, 3} # all but 2 have arrived + adapt._observed = {0, 1, 3} # all but 2 have arrived adapt._target = 2 await adapt.adapt() @@ -71,104 +82,3 @@ async def test_scale_down(): await adapt.adapt() await adapt.adapt() assert list(adapt.log) == old - - -@gen_test() -async def test_interval(): - adapt = MyAdaptive(interval="5 ms") - assert not adapt.plan - - for i in [0, 3, 1]: - start = time() - adapt._target = i - while len(adapt.plan) != i: - await asyncio.sleep(0.001) - assert time() < start + 2 - - adapt.stop() - await asyncio.sleep(0.05) - - adapt._target = 10 - await asyncio.sleep(0.02) - assert len(adapt.plan) == 1 # last value from before, unchanged - - -@gen_test() -async def test_adapt_logs_error_in_safe_target(): - class BadAdaptive(MyAdaptive): - """AdaptiveCore subclass which raises an OSError when attempting to adapt - - We use this to check that error handling works properly - """ - - def safe_target(self): - raise OSError() - - with captured_logger( - "distributed.deploy.adaptive_core", level=logging.WARNING - ) as log: - adapt = BadAdaptive(minimum=1, maximum=4, interval="10ms") - while "encountered an error" not in log.getvalue(): - await asyncio.sleep(0.01) - assert "stop" not in log.getvalue() - assert adapt.state == "running" - assert adapt.periodic_callback - assert adapt.periodic_callback.is_running() - - -@gen_test() -async def test_adapt_logs_errors(): - class BadAdaptive(MyAdaptive): - async def scale_down(self, workers=None): - raise OSError() - - adapt = BadAdaptive(minimum=1, maximum=4, wait_count=0, interval="10ms") - adapt._target = 2 - while adapt.state != "running": - await asyncio.sleep(0.01) - assert adapt.periodic_callback.is_running() - await adapt.adapt() - assert len(adapt.plan) == 2 - assert len(adapt.requested) == 2 - with captured_logger( - "distributed.deploy.adaptive_core", level=logging.WARNING - ) as log: - adapt._target = 0 - await adapt.adapt() - text = log.getvalue() - assert "encountered an error" in text - assert not adapt._adapting - assert adapt.periodic_callback - assert adapt.periodic_callback.is_running() - adapt.stop() - - -@gen_test() -async def test_adaptive_logs_stopping_once(): - with captured_logger("distributed.deploy.adaptive_core") as log: - adapt = MyAdaptive(interval="100ms") - while adapt.state != "running": - await asyncio.sleep(0.01) - assert adapt.periodic_callback - assert adapt.periodic_callback.is_running() - pc = adapt.periodic_callback - - adapt.stop() - adapt.stop() - assert adapt.state == "stopped" - assert not adapt.periodic_callback - assert not pc.is_running() - lines = log.getvalue().splitlines() - assert sum("Adaptive scaling stopped" in line for line in lines) == 1 - - -@gen_test() -async def test_adapt_stop_del(): - adapt = MyAdaptive(interval="100ms") - pc = adapt.periodic_callback - while adapt.state != "running": - await asyncio.sleep(0.01) - - del adapt - while pc.is_running(): - await asyncio.sleep(0.01)