diff --git a/ai_diffusion/model.py b/ai_diffusion/model.py index 9aa592e1d2..4e6c2b696f 100644 --- a/ai_diffusion/model.py +++ b/ai_diffusion/model.py @@ -4,9 +4,10 @@ from dataclasses import replace from pathlib import Path from enum import Enum -from typing import Any, NamedTuple +import time +from typing import NamedTuple from PyQt5.QtCore import QObject, QUuid, pyqtSignal, Qt -from PyQt5.QtGui import QImage, QPainter, QColor, QBrush +from PyQt5.QtGui import QPainter, QColor, QBrush import uuid from . import eventloop, workflow, util @@ -28,7 +29,7 @@ from .connection import Connection from .properties import Property, ObservableProperties from .jobs import Job, JobKind, JobParams, JobQueue, JobState, JobRegion -from .control import ControlLayer, ControlLayerList +from .control import ControlLayer from .region import Region, RegionLink, RootRegion, process_regions, get_region_inpaint_mask from .resources import ControlMode from .resolution import compute_bounds, compute_relative_bounds @@ -294,9 +295,10 @@ def estimate_cost(self, kind=JobKind.diffusion): return 0 def generate_live(self): - eventloop.run(_report_errors(self, self._generate_live())) + input, job_params = self._prepare_live_workflow() + eventloop.run(_report_errors(self, self._generate_live(input, job_params))) - async def _generate_live(self, last_input: WorkflowInput | None = None): + def _prepare_live_workflow(self): strength = self.live.strength workflow_kind = WorkflowKind.generate if strength == 1.0 else WorkflowKind.refine client = self._connection.client @@ -344,13 +346,12 @@ async def _generate_live(self, last_input: WorkflowInput | None = None): inpaint=inpaint if mask else None, is_live=True, ) - if input != last_input: - self.clear_error() - params = JobParams(bounds, conditioning.positive, regions=job_regions) - await self.enqueue_jobs(input, JobKind.live_preview, params) - return input + params = JobParams(bounds, conditioning.positive, regions=job_regions) + return input, params - return None + async def _generate_live(self, input: WorkflowInput, job_params: JobParams): + self.clear_error() + await self.enqueue_jobs(input, JobKind.live_preview, job_params) async def _generate_custom(self, previous_input: WorkflowInput | None): if self.workspace is not Workspace.custom or not self.document.is_active: @@ -837,6 +838,7 @@ class LiveWorkspace(QObject, ObservableProperties): _model: Model _last_input: WorkflowInput | None = None + _last_change: float = 0 _result: Image | None = None _result_composition: Image | None = None _result_params: JobParams | None = None @@ -883,12 +885,17 @@ def handle_job_finished(self, job: Job): eventloop.run(_report_errors(self._model, self._continue_generating())) async def _continue_generating(self): - while self.is_active and self._model.document.is_active: - new_input = await self._model._generate_live(self._last_input) - if new_input is not None: # frame was scheduled - self._last_input = new_input - return - # no changes in input data + while self.is_active: + if self._model.document.is_active: + new_input, job_params = self._model._prepare_live_workflow() + if self._last_input != new_input: + now = time.monotonic() + if self._last_change + settings.live_redraw_grace_period <= now: + await self._model._generate_live(new_input, job_params) + self._last_input = new_input + return + else: + self._last_change = time.monotonic() await asyncio.sleep(self._poll_rate) def apply_result(self, layer_only=False): diff --git a/ai_diffusion/settings.py b/ai_diffusion/settings.py index e5fb466b8f..c834d5e700 100644 --- a/ai_diffusion/settings.py +++ b/ai_diffusion/settings.py @@ -155,6 +155,13 @@ class Settings(QObject): _("Pick a new seed after copying the result to the canvas in Live mode"), ) + live_redraw_grace_period: float + _live_redraw_grace_period = Setting( + _("Live: Redraw grace period"), + 0.0, + _("How long to delay scheduling the live preview job for after a change is made"), + ) + prompt_translation: str _prompt_translation = Setting( _("Prompt Translation"), diff --git a/ai_diffusion/ui/settings.py b/ai_diffusion/ui/settings.py index 7e48ed6975..796f701c05 100644 --- a/ai_diffusion/ui/settings.py +++ b/ai_diffusion/ui/settings.py @@ -439,6 +439,10 @@ def __init__(self): self.add("auto_preview", SwitchSetting(S._auto_preview, parent=self)) self.add("show_steps", SwitchSetting(S._show_steps, parent=self)) self.add("new_seed_after_apply", SwitchSetting(S._new_seed_after_apply, parent=self)) + self.add( + "live_redraw_grace_period", + SliderSetting(S._live_redraw_grace_period, self, 0.0, 3.0, "{} s"), + ) self.add("debug_dump_workflow", SwitchSetting(S._debug_dump_workflow, parent=self)) languages = [(lang.name, lang.id) for lang in Localization.available]