diff --git a/ai_diffusion/client.py b/ai_diffusion/client.py index 4b39f11c49..337cce9c59 100644 --- a/ai_diffusion/client.py +++ b/ai_diffusion/client.py @@ -34,6 +34,8 @@ class ClientMessage(NamedTuple): images: ImageCollection | None = None result: dict | None = None error: str | None = None + # jobs queued before our next one + queue_length: int | None = None class User(QObject, ObservableProperties): diff --git a/ai_diffusion/comfy_client.py b/ai_diffusion/comfy_client.py index 0d4f7fec0f..061101c658 100644 --- a/ai_diffusion/comfy_client.py +++ b/ai_diffusion/comfy_client.py @@ -7,6 +7,7 @@ from enum import Enum from collections import deque from itertools import chain, product +from operator import itemgetter from typing import NamedTuple, Optional, Sequence from .api import WorkflowInput @@ -264,6 +265,7 @@ async def _listen_websocket(self, websocket: websockets_client.WebSocketClientPr if msg["type"] == "status": await self._report(ClientEvent.connected, "") + await self._poll_server_queue() if msg["type"] == "execution_start": id = msg["data"]["prompt_id"] @@ -340,9 +342,33 @@ async def clear_queue(self): except asyncio.QueueEmpty: break - await self._post("queue", {"clear": True}) + remote_ids = [await job.get_remote_id() for job in self._jobs] + await self._post("api/queue", {"delete": remote_ids}) + self._jobs.clear() + async def _poll_server_queue(self): + queue = await self._get("api/queue") + server_jobs = queue["queue_running"] + queue["queue_pending"] + # why are they unsorted to start with...? + server_jobs = sorted(server_jobs, key=itemgetter(0)) + server_jobs = [entry[1] for entry in server_jobs] + if not (self._jobs or self._active): + return + + if self._active: + first_job = self._active + else: + first_job = self._jobs[0] + # as we got the job from `_jobs` or `_active`, this field must have been set (in `_run_job`). + first_remote_id = util.ensure(await first_job.get_remote_id()) + try: + offset = server_jobs.index(first_remote_id) + except ValueError: + # probably just haven't gotten the notification yet + return + await self._report(ClientEvent.queued, first_job.local_id, queue_length=offset) + async def disconnect(self): if self._is_connected: self._is_connected = False @@ -443,7 +469,7 @@ def _get_active_job(self, remote_id: str) -> Optional[JobInfo]: return self._active elif self._active: log.warning(f"Received message for job {remote_id}, but job {self._active} is active") - if len(self._jobs) == 0: + if not self._jobs: log.warning(f"Received unknown job {remote_id}") return None active = next((j for j in self._jobs if j.remote_id == remote_id), None) @@ -454,7 +480,7 @@ def _get_active_job(self, remote_id: str) -> Optional[JobInfo]: async def _start_job(self, remote_id: str): if self._active is not None: log.warning(f"Started job {remote_id}, but {self._active} was never finished") - if len(self._jobs) == 0: + if not self._jobs: log.warning(f"Received unknown job {remote_id}") return None diff --git a/ai_diffusion/model.py b/ai_diffusion/model.py index 805c3b19e3..8fea304996 100644 --- a/ai_diffusion/model.py +++ b/ai_diffusion/model.py @@ -72,6 +72,7 @@ class Model(QObject, ObservableProperties): progress = Property(0.0) jobs: JobQueue error = Property("") + queue_length: int = 0 workspace_changed = pyqtSignal(Workspace) style_changed = pyqtSignal(Style) @@ -87,6 +88,7 @@ class Model(QObject, ObservableProperties): error_changed = pyqtSignal(str) has_error_changed = pyqtSignal(bool) modified = pyqtSignal(QObject, str) + queue_length_changed = pyqtSignal(int) def __init__(self, document: Document, connection: Connection): super().__init__() @@ -408,9 +410,13 @@ def handle_message(self, message: ClientMessage): return if message.event is ClientEvent.queued: - self.jobs.notify_started(job) - self.progress = -1 - self.progress_changed.emit(-1) + if message.queue_length is not None: + self.queue_length = message.queue_length + self.queue_length_changed.emit(message.queue_length) + if message.queue_length is None or message.queue_length == 0: + self.jobs.notify_started(job) + self.progress = -1 + self.progress_changed.emit(-1) elif message.event is ClientEvent.progress: self.jobs.notify_started(job) self.progress_kind = ProgressKind.generation diff --git a/ai_diffusion/ui/widget.py b/ai_diffusion/ui/widget.py index 3e95264efc..2ca51f49bc 100644 --- a/ai_diffusion/ui/widget.py +++ b/ai_diffusion/ui/widget.py @@ -209,6 +209,7 @@ def _connect_model(self): self._connections = [ self._model.jobs.count_changed.connect(self._update), self._model.progress_kind_changed.connect(self._update), + self._model.queue_length_changed.connect(self._update), ] def _update(self): @@ -220,6 +221,10 @@ def _update(self): self.setIcon(theme.icon("queue-upload")) self.setToolTip(_("Uploading models.") + f" {queued_msg} {cancel_msg}") count += 1 + elif self._model.queue_length > 0: + self.setIcon(theme.icon("queue-inactive")) + self.setToolTip(_("Server is busy.")) + count = f"+{self.model.queue_length}" elif self._model.jobs.any_executing(): self.setIcon(theme.icon("queue-active")) if count > 0: