From 4e9bab7123350bde0a9f70213aef42048db5a50f Mon Sep 17 00:00:00 2001 From: Acly Date: Sun, 17 Nov 2024 12:15:59 +0100 Subject: [PATCH] Update upscale button text/enabled based on context #1269 --- ai_diffusion/model.py | 25 +++++++++++++++++++------ ai_diffusion/ui/upscale.py | 7 +++++++ ai_diffusion/ui/widget.py | 19 ++++++++++++++----- 3 files changed, 40 insertions(+), 11 deletions(-) diff --git a/ai_diffusion/model.py b/ai_diffusion/model.py index 377db6159..9aa592e1d 100644 --- a/ai_diffusion/model.py +++ b/ai_diffusion/model.py @@ -277,7 +277,7 @@ def upscale_image(self): eventloop.run(_report_errors(self, self._enqueue_job(job, inputs))) self._doc.resize(job.params.bounds.extent) - self.upscale.can_generate = False + self.upscale.set_in_progress(True) self.upscale.target_extent_changed.emit(self.upscale.target_extent) def estimate_cost(self, kind=JobKind.diffusion): @@ -482,7 +482,7 @@ def handle_message(self, message: ClientMessage): def _finish_job(self, job: Job, event: ClientEvent): if job.kind is JobKind.upscaling: - self.upscale.can_generate = True + self.upscale.set_in_progress(False) if event is ClientEvent.finished: self.jobs.notify_finished(job) @@ -757,7 +757,7 @@ class UpscaleParams(NamedTuple): class UpscaleWorkspace(QObject, ObservableProperties): upscaler = Property("", persist=True) - factor = Property(2.0, persist=True) + factor = Property(2.0, persist=True, setter="_set_factor") use_diffusion = Property(True, persist=True) strength = Property(0.3, persist=True) unblur_strength = Property(1, persist=True) @@ -774,12 +774,11 @@ class UpscaleWorkspace(QObject, ObservableProperties): can_generate_changed = pyqtSignal(bool) modified = pyqtSignal(QObject, str) - _model: Model - def __init__(self, model: Model): super().__init__() self._model = model - self.factor_changed.connect(lambda _: self.target_extent_changed.emit(self.target_extent)) + self._in_progress = False + self.use_diffusion_changed.connect(self._update_can_generate) self._init_model() model._connection.models_changed.connect(self._init_model) @@ -788,6 +787,20 @@ def _init_model(self): if self.upscaler not in client.models.upscalers: self.upscaler = client.models.default_upscaler + def set_in_progress(self, in_progress: bool): + self._in_progress = in_progress + self._update_can_generate() + + def _set_factor(self, value: float): + if self._factor != value: + self._factor = value + self.factor_changed.emit(value) + self.target_extent_changed.emit(self.target_extent) + self._update_can_generate() + + def _update_can_generate(self): + self.can_generate = not self._in_progress and (self.factor > 1.0 or self.use_diffusion) + @property def target_extent(self): return self._model.document.extent * self.factor diff --git a/ai_diffusion/ui/upscale.py b/ai_diffusion/ui/upscale.py index 77d800f3c..1f43a6d94 100644 --- a/ai_diffusion/ui/upscale.py +++ b/ai_diffusion/ui/upscale.py @@ -123,6 +123,7 @@ def __init__(self): layout.addLayout(model_layout) self.factor_widget = FactorWidget(self) + self.factor_widget.value_changed.connect(self._update_factor) layout.addWidget(self.factor_widget) self.refinement_checkbox = QGroupBox(_("Refine upscaled image"), self) @@ -282,6 +283,12 @@ def _update_prompt(self): self.prompt_warning.hide() set_text_clipped(self.prompt_label, text, padding=padding) + def _update_factor(self): + if self.factor_widget.value == 1.0 and self.model.upscale.use_diffusion: + self.upscale_button.operation = _("Refine") + else: + self.upscale_button.operation = _("Upscale") + def _upscaler_order(filename: str): return { diff --git a/ai_diffusion/ui/widget.py b/ai_diffusion/ui/widget.py index f85470224..6a6ff4aa0 100644 --- a/ai_diffusion/ui/widget.py +++ b/ai_diffusion/ui/widget.py @@ -739,7 +739,7 @@ def _create_action(self, name: str, workspace: Workspace): class GenerateButton(QPushButton): model: Model - operation: str + _operation: str _kind: JobKind _cost: int = 0 _cost_icon: QIcon @@ -747,14 +747,23 @@ class GenerateButton(QPushButton): def __init__(self, kind: JobKind, parent: QWidget): super().__init__(parent) self.model = root.active_model - self.operation = _("Generate") + self._operation = _("Generate") self._kind = kind self._cost_icon = theme.icon("interstice") self.setAttribute(Qt.WidgetAttribute.WA_Hover) + @property + def operation(self): + return self._operation + + @operation.setter + def operation(self, value: str): + self._operation = value + self.update() + def minimumSizeHint(self): fm = self.fontMetrics() - return QSize(fm.width(self.operation) + 40, 12 + int(1.3 * fm.height())) + return QSize(fm.width(self._operation) + 40, 12 + int(1.3 * fm.height())) def enterEvent(self, a0: QEvent | None): if client := root.connection.client_if_connected: @@ -776,12 +785,12 @@ def paintEvent(self, a0: QPaintEvent | None) -> None: is_hover = int(opt.state) & QStyle.StateFlag.State_MouseOver element = QStyle.PrimitiveElement.PE_PanelButtonCommand vcenter = Qt.AlignmentFlag.AlignVCenter - content_width = fm.width(self.operation) + 5 + pixmap.width() + content_width = fm.width(self._operation) + 5 + pixmap.width() content_rect = rect.adjusted(int(0.5 * (rect.width() - content_width)), 0, 0, 0) style.drawPrimitive(element, opt, painter, self) style.drawItemPixmap(painter, content_rect, vcenter, pixmap) content_rect = content_rect.adjusted(pixmap.width() + 5, 0, 0, 0) - style.drawItemText(painter, content_rect, vcenter, self.palette(), True, self.operation) + style.drawItemText(painter, content_rect, vcenter, self.palette(), True, self._operation) if is_hover and self._cost > 0: cost_width = fm.width(str(self._cost))