Skip to content

Commit

Permalink
More generation metadata and option to select style used from history #…
Browse files Browse the repository at this point in the history
  • Loading branch information
Acly committed Sep 12, 2024
1 parent e41ce92 commit 6124149
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 7 deletions.
9 changes: 9 additions & 0 deletions ai_diffusion/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from .image import Bounds, ImageCollection
from .settings import settings
from .style import Style
from .util import ensure
from . import control

Expand Down Expand Up @@ -49,6 +50,9 @@ class JobParams:
regions: list[JobRegion] = field(default_factory=list)
strength: float = 1.0
seed: int = 0
style: str = ""
checkpoint: str = ""
sampler: str = ""
has_mask: bool = False
frame: tuple[int, int, int] = (0, 0, 0)
animation_id: str = ""
Expand All @@ -66,6 +70,11 @@ def equal_ignore_seed(cls, a: JobParams | None, b: JobParams | None):
field_names = (f.name for f in fields(cls) if not f.name == "seed")
return all(getattr(a, name) == getattr(b, name) for name in field_names)

def set_style(self, style: Style):
self.style = style.filename
self.checkpoint = style.sd_checkpoint
self.sampler = f"{style.sampler} ({style.sampler_steps} / {style.cfg_scale})"


class Job:
id: str | None
Expand Down
1 change: 1 addition & 0 deletions ai_diffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def _prepare_workflow(self, dryrun=False):
inpaint=inpaint,
)
job_params = JobParams(bounds, prompt, regions=job_regions)
job_params.set_style(self.style)
return input, job_params

async def enqueue_jobs(
Expand Down
43 changes: 36 additions & 7 deletions ai_diffusion/ui/generation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from textwrap import wrap as wrap_text
from PyQt5.QtCore import Qt, QMetaObject, QSize, QPoint, QUuid, pyqtSignal
from PyQt5.QtGui import QGuiApplication, QMouseEvent, QPalette, QColor
from PyQt5.QtWidgets import (
Expand All @@ -25,10 +26,11 @@
from ..image import Bounds, Extent, Image
from ..jobs import Job, JobQueue, JobState, JobKind, JobParams
from ..model import Model, InpaintContext, RootRegion, ProgressKind
from ..style import Styles
from ..root import root
from ..workflow import InpaintMode, FillMode
from ..localization import translate as _
from ..util import ensure
from ..util import ensure, flatten
from .widget import WorkspaceSelectWidget, StyleSelectWidget, StrengthWidget, QueueButton
from .widget import GenerateButton, create_wide_tool_button
from .region import RegionPromptWidget
Expand Down Expand Up @@ -130,10 +132,10 @@ def add(self, job: Job):
scroll_to_bottom = (
scrollbar and scrollbar.isVisible() and scrollbar.value() >= scrollbar.maximum() - 4
)
prompt = job.params.prompt if job.params.prompt != "" else "<no prompt>"

if not JobParams.equal_ignore_seed(self._last_job_params, job.params):
self._last_job_params = job.params
prompt = job.params.prompt if job.params.prompt != "" else "<no prompt>"
strength = f"{job.params.strength*100:.0f}% - " if job.params.strength != 1.0 else ""

header = QListWidgetItem(f"{job.timestamp:%H:%M} - {strength}{prompt}")
Expand All @@ -148,16 +150,34 @@ def add(self, job: Job):
item = QListWidgetItem(self._image_thumbnail(job, i), None) # type: ignore (text can be None)
item.setData(Qt.ItemDataRole.UserRole, job.id)
item.setData(Qt.ItemDataRole.UserRole + 1, i)
item.setData(
Qt.ItemDataRole.ToolTipRole,
f"{prompt} @ {job.params.strength*100:.0f}% strength\n"
+ _("Click to toggle preview, double-click to apply."),
)
item.setData(Qt.ItemDataRole.ToolTipRole, self._job_info(job.params))
self.addItem(item)

if scroll_to_bottom:
self.scrollToBottom()

def _job_info(self, params: JobParams):
prompt = params.prompt if params.prompt != "" else "<no prompt>"
if len(prompt) > 70:
prompt = prompt[:66] + "..."
style = Styles.list().find(params.style)
positive = _("Prompt") + f": {params.prompt or '-'}"
negative = _("Negative Prompt") + f": {params.negative_prompt or '-'}"
strings = [
f"{prompt} @ {params.strength*100:.0f}%\n",
_("Click to toggle preview, double-click to apply."),
"",
_("Style") + f": {style.name if style else params.style}",
wrap_text(positive, 80, subsequent_indent=" "),
wrap_text(negative, 80, subsequent_indent=" "),
_("Strength") + f": {params.strength*100:.0f}%",
_("Model") + f": {params.checkpoint}",
_("Sampler") + f": {params.sampler}",
_("Seed") + f": {params.seed}",
f"{params.bounds}",
]
return "\n".join(flatten(strings))

def remove(self, job: Job):
self._remove_items(ensure(job.id))

Expand Down Expand Up @@ -313,9 +333,13 @@ def _image_thumbnail(self, job: Job, index: int):
def _show_context_menu(self, pos: QPoint):
item = self.itemAt(pos)
if item is not None:
job = self._model.jobs.find(self._item_data(item).job)
menu = QMenu(self)
menu.addAction(_("Copy Prompt"), self._copy_prompt)
menu.addAction(_("Copy Strength"), self._copy_strength)
style_action = ensure(menu.addAction(_("Copy Style"), self._copy_style))
if job is None or Styles.list().find(job.params.style) is None:
style_action.setEnabled(False)
menu.addAction(_("Copy Seed"), self._copy_seed)
menu.addSeparator()
save_action = ensure(menu.addAction(_("Save Image"), self._save_image))
Expand Down Expand Up @@ -348,6 +372,11 @@ def _copy_strength(self):
if job := self.selected_job:
self._model.strength = job.params.strength

def _copy_style(self):
if job := self.selected_job:
if style := Styles.list().find(job.params.style):
self._model.style = style

def _copy_seed(self):
if job := self.selected_job:
self._model.fixed_seed = True
Expand Down
9 changes: 9 additions & 0 deletions ai_diffusion/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import asdict, is_dataclass
from itertools import islice
from pathlib import Path
from typing import Generator
import asyncio
import importlib.util
import os
Expand Down Expand Up @@ -124,6 +125,14 @@ def unique(seq: Sequence[T], key) -> list[T]:
return [x for x in seq if (k := key(x)) not in seen and not seen.add(k)]


def flatten(seq: Sequence[T | list[T]]) -> Generator[T, None, None]:
for x in seq:
if isinstance(x, list):
yield from x
else:
yield x


def trim_text(text: str, max_length: int) -> str:
if len(text) > max_length:
return text[: max_length - 3] + "..."
Expand Down

0 comments on commit 6124149

Please sign in to comment.