Skip to content

Commit

Permalink
Fix some custom parameters being reset when the graph changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Acly committed Oct 18, 2024
1 parent 6b37477 commit da39321
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 37 deletions.
11 changes: 2 additions & 9 deletions ai_diffusion/comfy_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from .image import Bounds, Extent, Image
from .resources import Arch, ControlMode
from .util import base_type_match


class ComfyRunMode(Enum):
Expand Down Expand Up @@ -40,21 +41,13 @@ def input(self, key: str, default: None = None) -> Input | None: ...

def input(self, key: str, default: T | None = None) -> T | Input | None:
result = self.inputs.get(key, default)
assert (
default is None
or type(result) == type(default)
or (isnumber(result) and isnumber(default))
)
assert default is None or base_type_match(result, default)
return result

def output(self, index=0) -> Output:
return Output(int(self.id), index)


def isnumber(x):
return isinstance(x, (int, float))


class ComfyWorkflow:
"""Builder for workflows which can be sent to the ComfyUI prompt API."""

Expand Down
6 changes: 4 additions & 2 deletions ai_diffusion/custom_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .jobs import Job, JobParams, JobQueue, JobKind
from .properties import Property, ObservableProperties
from .style import Styles
from .util import user_data_dir, client_logger as log
from .util import base_type_match, user_data_dir, client_logger as log
from .ui import theme
from . import eventloop

Expand Down Expand Up @@ -492,7 +492,9 @@ def live_result(self):

def _coerce(params: dict[str, Any], types: list[CustomParam]):
def use(value, default):
if value is None or not type(value) == type(default):
if default is None:
return value
if value is None or not base_type_match(value, default):
return default
return value

Expand Down
57 changes: 31 additions & 26 deletions ai_diffusion/ui/custom_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ..root import root
from ..settings import settings
from ..localization import translate as _
from ..util import ensure, clamp
from ..util import ensure, clamp, base_type_match
from .generation import GenerateButton, ProgressBar, QueueButton, HistoryWidget, create_error_label
from .live import LivePreviewArea
from .switch import SwitchWidget
Expand Down Expand Up @@ -51,8 +51,11 @@ def _update(self):
assert False, f"Unknown filter: {self.filter}"

for l in layers:
if self.findData(l.id) == -1:
index = self.findData(l.id)
if index == -1:
self.addItem(l.name, l.id)
elif self.itemText(index) != l.name:
self.setItemText(index, l.name)
i = 0
while i < self.count():
if self.itemData(i) not in (l.id for l in layers):
Expand Down Expand Up @@ -116,8 +119,8 @@ def value(self):
return self._widget.value()

@value.setter
def value(self, value: int):
self._widget.setValue(value)
def value(self, value: int | float):
self._widget.setValue(int(value))


class FloatParamWidget(QWidget):
Expand Down Expand Up @@ -164,11 +167,11 @@ def value(self):
return self._widget.value()

@value.setter
def value(self, value: float):
def value(self, value: float | int):
if isinstance(self._widget, QSlider):
self._widget.setValue(round(value * 100))
else:
self._widget.setValue(value)
self._widget.setValue(float(value))


class BoolParamWidget(QWidget):
Expand Down Expand Up @@ -321,25 +324,27 @@ def value(self, value: str):


def _create_param_widget(param: CustomParam, parent: QWidget) -> CustomParamWidget:
if param.kind is ParamKind.image_layer:
return LayerSelect("image", parent)
if param.kind is ParamKind.mask_layer:
return LayerSelect("mask", parent)
if param.kind is ParamKind.number_int:
return IntParamWidget(param, parent)
if param.kind is ParamKind.number_float:
return FloatParamWidget(param, parent)
if param.kind is ParamKind.toggle:
return BoolParamWidget(param, parent)
if param.kind is ParamKind.text:
return TextParamWidget(param, parent)
if param.kind in [ParamKind.prompt_positive, ParamKind.prompt_negative]:
return PromptParamWidget(param, parent)
if param.kind is ParamKind.choice:
return ChoiceParamWidget(param, parent)
if param.kind is ParamKind.style:
return StyleParamWidget(parent)
assert False, f"Unknown param kind: {param.kind}"
match param.kind:
case ParamKind.image_layer:
return LayerSelect("image", parent)
case ParamKind.mask_layer:
return LayerSelect("mask", parent)
case ParamKind.number_int:
return IntParamWidget(param, parent)
case ParamKind.number_float:
return FloatParamWidget(param, parent)
case ParamKind.toggle:
return BoolParamWidget(param, parent)
case ParamKind.text:
return TextParamWidget(param, parent)
case ParamKind.prompt_positive | ParamKind.prompt_negative:
return PromptParamWidget(param, parent)
case ParamKind.choice:
return ChoiceParamWidget(param, parent)
case ParamKind.style:
return StyleParamWidget(parent)
case _:
assert False, f"Unknown param kind: {param.kind}"


class WorkflowParamsWidget(QWidget):
Expand Down Expand Up @@ -375,7 +380,7 @@ def value(self):
def value(self, values: dict[str, Any]):
for name, value in values.items():
if widget := self._widgets.get(name):
if type(widget.value) == type(value):
if base_type_match(widget.value, value):
widget.value = value


Expand Down
8 changes: 8 additions & 0 deletions ai_diffusion/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,14 @@ def median_or_zero(values: Iterable[float]) -> float:
return 0


def isnumber(x):
return isinstance(x, (int, float))


def base_type_match(a, b):
return type(a) == type(b) or (isnumber(a) and isnumber(b))


def unique(seq: Sequence[T], key) -> list[T]:
seen = set()
return [x for x in seq if (k := key(x)) not in seen and not seen.add(k)]
Expand Down

0 comments on commit da39321

Please sign in to comment.