Skip to content

Commit aed73dc

Browse files
committed
Only try to analyse workflows after connecting to get node info
1 parent 70dbe98 commit aed73dc

File tree

3 files changed

+81
-30
lines changed

3 files changed

+81
-30
lines changed

ai_diffusion/custom_workflow.py

+37-21
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from PyQt5.QtCore import pyqtSignal
1111

1212
from .api import WorkflowInput
13-
from .comfy_workflow import ComfyWorkflow
14-
from .connection import Connection
13+
from .comfy_workflow import ComfyWorkflow, ComfyNode
14+
from .connection import Connection, ConnectionState
1515
from .image import Bounds, Image
1616
from .jobs import Job, JobParams, JobQueue, JobKind
1717
from .properties import Property, ObservableProperties
@@ -51,23 +51,29 @@ class WorkflowCollection(QAbstractListModel):
5151
def __init__(self, connection: Connection, folder: Path | None = None):
5252
super().__init__()
5353
self._connection = connection
54-
self._workflows: list[CustomWorkflow] = []
55-
5654
self._folder = folder or user_data_dir / "workflows"
57-
for file in self._folder.glob("*.json"):
58-
try:
59-
self._process_file(file)
60-
except Exception as e:
61-
log.exception(f"Error loading workflow from {file}: {e}")
55+
self._workflows: list[CustomWorkflow] = []
6256

57+
self._connection.state_changed.connect(self._handle_connection)
6358
self._connection.workflow_published.connect(self._process_remote_workflow)
64-
for wf in self._connection.workflows.keys():
65-
self._process_remote_workflow(wf)
59+
self._handle_connection(self._connection.state)
60+
61+
def _handle_connection(self, state: ConnectionState):
62+
if state in (ConnectionState.connected, ConnectionState.disconnected):
63+
self.clear()
64+
65+
if state is ConnectionState.connected:
66+
for file in self._folder.glob("*.json"):
67+
try:
68+
self._process_file(file)
69+
except Exception as e:
70+
log.exception(f"Error loading workflow from {file}: {e}")
71+
72+
for wf in self._connection.workflows.keys():
73+
self._process_remote_workflow(wf)
6674

6775
def _node_inputs(self):
68-
if client := self._connection.client_if_connected:
69-
return client.models.node_inputs
70-
return {}
76+
return self._connection.client.models.node_inputs
7177

