diff --git a/ai_diffusion/client.py b/ai_diffusion/client.py index 4b39f11c49..59540508f4 100644 --- a/ai_diffusion/client.py +++ b/ai_diffusion/client.py @@ -25,6 +25,7 @@ class ClientEvent(Enum): disconnected = 5 queued = 6 upload = 7 + foreign_jobs = 8 class ClientMessage(NamedTuple): @@ -34,6 +35,7 @@ class ClientMessage(NamedTuple): images: ImageCollection | None = None result: dict | None = None error: str | None = None + foreign_jobs: int | None = None class User(QObject, ObservableProperties): diff --git a/ai_diffusion/comfy_client.py b/ai_diffusion/comfy_client.py index 0d4f7fec0f..6500800fd7 100644 --- a/ai_diffusion/comfy_client.py +++ b/ai_diffusion/comfy_client.py @@ -7,6 +7,7 @@ from enum import Enum from collections import deque from itertools import chain, product +from operator import itemgetter from typing import NamedTuple, Optional, Sequence from .api import WorkflowInput @@ -99,6 +100,7 @@ class ComfyClient(Client): _websocket_listener: asyncio.Task _supported_sd_versions: list[SDVersion] _supported_languages: list[TranslationPackage] + _server_job_ids: list[str] = [] @staticmethod async def connect(url=default_url, access_token=""): @@ -264,6 +266,7 @@ async def _listen_websocket(self, websocket: websockets_client.WebSocketClientPr if msg["type"] == "status": await self._report(ClientEvent.connected, "") + await self.update_server_queue() if msg["type"] == "execution_start": id = msg["data"]["prompt_id"] @@ -340,9 +343,34 @@ async def clear_queue(self): except asyncio.QueueEmpty: break - await self._post("queue", {"clear": True}) + remote_ids = [await job.get_remote_id() for job in self._jobs] + await self._post("api/queue", {"delete": remote_ids}) + self._jobs.clear() + async def update_server_queue(self): + queue = await self._get("api/queue") + server_jobs = queue["queue_running"] + queue["queue_pending"] + # why are they unsorted to start with...? + server_jobs = sorted(server_jobs, key=itemgetter(0)) + self._server_job_ids = [entry[1] for entry in server_jobs] + if not (self._jobs or self._active): + await self._report(ClientEvent.foreign_jobs, "", foreign_jobs=len(self._server_job_ids)) + return + + try: + if self._active: + first_remote_id = await self._active.get_remote_id() + else: + first_remote_id = await self._jobs[0].get_remote_id() + # if we got it from _jobs or _active, this field must have been set (in _run_job). + first_remote_id = util.ensure(first_remote_id) + offset = self._server_job_ids.index(first_remote_id) + await self._report(ClientEvent.foreign_jobs, "", foreign_jobs=offset) + except ValueError: + # probably just haven't gotten the notification yet + pass + async def disconnect(self): if self._is_connected: self._is_connected = False @@ -443,7 +471,7 @@ def _get_active_job(self, remote_id: str) -> Optional[JobInfo]: return self._active elif self._active: log.warning(f"Received message for job {remote_id}, but job {self._active} is active") - if len(self._jobs) == 0: + if not self._jobs: log.warning(f"Received unknown job {remote_id}") return None active = next((j for j in self._jobs if j.remote_id == remote_id), None) @@ -454,7 +482,7 @@ def _get_active_job(self, remote_id: str) -> Optional[JobInfo]: async def _start_job(self, remote_id: str): if self._active is not None: log.warning(f"Started job {remote_id}, but {self._active} was never finished") - if len(self._jobs) == 0: + if not self._jobs: log.warning(f"Received unknown job {remote_id}") return None diff --git a/ai_diffusion/connection.py b/ai_diffusion/connection.py index 302b8d06b8..1de15c835c 100644 --- a/ai_diffusion/connection.py +++ b/ai_diffusion/connection.py @@ -31,11 +31,13 @@ class Connection(QObject, ObservableProperties): state = Property(ConnectionState.disconnected) error = Property("") missing_resource: MissingResource | None = None + foreign_jobs: int = 0 state_changed = pyqtSignal(ConnectionState) error_changed = pyqtSignal(str) models_changed = pyqtSignal() message_received = pyqtSignal(ClientMessage) + foreign_jobs_changed = pyqtSignal(int) _client: Client | None = None _task: asyncio.Task | None = None @@ -169,6 +171,9 @@ async def _handle_messages(self): if temporary_disconnect: temporary_disconnect = False self.error = "" + elif msg.event is ClientEvent.foreign_jobs: + self.foreign_jobs = util.ensure(msg.foreign_jobs) + self.foreign_jobs_changed.emit(self.foreign_jobs) else: self.message_received.emit(msg) except asyncio.CancelledError: diff --git a/ai_diffusion/ui/widget.py b/ai_diffusion/ui/widget.py index 3e95264efc..3863d162f3 100644 --- a/ai_diffusion/ui/widget.py +++ b/ai_diffusion/ui/widget.py @@ -209,6 +209,7 @@ def _connect_model(self): self._connections = [ self._model.jobs.count_changed.connect(self._update), self._model.progress_kind_changed.connect(self._update), + root.connection.foreign_jobs_changed.connect(self._update), ] def _update(self): @@ -220,6 +221,10 @@ def _update(self): self.setIcon(theme.icon("queue-upload")) self.setToolTip(_("Uploading models.") + f" {queued_msg} {cancel_msg}") count += 1 + elif root.connection.foreign_jobs > 0: + self.setIcon(theme.icon("queue-inactive")) + self.setToolTip(_("Server is busy.")) + count = f"+{root.connection.foreign_jobs}" elif self._model.jobs.any_executing(): self.setIcon(theme.icon("queue-active")) if count > 0: