Skip to content

Commit 089dc66

Browse files
committed
Load previously used custom graph from document (somehow)
1 parent aed73dc commit 089dc66

File tree

4 files changed

+74
-31
lines changed

4 files changed

+74
-31
lines changed

ai_diffusion/custom_workflow.py

+36-17
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,14 @@ class WorkflowCollection(QAbstractListModel):
4848
_icon_remote = theme.icon("web-connection")
4949
_icon_document = theme.icon("file-kra")
5050

51+
loaded = pyqtSignal()
52+
5153
def __init__(self, connection: Connection, folder: Path | None = None):
5254
super().__init__()
5355
self._connection = connection
5456
self._folder = folder or user_data_dir / "workflows"
5557
self._workflows: list[CustomWorkflow] = []
58+
self._pending_workflows: list[tuple[str, WorkflowSource, dict]] = []
5659

5760
self._connection.state_changed.connect(self._handle_connection)
5861
self._connection.workflow_published.connect(self._process_remote_workflow)
@@ -63,6 +66,10 @@ def _handle_connection(self, state: ConnectionState):
6366
self.clear()
6467

6568
if state is ConnectionState.connected:
69+
for id, source, graph in self._pending_workflows:
70+
self._process_workflow(id, source, graph)
71+
self._pending_workflows.clear()
72+
6673
for file in self._folder.glob("*.json"):
6774
try:
6875
self._process_file(file)
@@ -72,31 +79,36 @@ def _handle_connection(self, state: ConnectionState):
7279
for wf in self._connection.workflows.keys():
7380
self._process_remote_workflow(wf)
7481

82+
self.loaded.emit()
83+
7584
def _node_inputs(self):
7685
return self._connection.client.models.node_inputs
7786

78-
def _create_workflow(
87+
def _process_workflow(
7988
self, id: str, source: WorkflowSource, graph: dict, path: Path | None = None
8089
):
81-
wf = ComfyWorkflow.import_graph(graph, self._node_inputs())
82-
return CustomWorkflow(id, source, wf, path)
83-
84-
def _process_remote_workflow(self, id: str):
85-
graph = self._connection.workflows[id]
86-
self._process(self._create_workflow(id, WorkflowSource.remote, graph))
87-
88-
def _process_file(self, file: Path):
89-
with file.open("r") as f:
90-
graph = json.load(f)
91-
self._process(self._create_workflow(file.stem, WorkflowSource.local, graph, file))
90+
if self._connection.state is not ConnectionState.connected:
91+
self._pending_workflows.append((id, source, graph))
92+
return
9293

93-
def _process(self, workflow: CustomWorkflow):
94+
comfy_flow = ComfyWorkflow.import_graph(graph, self._node_inputs())
95+
workflow = CustomWorkflow(id, source, comfy_flow, path)
9496
idx = self.find_index(workflow.id)
9597
if idx.isValid():
9698
self._workflows[idx.row()] = workflow
9799
self.dataChanged.emit(idx, idx)
98100
else:
99101
self.append(workflow)
102+
return idx
103+
104+
def _process_remote_workflow(self, id: str):
105+
graph = self._connection.workflows[id]
106+
self._process_workflow(id, WorkflowSource.remote, graph)
107+
108+
def _process_file(self, file: Path):
109+
with file.open("r") as f:
110+
graph = json.load(f)
111+
self._process_workflow(file.stem, WorkflowSource.local, graph, file)
100112

101113
def rowCount(self, parent=QModelIndex()):
102114
return len(self._workflows)
@@ -121,7 +133,7 @@ def append(self, item: CustomWorkflow):
121133
self.endInsertRows()
122134

123135
def add_from_document(self, id: str, graph: dict):
124-
self.append(self._create_workflow(id, WorkflowSource.document, graph))
136+
self._process_workflow(id, WorkflowSource.document, graph)
125137

126138
def remove(self, id: str):
127139
idx = self.find_index(id)
@@ -154,7 +166,7 @@ def save_as(self, id: str, graph: dict):
154166
self._folder.mkdir(exist_ok=True)
155167
path = self._folder / f"{id}.json"
156168
path.write_text(json.dumps(graph, indent=2))
157-
self.append(self._create_workflow(id, WorkflowSource.local, graph, path))
169+
self._process_workflow(id, WorkflowSource.local, graph, path)
158170
return id
159171

160172
def import_file(self, filepath: Path):
@@ -336,12 +348,16 @@ def __init__(self, workflows: WorkflowCollection, generator: ImageGenerator, job
336348

337349
jobs.job_finished.connect(self._handle_job_finished)
338350
workflows.dataChanged.connect(self._update_workflow)
339-
workflows.rowsInserted.connect(self._set_default_workflow)
351+
workflows.loaded.connect(self._set_default_workflow)
340352
self._set_default_workflow()
341353

342354
def _set_default_workflow(self):
343355
if not self.workflow_id and len(self._workflows) > 0:
344356
self.workflow_id = self._workflows[0].id
357+
else:
358+
current_index = self._workflows.find_index(self.workflow_id)
359+
if current_index.isValid():
360+
self._update_workflow(current_index, QModelIndex())
345361

346362
def _update_workflow(self, idx: QModelIndex, _: QModelIndex):
347363
wf = self._workflows[idx.row()]
@@ -358,10 +374,13 @@ def _set_workflow_id(self, id: str):
358374
self._workflow_id = id
359375
self.workflow_id_changed.emit(id)
360376
self.modified.emit(self, "workflow_id")
361-
self._update_workflow(self._workflows.find_index(id), QModelIndex())
377+
index = self._workflows.find_index(id)
378+
if index.isValid(): # might be invalid when loading document before connecting
379+
self._update_workflow(index, QModelIndex())
362380

363381
def set_graph(self, id: str, graph: dict):
364382
if self._workflows.find(id) is None:
383+
id = "Document Workflow (embedded)"
365384
self._workflows.add_from_document(id, graph)
366385
self.workflow_id = id
367386

ai_diffusion/persistence.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
from PyQt5.QtWidgets import QMessageBox
99

1010
from .api import InpaintMode, FillMode
11-
from .image import Bounds, Image, ImageCollection, ImageFileFormat
11+
from .image import ImageCollection
1212
from .model import Model, InpaintContext
13+
from .custom_workflow import CustomWorkspace
1314
from .control import ControlLayer, ControlLayerList
1415
from .region import RootRegion, Region
1516
from .jobs import Job, JobKind, JobParams, JobQueue
@@ -132,7 +133,7 @@ def _save(self):
132133
state["upscale"] = _serialize(model.upscale)
133134
state["live"] = _serialize(model.live)
134135
state["animation"] = _serialize(model.animation)
135-
state["custom"] = _serialize(model.custom)
136+
state["custom"] = _serialize_custom(model.custom)
136137
state["history"] = [asdict(h) for h in self._history]
137138
state["root"] = _serialize(model.regions)
138139
state["control"] = [_serialize(c) for c in model.regions.control]
@@ -151,7 +152,7 @@ def _load(self, model: Model, state_bytes: bytes):
151152
_deserialize(model.upscale, state.get("upscale", {}))
152153
_deserialize(model.live, state.get("live", {}))
153154
_deserialize(model.animation, state.get("animation", {}))
154-
_deserialize(model.custom, state.get("custom", {}))
155+
_deserialize_custom(model.custom, state.get("custom", {}))
155156
_deserialize(model.regions, state.get("root", {}))
156157
for control_state in state.get("control", []):
157158
_deserialize(model.regions.control.emplace(), control_state)
@@ -264,6 +265,21 @@ def converter(type, value):
264265
return deserialize(obj, data, converter)
265266

266267

268+
def _serialize_custom(custom: CustomWorkspace):
269+
result = _serialize(custom)
270+
result["workflow_id"] = custom.workflow_id
271+
result["graph"] = custom.graph.root if custom.graph else None
272+
return result
273+
274+
275+
def _deserialize_custom(custom: CustomWorkspace, data: dict[str, Any]):
276+
_deserialize(custom, data)
277+
workflow_id = data.get("workflow_id", "")
278+
graph = data.get("graph", None)
279+
if workflow_id and graph:
280+
custom.set_graph(workflow_id, graph)
281+
282+
267283
def _find_annotation(document, name: str):
268284
if result := document.find_annotation(name):
269285
return result

ai_diffusion/ui/custom_workflow.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def __init__(self, param: CustomParam, parent: QWidget | None = None):
8787
self._widget.setMinimumHeight(self._widget.minimumSizeHint().height() + 4)
8888
self._widget.valueChanged.connect(self._notify)
8989
self._label = QLabel(self)
90-
self._label.setFixedWidth(40)
90+
self._label.setFixedWidth(32)
9191
self._label.setAlignment(Qt.AlignmentFlag.AlignRight)
9292
layout.addWidget(self._widget)
9393
layout.addWidget(self._label)
@@ -135,7 +135,7 @@ def __init__(self, param: CustomParam, parent: QWidget | None = None):
135135
self._widget.setMinimumHeight(self._widget.minimumSizeHint().height() + 4)
136136
self._widget.valueChanged.connect(self._notify)
137137
self._label = QLabel(self)
138-
self._label.setFixedWidth(40)
138+
self._label.setFixedWidth(32)
139139
self._label.setAlignment(Qt.AlignmentFlag.AlignRight)
140140
layout.addWidget(self._widget)
141141
layout.addWidget(self._label)

tests/test_custom_workflow.py

+17-9
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,24 @@ def test_collection(tmp_path: Path):
8383
connection = create_mock_connection(connection_workflows, state=ConnectionState.disconnected)
8484

8585
collection = WorkflowCollection(connection, tmp_path)
86+
events = []
87+
8688
assert len(collection) == 0
89+
90+
def on_loaded():
91+
events.append("loaded")
92+
93+
collection.loaded.connect(on_loaded)
94+
doc_graph = {"0": {"class_type": "D1", "inputs": {}}}
95+
collection.add_from_document("doc1", doc_graph)
96+
8797
connection.state = ConnectionState.connected
88-
assert len(collection) == 3
98+
assert len(collection) == 4
99+
assert events == ["loaded"]
89100
_assert_has_workflow(collection, "file1", WorkflowSource.local, file1_graph, file1)
90101
_assert_has_workflow(collection, "file2", WorkflowSource.local, file2_graph, file2)
91102
_assert_has_workflow(collection, "connection1", WorkflowSource.remote, connection_graph)
92-
93-
events = []
103+
_assert_has_workflow(collection, "doc1", WorkflowSource.document, doc_graph)
94104

95105
def on_begin_insert(index, first, last):
96106
events.append(("begin_insert", first))
@@ -109,15 +119,13 @@ def on_data_changed(start, end):
109119
connection_workflows["connection2"] = connection2_graph
110120
connection.workflow_published.emit("connection2")
111121

112-
assert len(collection) == 4
122+
assert len(collection) == 5
113123
_assert_has_workflow(collection, "connection2", WorkflowSource.remote, connection2_graph)
114124

115125
file1_graph_changed = {"0": {"class_type": "F3", "inputs": {}}}
116-
collection.set_graph(collection.index(0), file1_graph_changed)
126+
collection.set_graph(collection.find_index("file1"), file1_graph_changed)
117127
_assert_has_workflow(collection, "file1", WorkflowSource.local, file1_graph_changed, file1)
118-
assert events == [("begin_insert", 3), "end_insert", ("data_changed", 0)]
119-
120-
collection.add_from_document("doc1", {"0": {"class_type": "D1", "inputs": {}}})
128+
assert events == ["loaded", ("begin_insert", 4), "end_insert", ("data_changed", 1)]
121129

122130
sorted = SortedWorkflows(collection)
123131
assert sorted[0].source is WorkflowSource.document
@@ -207,7 +215,7 @@ def test_workspace():
207215
}
208216
}
209217
workspace.set_graph("doc1", doc_graph)
210-
assert workspace.workflow_id == "doc1"
218+
assert workspace.workflow_id == "Document Workflow (embedded)"
211219
assert workspace.workflow and workspace.workflow.source is WorkflowSource.document
212220
assert workspace.graph and workspace.graph.node(0).type == "ETN_Parameter"
213221
assert workspace.metadata[0].name == "param2"

0 commit comments

Comments
 (0)