7278
def _create_workflow(
7379
self, id: str, source: WorkflowSource, graph: dict, path: Path | None = None
@@ -127,6 +133,12 @@ def remove(self, id: str):
127133
self._workflows.pop(idx.row())
128134
self.endRemoveRows()
129135

136+
def clear(self):
137+
if len(self._workflows) > 0:
138+
self.beginResetModel()
139+
self._workflows.clear()
140+
self.endResetModel()
141+
130142
def set_graph(self, index: QModelIndex, graph: dict):
131143
wf = self._workflows[index.row()]
132144
wf.workflow = ComfyWorkflow.import_graph(graph, self._node_inputs())
@@ -266,20 +278,24 @@ def workflow_parameters(w: ComfyWorkflow):
266278
case ("ETN_Parameter", "choice"):
267279
name = node.input("name", "Parameter")
268280
default = node.input("default", "")
269-
connected, input_name = next(w.find_connected(node.output()), (None, ""))
270-
if connected:
271-
if input_type := w.input_type(connected.type, input_name):
272-
if isinstance(input_type[0], list):
273-
yield CustomParam(
274-
ParamKind.choice, name, choices=input_type[0], default=default
275-
)
281+
if choices := _get_choices(w, node):
282+
yield CustomParam(ParamKind.choice, name, choices=choices, default=default)
276283
else:
277284
yield CustomParam(ParamKind.text, name, default=default)
278285
case ("ETN_Parameter", unknown_type) if unknown_type != "auto":
279286
unknown = node.input("name", "?") + ": " + unknown_type
280287
log.warning(f"Custom workflow has an unsupported parameter type {unknown}")
281288

282289

290+
def _get_choices(w: ComfyWorkflow, node: ComfyNode):
291+
connected, input_name = next(w.find_connected(node.output()), (None, ""))
292+
if connected:
293+
if input_type := w.input_type(connected.type, input_name):
294+
if isinstance(input_type[0], list):
295+
return input_type[0]
296+
return None
297+
298+
283299
ImageGenerator = Callable[[WorkflowInput | None], Awaitable[None | Literal[False] | WorkflowInput]]
284300

285301

ai_diffusion/root.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def init(self):
4040
self._files = FileLibrary.load()
4141
self._workflows = WorkflowCollection(self._connection)
4242
self._models = []
43+
self._null_model = Model(Document(), self._connection, self._workflows)
4344
self._recent = RecentlyUsedSync.from_settings()
4445
self._auto_update = AutoUpdate()
4546
if settings.auto_update:
@@ -96,7 +97,7 @@ def auto_update(self) -> AutoUpdate:
9697
def active_model(self):
9798
if model := self.model_for_active_document():
9899
return model
99-
return Model(Document(), self._connection, self._workflows)
100+
return self._null_model
100101

101102
async def autostart(self, signal_server_change: Callable):
102103
connection = self._connection

tests/test_custom_workflow.py

+42-8
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
from pathlib import Path
44
from PyQt5.QtCore import Qt
55

6-
from ai_diffusion.api import CustomWorkflowInput, ImageInput, SamplingInput
7-
from ai_diffusion.client import ClientModels, CheckpointInfo
8-
from ai_diffusion.connection import Connection
6+
from ai_diffusion.api import CustomWorkflowInput, ImageInput, WorkflowInput
7+
from ai_diffusion.client import Client, ClientModels, CheckpointInfo
8+
from ai_diffusion.connection import Connection, ConnectionState
99
from ai_diffusion.comfy_workflow import ComfyNode, ComfyWorkflow, Output
1010
from ai_diffusion.custom_workflow import CustomWorkflow, WorkflowSource, WorkflowCollection
1111
from ai_diffusion.custom_workflow import SortedWorkflows, CustomWorkspace
@@ -19,6 +19,40 @@
1919
from .config import test_dir
2020

2121

22+
class MockClient(Client):
23+
def __init__(self, node_inputs: dict[str, dict]):
24+
self.models = ClientModels()
25+
self.models.node_inputs = node_inputs
26+
27+
@staticmethod
28+
async def connect(url: str, access_token: str = "") -> Client:
29+
return MockClient({})
30+
31+
async def enqueue(self, work: WorkflowInput, front: bool = False) -> str:
32+
return ""
33+
34+
async def listen(self): # type: ignore
35+
return
36+
37+
async def interrupt(self):
38+
pass
39+
40+
async def clear_queue(self):
41+
pass
42+
43+
44+
def create_mock_connection(
45+
initial_workflows: dict[str, dict],
46+
node_inputs: dict[str, dict] | None = None,
47+
state: ConnectionState = ConnectionState.connected,
48+
):
49+
connection = Connection()
50+
connection._client = MockClient(node_inputs or {})
51+
connection._workflows = initial_workflows
52+
connection.state = state
53+
return connection
54+
55+
2256
def _assert_has_workflow(
2357
collection: WorkflowCollection,
2458
name: str,
@@ -44,12 +78,13 @@ def test_collection(tmp_path: Path):
4478
file2_graph = {"0": {"class_type": "F2", "inputs": {}}}
4579
file2.write_text(json.dumps(file2_graph))
4680

47-
connection = Connection()
4881
connection_graph = {"0": {"class_type": "C1", "inputs": {}}}
4982
connection_workflows = {"connection1": connection_graph}
50-
connection._workflows = connection_workflows
83+
connection = create_mock_connection(connection_workflows, state=ConnectionState.disconnected)
5184

5285
collection = WorkflowCollection(connection, tmp_path)
86+
assert len(collection) == 0
87+
connection.state = ConnectionState.connected
5388
assert len(collection) == 3
5489
_assert_has_workflow(collection, "file1", WorkflowSource.local, file1_graph, file1)
5590
_assert_has_workflow(collection, "file2", WorkflowSource.local, file2_graph, file2)
@@ -110,7 +145,7 @@ def make_dummy_graph(n: int = 42):
110145
def test_files(tmp_path: Path):
111146
collection_folder = tmp_path / "workflows"
112147

113-
collection = WorkflowCollection(Connection(), collection_folder)
148+
collection = WorkflowCollection(create_mock_connection({}, {}), collection_folder)
114149
assert len(collection) == 0
115150

116151
file1 = tmp_path / "file1.json"
@@ -147,9 +182,8 @@ async def dummy_generate(workflow_input):
147182

148183

149184
def test_workspace():
150-
connection = Connection()
151185
connection_workflows = {"connection1": make_dummy_graph(42)}
152-
connection._workflows = connection_workflows
186+
connection = create_mock_connection(connection_workflows, {})
153187
workflows = WorkflowCollection(connection)
154188

155189
jobs = JobQueue()

0 commit comments

Comments
 (0)