Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove idle and saturated sets from scheduler #8889

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 7 additions & 113 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1764,11 +1764,8 @@ def __init__(
self.clients["fire-and-forget"] = ClientState("fire-and-forget")
self.extensions = {}
self.host_info = host_info
self.idle = SortedDict()
self.idle_task_count = set()
self.n_tasks = 0
self.resources = resources
self.saturated = set()
self.tasks = tasks
self.replicated_tasks = {
ts for ts in self.tasks.values() if len(ts.who_has or ()) > 1
Expand Down Expand Up @@ -1852,7 +1849,6 @@ def __pdict__(self) -> dict[str, Any]:
return {
"bandwidth": self.bandwidth,
"resources": self.resources,
"saturated": self.saturated,
"unrunnable": self.unrunnable,
"queued": self.queued,
"n_tasks": self.n_tasks,
Expand All @@ -1867,7 +1863,6 @@ def __pdict__(self) -> dict[str, Any]:
"extensions": self.extensions,
"clients": self.clients,
"workers": self.workers,
"idle": self.idle,
"host_info": self.host_info,
}

Expand Down Expand Up @@ -2282,7 +2277,7 @@ def decide_worker_rootish_queuing_disabled(
# See root-ish-ness note below in `decide_worker_rootish_queuing_enabled`
assert math.isinf(self.WORKER_SATURATION) or not ts._queueable

pool = self.idle.values() if self.idle else self.running
pool = self.running
if not pool:
return None

Expand Down Expand Up @@ -2347,22 +2342,16 @@ def decide_worker_rootish_queuing_enabled(self) -> WorkerState | None:
# then add that assertion here (and actually pass in the task).
assert not math.isinf(self.WORKER_SATURATION)

if not self.idle_task_count:
# All workers busy? Task gets/stays queued.
if not self.running:
return None

# Just pick the least busy worker.
# NOTE: this will lead to worst-case scheduling with regards to co-assignment.
ws = min(
self.idle_task_count,
key=lambda ws: len(ws.processing) / ws.nthreads,
)
ws = min(self.running, key=lambda ws: len(ws.processing) / ws.nthreads)
if _worker_full(ws, self.WORKER_SATURATION):
return None
if self.validate:
assert self.workers.get(ws.address) is ws
assert not _worker_full(ws, self.WORKER_SATURATION), (
ws,
_task_slots_available(ws, self.WORKER_SATURATION),
)
assert ws in self.running, (ws, self.running)

return ws
Expand Down Expand Up @@ -2406,7 +2395,7 @@ def decide_worker_non_rootish(self, ts: TaskState) -> WorkerState | None:
# dependencies, but its group is also smaller than the cluster.

# Fastpath when there are no related tasks or restrictions
worker_pool = self.idle or self.workers
worker_pool = self.workers
# FIXME idle and workers are SortedDict's declared as dicts
# because sortedcontainers is not annotated
wp_vals = cast("Sequence[WorkerState]", worker_pool.values())
Expand Down Expand Up @@ -2906,7 +2895,6 @@ def _transition_waiting_queued(self, key: Key, stimulus_id: str) -> RecsMsgs:
ts = self.tasks[key]

if self.validate:
assert not self.idle_task_count, (ts, self.idle_task_count)
self._validate_ready(ts)

ts.state = "queued"
Expand Down Expand Up @@ -3094,63 +3082,6 @@ def is_rootish(self, ts: TaskState) -> bool:
and sum(map(len, tg.dependencies)) < 5
)

def check_idle_saturated(self, ws: WorkerState, occ: float = -1.0) -> None:
"""Update the status of the idle and saturated state

The scheduler keeps track of workers that are ..

- Saturated: have enough work to stay busy
- Idle: do not have enough work to stay busy

They are considered saturated if they both have enough tasks to occupy
all of their threads, and if the expected runtime of those tasks is
large enough.

If ``distributed.scheduler.worker-saturation`` is not ``inf``
(scheduler-side queuing is enabled), they are considered idle
if they have fewer tasks processing than the ``worker-saturation``
threshold dictates.

Otherwise, they are considered idle if they have fewer tasks processing
than threads, or if their tasks' total expected runtime is less than half
the expected runtime of the same number of average tasks.

This is useful for load balancing and adaptivity.
"""
if self.total_nthreads == 0 or ws.status == Status.closed:
return
if occ < 0:
occ = ws.occupancy

p = len(ws.processing)

self.saturated.discard(ws)
if ws.status != Status.running:
self.idle.pop(ws.address, None)
elif self.is_unoccupied(ws, occ, p):
self.idle[ws.address] = ws
else:
self.idle.pop(ws.address, None)
nc = ws.nthreads
if p > nc:
pending = occ * (p - nc) / (p * nc)
if 0.4 < pending > 1.9 * (self.total_occupancy / self.total_nthreads):
self.saturated.add(ws)

if not _worker_full(ws, self.WORKER_SATURATION) and ws.status == Status.running:
self.idle_task_count.add(ws)
else:
self.idle_task_count.discard(ws)

def is_unoccupied(
self, ws: WorkerState, occupancy: float, nprocessing: int
) -> bool:
nthreads = ws.nthreads
return (
nprocessing < nthreads
or occupancy < nthreads * (self.total_occupancy / self.total_nthreads) / 2
)

def get_comm_cost(self, ts: TaskState, ws: WorkerState) -> float:
"""
Get the estimated communication cost (in s.) to compute the task
Expand Down Expand Up @@ -3357,7 +3288,6 @@ def _add_to_processing(
ts.processing_on = ws
ts.state = "processing"
self.acquire_resources(ts, ws)
self.check_idle_saturated(ws)
self.n_tasks += 1

if ts.actor:
Expand Down Expand Up @@ -3423,7 +3353,6 @@ def _exit_processing_common(self, ts: TaskState) -> WorkerState | None:
if self.workers.get(ws.address) is not ws: # may have been removed
return None

self.check_idle_saturated(ws)
self.release_resources(ts, ws)

return ws
Expand Down Expand Up @@ -4547,10 +4476,6 @@ async def add_worker(
metrics=metrics,
)

# Do not need to adjust self.total_occupancy as self.occupancy[ws] cannot
# exist before this.
self.check_idle_saturated(ws)

self.stream_comms[address] = BatchedSend(interval="5ms", loop=self.loop)

awaitables = []
Expand Down Expand Up @@ -5167,13 +5092,11 @@ def stimulus_queue_slots_maybe_opened(self, *, stimulus_id: str) -> None:
so any tasks that became runnable are already in ``processing``. Otherwise,
overproduction can occur if queued tasks get scheduled before downstream tasks.

Must be called after `check_idle_saturated`; i.e. `idle_task_count` must be up to date.
"""
if not self.queued:
return
slots_available = sum(
_task_slots_available(ws, self.WORKER_SATURATION)
for ws in self.idle_task_count
_task_slots_available(ws, self.WORKER_SATURATION) for ws in self.running
)
if slots_available == 0:
return
Expand Down Expand Up @@ -5403,9 +5326,6 @@ async def remove_worker(
self.rpc.remove(address)
del self.stream_comms[address]
del self.aliases[ws.name]
self.idle.pop(ws.address, None)
self.idle_task_count.discard(ws)
self.saturated.discard(ws)
del self.workers[address]
self._workers_removed_total += 1
ws.status = Status.closed
Expand Down Expand Up @@ -5734,23 +5654,6 @@ def validate_state(self, allow_overlap: bool = False) -> None:
if not (set(self.workers) == set(self.stream_comms)):
raise ValueError("Workers not the same in all collections")

assert self.running.issuperset(self.idle.values()), (
self.running.copy(),
set(self.idle.values()),
)
assert self.running.issuperset(self.idle_task_count), (
self.running.copy(),
self.idle_task_count.copy(),
)
assert self.running.issuperset(self.saturated), (
self.running.copy(),
self.saturated.copy(),
)
assert self.saturated.isdisjoint(self.idle.values()), (
self.saturated.copy(),
set(self.idle.values()),
)

task_prefix_counts: defaultdict[str, int] = defaultdict(int)
for w, ws in self.workers.items():
assert isinstance(w, str), (type(w), w)
Expand All @@ -5761,14 +5664,10 @@ def validate_state(self, allow_overlap: bool = False) -> None:
assert ws in self.running
else:
assert ws not in self.running
assert ws.address not in self.idle
assert ws not in self.saturated

assert ws.long_running.issubset(ws.processing)
if not ws.processing:
assert not ws.occupancy
if ws.status == Status.running:
assert ws.address in self.idle
assert not ws.needs_what.keys() & ws.has_what
actual_needs_what: defaultdict[TaskState, int] = defaultdict(int)
for ts in ws.processing:
Expand Down Expand Up @@ -6031,7 +5930,6 @@ def handle_long_running(
ts.prefix.duration_average = (old_duration + compute_duration) / 2

ws.add_to_long_running(ts)
self.check_idle_saturated(ws)

self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id)

Expand Down Expand Up @@ -6059,16 +5957,12 @@ def handle_worker_status_change(

if ws.status == Status.running:
self.running.add(ws)
self.check_idle_saturated(ws)
self.transitions(
self.bulk_schedule_unrunnable_after_adding_worker(ws), stimulus_id
)
self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id)
else:
self.running.discard(ws)
self.idle.pop(ws.address, None)
self.idle_task_count.discard(ws)
self.saturated.discard(ws)
self._refresh_no_workers_since()

def handle_request_refresh_who_has(
Expand Down
Loading
Loading