Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a setting to cancel live jobs upon change #1248

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion ai_diffusion/comfy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,11 @@ async def _listen_websocket(self, websocket: websockets_client.WebSocketClientPr
ClientEvent.progress, self._active.local_id, progress.value
)
else:
log.error(f"Received message {msg} but there is no active job")
if settings.live_cancel_jobs_on_change:
# likely to be an out-of-order progress from cancelled/interrupted job so don't log these
pass
else:
log.error(f"Received message {msg} but there is no active job")
Copy link
Contributor Author

@modelflat modelflat Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why are we logging these at all, especially at the Error level


if msg["type"] == "executed":
job = self._get_active_job(msg["data"]["prompt_id"])
Expand Down
19 changes: 17 additions & 2 deletions ai_diffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,8 @@ async def _generate_live(self, last_input: WorkflowInput | None = None):
if input != last_input:
self.clear_error()
params = JobParams(bounds, conditioning.positive, regions=job_regions)
if settings.live_cancel_jobs_on_change:
await self._cancel_everything()
await self.enqueue_jobs(input, JobKind.live_preview, params)
return input

Expand All @@ -363,6 +365,17 @@ def _get_current_image(self, bounds: Bounds):
exclude.append(self._layer)
return self._doc.get_image(bounds, exclude_layers=exclude)

async def _cancel_everything(self):
await self._connection.client.clear_queue()
await self._connection.client.interrupt()
to_remove = [
job
for job in self.jobs
if job.state is JobState.queued or job.state is JobState.executing
]
for job in to_remove:
self.jobs.remove(job)

def generate_control_layer(self, control: ControlLayer):
ok, msg = self._doc.check_color_mode()
if not ok and msg:
Expand Down Expand Up @@ -811,14 +824,16 @@ def handle_job_finished(self, job: Job):
if len(job.results) > 0:
self.set_result(job.results[0], job.params)
self.is_active = self._is_active and self._model.document.is_active
eventloop.run(_report_errors(self._model, self._continue_generating()))
if not settings.live_cancel_jobs_on_change:
eventloop.run(_report_errors(self._model, self._continue_generating()))

async def _continue_generating(self):
while self.is_active and self._model.document.is_active:
new_input = await self._model._generate_live(self._last_input)
if new_input is not None: # frame was scheduled
self._last_input = new_input
return
if not settings.live_cancel_jobs_on_change:
return
# no changes in input data
await asyncio.sleep(self._poll_rate)

Expand Down
7 changes: 7 additions & 0 deletions ai_diffusion/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,13 @@ class Settings(QObject):
_("NSFW Filter"), 0.0, _("Attempt to filter out images with explicit content")
)

live_cancel_jobs_on_change: bool
_live_cancel_jobs_on_change = Setting(
_("Live: Cancel incomplete jobs on change"),
False,
_("Prevents intermediate results from being shown in Live Mode"),
)

new_seed_after_apply: bool
_new_seed_after_apply = Setting(
_("Live: New Seed after Apply"),
Expand Down
4 changes: 4 additions & 0 deletions ai_diffusion/ui/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,10 @@ def __init__(self):
self.add("auto_preview", SwitchSetting(S._auto_preview, parent=self))
self.add("show_steps", SwitchSetting(S._show_steps, parent=self))
self.add("new_seed_after_apply", SwitchSetting(S._new_seed_after_apply, parent=self))
self.add(
"live_cancel_jobs_on_change",
SwitchSetting(S._live_cancel_jobs_on_change, parent=self),
)
self.add("debug_dump_workflow", SwitchSetting(S._debug_dump_workflow, parent=self))

languages = [(lang.name, lang.id) for lang in Localization.available]
Expand Down