Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a grace period between detecting a change and triggering generation in live preview #1412

Merged
merged 4 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 24 additions & 17 deletions ai_diffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from dataclasses import replace
from pathlib import Path
from enum import Enum
from typing import Any, NamedTuple
import time
from typing import NamedTuple
from PyQt5.QtCore import QObject, QUuid, pyqtSignal, Qt
from PyQt5.QtGui import QImage, QPainter, QColor, QBrush
from PyQt5.QtGui import QPainter, QColor, QBrush
import uuid

from . import eventloop, workflow, util
Expand All @@ -28,7 +29,7 @@
from .connection import Connection
from .properties import Property, ObservableProperties
from .jobs import Job, JobKind, JobParams, JobQueue, JobState, JobRegion
from .control import ControlLayer, ControlLayerList
from .control import ControlLayer
from .region import Region, RegionLink, RootRegion, process_regions, get_region_inpaint_mask
from .resources import ControlMode
from .resolution import compute_bounds, compute_relative_bounds
Expand Down Expand Up @@ -294,9 +295,10 @@ def estimate_cost(self, kind=JobKind.diffusion):
return 0

def generate_live(self):
eventloop.run(_report_errors(self, self._generate_live()))
input, job_params = self._prepare_live_workflow()
eventloop.run(_report_errors(self, self._generate_live(input, job_params)))

async def _generate_live(self, last_input: WorkflowInput | None = None):
def _prepare_live_workflow(self):
strength = self.live.strength
workflow_kind = WorkflowKind.generate if strength == 1.0 else WorkflowKind.refine
client = self._connection.client
Expand Down Expand Up @@ -344,13 +346,12 @@ async def _generate_live(self, last_input: WorkflowInput | None = None):
inpaint=inpaint if mask else None,
is_live=True,
)
if input != last_input:
self.clear_error()
params = JobParams(bounds, conditioning.positive, regions=job_regions)
await self.enqueue_jobs(input, JobKind.live_preview, params)
return input
params = JobParams(bounds, conditioning.positive, regions=job_regions)
return input, params

return None
async def _generate_live(self, input: WorkflowInput, job_params: JobParams):
self.clear_error()
await self.enqueue_jobs(input, JobKind.live_preview, job_params)

async def _generate_custom(self, previous_input: WorkflowInput | None):
if self.workspace is not Workspace.custom or not self.document.is_active:
Expand Down Expand Up @@ -837,6 +838,7 @@ class LiveWorkspace(QObject, ObservableProperties):

_model: Model
_last_input: WorkflowInput | None = None
_last_change: float = 0
_result: Image | None = None
_result_composition: Image | None = None
_result_params: JobParams | None = None
Expand Down Expand Up @@ -883,12 +885,17 @@ def handle_job_finished(self, job: Job):
eventloop.run(_report_errors(self._model, self._continue_generating()))

async def _continue_generating(self):
while self.is_active and self._model.document.is_active:
new_input = await self._model._generate_live(self._last_input)
if new_input is not None: # frame was scheduled
self._last_input = new_input
return
# no changes in input data
while self.is_active:
if self._model.document.is_active:
new_input, job_params = self._model._prepare_live_workflow()
if self._last_input != new_input:
now = time.monotonic()
if self._last_change + settings.live_redraw_grace_period <= now:
await self._model._generate_live(new_input, job_params)
self._last_input = new_input
return
else:
self._last_change = time.monotonic()
await asyncio.sleep(self._poll_rate)

def apply_result(self, layer_only=False):
Expand Down
7 changes: 7 additions & 0 deletions ai_diffusion/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,13 @@ class Settings(QObject):
_("Pick a new seed after copying the result to the canvas in Live mode"),
)

live_redraw_grace_period: float
_live_redraw_grace_period = Setting(
_("Live: Redraw grace period"),
0.0,
_("How long to delay scheduling the live preview job for after a change is made"),
)

prompt_translation: str
_prompt_translation = Setting(
_("Prompt Translation"),
Expand Down
4 changes: 4 additions & 0 deletions ai_diffusion/ui/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,10 @@ def __init__(self):
self.add("auto_preview", SwitchSetting(S._auto_preview, parent=self))
self.add("show_steps", SwitchSetting(S._show_steps, parent=self))
self.add("new_seed_after_apply", SwitchSetting(S._new_seed_after_apply, parent=self))
self.add(
"live_redraw_grace_period",
SliderSetting(S._live_redraw_grace_period, self, 0.0, 3.0, "{} s"),
)
self.add("debug_dump_workflow", SwitchSetting(S._debug_dump_workflow, parent=self))

languages = [(lang.name, lang.id) for lang in Localization.available]
Expand Down