Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix 1137: Display ComfyUI Queue State #1156

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ai_diffusion/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class ClientMessage(NamedTuple):
images: ImageCollection | None = None
result: dict | None = None
error: str | None = None
# jobs queued before our next one
queue_length: int | None = None


class User(QObject, ObservableProperties):
Expand Down
32 changes: 29 additions & 3 deletions ai_diffusion/comfy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from enum import Enum
from collections import deque
from itertools import chain, product
from operator import itemgetter
from typing import NamedTuple, Optional, Sequence

from .api import WorkflowInput
Expand Down Expand Up @@ -264,6 +265,7 @@ async def _listen_websocket(self, websocket: websockets_client.WebSocketClientPr

if msg["type"] == "status":
await self._report(ClientEvent.connected, "")
await self._poll_server_queue()

if msg["type"] == "execution_start":
id = msg["data"]["prompt_id"]
Expand Down Expand Up @@ -340,9 +342,33 @@ async def clear_queue(self):
except asyncio.QueueEmpty:
break

await self._post("queue", {"clear": True})
remote_ids = [await job.get_remote_id() for job in self._jobs]
await self._post("api/queue", {"delete": remote_ids})

self._jobs.clear()

async def _poll_server_queue(self):
queue = await self._get("api/queue")
server_jobs = queue["queue_running"] + queue["queue_pending"]
# why are they unsorted to start with...?
server_jobs = sorted(server_jobs, key=itemgetter(0))
server_jobs = [entry[1] for entry in server_jobs]
if not (self._jobs or self._active):
FeepingCreature marked this conversation as resolved.
Show resolved Hide resolved
return

if self._active:
first_job = self._active
else:
first_job = self._jobs[0]
# as we got the job from `_jobs` or `_active`, this field must have been set (in `_run_job`).
first_remote_id = util.ensure(await first_job.get_remote_id())
try:
offset = server_jobs.index(first_remote_id)
except ValueError:
# probably just haven't gotten the notification yet
return
await self._report(ClientEvent.queued, first_job.local_id, queue_length=offset)

async def disconnect(self):
if self._is_connected:
self._is_connected = False
Expand Down Expand Up @@ -443,7 +469,7 @@ def _get_active_job(self, remote_id: str) -> Optional[JobInfo]:
return self._active
elif self._active:
log.warning(f"Received message for job {remote_id}, but job {self._active} is active")
if len(self._jobs) == 0:
if not self._jobs:
log.warning(f"Received unknown job {remote_id}")
return None
active = next((j for j in self._jobs if j.remote_id == remote_id), None)
Expand All @@ -454,7 +480,7 @@ def _get_active_job(self, remote_id: str) -> Optional[JobInfo]:
async def _start_job(self, remote_id: str):
if self._active is not None:
log.warning(f"Started job {remote_id}, but {self._active} was never finished")
if len(self._jobs) == 0:
if not self._jobs:
log.warning(f"Received unknown job {remote_id}")
return None

Expand Down
12 changes: 9 additions & 3 deletions ai_diffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class Model(QObject, ObservableProperties):
progress = Property(0.0)
jobs: JobQueue
error = Property("")
queue_length: int = 0

workspace_changed = pyqtSignal(Workspace)
style_changed = pyqtSignal(Style)
Expand All @@ -87,6 +88,7 @@ class Model(QObject, ObservableProperties):
error_changed = pyqtSignal(str)
has_error_changed = pyqtSignal(bool)
modified = pyqtSignal(QObject, str)
queue_length_changed = pyqtSignal(int)

def __init__(self, document: Document, connection: Connection):
super().__init__()
Expand Down Expand Up @@ -408,9 +410,13 @@ def handle_message(self, message: ClientMessage):
return

if message.event is ClientEvent.queued:
self.jobs.notify_started(job)
self.progress = -1
self.progress_changed.emit(-1)
if message.queue_length is not None:
self.queue_length = message.queue_length
self.queue_length_changed.emit(message.queue_length)
if message.queue_length is None or message.queue_length == 0:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is where the "don't start the job until it's actually running" thing moved to.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest to consider it started in all cases, like before (just add the queue length if available).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I don't see what it buys you. Why should we consider job started if the job is not started?

self.jobs.notify_started(job)
self.progress = -1
self.progress_changed.emit(-1)
elif message.event is ClientEvent.progress:
self.jobs.notify_started(job)
self.progress_kind = ProgressKind.generation
Expand Down
5 changes: 5 additions & 0 deletions ai_diffusion/ui/widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def _connect_model(self):
self._connections = [
self._model.jobs.count_changed.connect(self._update),
self._model.progress_kind_changed.connect(self._update),
self._model.queue_length_changed.connect(self._update),
]

def _update(self):
Expand All @@ -220,6 +221,10 @@ def _update(self):
self.setIcon(theme.icon("queue-upload"))
self.setToolTip(_("Uploading models.") + f" {queued_msg} {cancel_msg}")
count += 1
elif self._model.queue_length > 0:
self.setIcon(theme.icon("queue-inactive"))
self.setToolTip(_("Server is busy."))
count = f"+{self.model.queue_length}"
elif self._model.jobs.any_executing():
self.setIcon(theme.icon("queue-active"))
if count > 0:
Expand Down
Loading