Skip to content

Commit

Permalink
Auto-configure live grace period depending on avg generation time
Browse files Browse the repository at this point in the history
- no delay if generation time is fast
- use default values otherwise, removed setting
  • Loading branch information
Acly committed Nov 30, 2024
1 parent 842108b commit 5b286e1
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 44 deletions.
94 changes: 61 additions & 33 deletions ai_diffusion/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
from copy import copy
from collections import deque
from dataclasses import replace
from pathlib import Path
from enum import Enum
Expand Down Expand Up @@ -841,6 +842,54 @@ def params(self):
_unblur_strength_map = {0: 0.0, 1: 0.5, 2: 1.0}


class LiveScheduler:
poll_rate = 0.1
default_grace_period = 0.25 # seconds to delay after most recent document edit
max_wait_time = 3.0 # maximum seconds to delay over total editing time
delay_threshold = 1.5 # use delay only if average generation time exceeds this value

def __init__(self):
self._last_input: WorkflowInput | None = None
self._last_change = 0.0
self._oldest_change = 0.0
self._has_changes = True
self._generation_start_time = 0.0
self._generation_times: deque[float] = deque(maxlen=10)

def should_generate(self, input: WorkflowInput):
now = time.monotonic()
if self._last_input != input:
self._last_input = input
self._last_change = now
if not self._has_changes:
self._oldest_change = now
self._has_changes = True

time_since_last_change = now - self._last_change
time_since_oldest_change = now - self._oldest_change
return self._has_changes and (
time_since_last_change >= self.grace_period
or time_since_oldest_change >= self.max_wait_time
)

def notify_generation_started(self):
self._generation_start_time = time.monotonic()
self._has_changes = False

def notify_generation_finished(self):
self._generation_times.append(time.monotonic() - self._generation_start_time)

@property
def average_generation_time(self):
return sum(self._generation_times) / max(1, len(self._generation_times))

@property
def grace_period(self):
if self.average_generation_time > self.delay_threshold:
return self.default_grace_period
return 0.0


class LiveWorkspace(QObject, ObservableProperties):
is_active = Property(False, setter="toggle")
is_recording = Property(False, setter="toggle_record")
Expand All @@ -855,27 +904,17 @@ class LiveWorkspace(QObject, ObservableProperties):
result_available = pyqtSignal(Image)
modified = pyqtSignal(QObject, str)

_model: Model
_last_input: WorkflowInput | None = None
_last_change: float = 0
_oldest_change: float = 0
_has_changes: bool = True
_result: Image | None = None
_result_composition: Image | None = None
_result_params: JobParams | None = None
_keyframes_folder: Path | None = None
_keyframe_start = 0
_keyframe_index = 0
_keyframes: list[Path]

_poll_rate = 0.1
_grace_period = 0.25
_max_wait_time = 3.0

def __init__(self, model: Model):
super().__init__()
self._model = model
self._keyframes = []
self._scheduler = LiveScheduler()
self._result: Image | None = None
self._result_composition: Image | None = None
self._result_params: JobParams | None = None
self._keyframes_folder: Path | None = None
self._keyframe_start = 0
self._keyframe_index = 0
self._keyframes: list[Path] = []
model.jobs.job_finished.connect(self.handle_job_finished)

def toggle(self, active: bool):
Expand Down Expand Up @@ -905,29 +944,18 @@ def handle_job_finished(self, job: Job):
if len(job.results) > 0:
self.set_result(job.results[0], job.params)
self.is_active = self._is_active and self._model.document.is_active
self._scheduler.notify_generation_finished()
eventloop.run(_report_errors(self._model, self._continue_generating()))

async def _continue_generating(self):
while self.is_active:
if self._model.document.is_active:
new_input, job_params = self._model._prepare_live_workflow()
now = time.monotonic()
if self._last_input != new_input:
self._last_input = new_input
self._last_change = now
if not self._has_changes:
self._oldest_change = now
self._has_changes = True
time_since_last_change = now - self._last_change
time_since_oldest_change = now - self._oldest_change
if self._has_changes and (
time_since_last_change >= self._grace_period
or time_since_oldest_change >= self._max_wait_time
):
if self._scheduler.should_generate(new_input):
await self._model._generate_live(new_input, job_params)
self._has_changes = False
self._scheduler.notify_generation_started()
return
await asyncio.sleep(self._poll_rate)
await asyncio.sleep(self._scheduler.poll_rate)

def apply_result(self, layer_only=False):
assert self.result is not None and self._result_params is not None
Expand Down
7 changes: 0 additions & 7 deletions ai_diffusion/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,6 @@ class Settings(QObject):
_("Pick a new seed after copying the result to the canvas in Live mode"),
)

live_redraw_grace_period: float
_live_redraw_grace_period = Setting(
_("Live: Redraw grace period"),
0.0,
_("How long to delay scheduling the live preview job for after a change is made"),
)

prompt_translation: str
_prompt_translation = Setting(
_("Prompt Translation"),
Expand Down
4 changes: 0 additions & 4 deletions ai_diffusion/ui/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,10 +439,6 @@ def __init__(self):
self.add("auto_preview", SwitchSetting(S._auto_preview, parent=self))
self.add("show_steps", SwitchSetting(S._show_steps, parent=self))
self.add("new_seed_after_apply", SwitchSetting(S._new_seed_after_apply, parent=self))
self.add(
"live_redraw_grace_period",
SliderSetting(S._live_redraw_grace_period, self, 0.0, 3.0, "{} s"),
)
self.add("debug_dump_workflow", SwitchSetting(S._debug_dump_workflow, parent=self))

languages = [(lang.name, lang.id) for lang in Localization.available]
Expand Down

0 comments on commit 5b286e1

Please sign in to comment.