|
10 | 10 | from PyQt5.QtCore import pyqtSignal
|
11 | 11 |
|
12 | 12 | from .api import WorkflowInput
|
13 |
| -from .comfy_workflow import ComfyWorkflow |
14 |
| -from .connection import Connection |
| 13 | +from .comfy_workflow import ComfyWorkflow, ComfyNode |
| 14 | +from .connection import Connection, ConnectionState |
15 | 15 | from .image import Bounds, Image
|
16 | 16 | from .jobs import Job, JobParams, JobQueue, JobKind
|
17 | 17 | from .properties import Property, ObservableProperties
|
@@ -51,23 +51,29 @@ class WorkflowCollection(QAbstractListModel):
|
51 | 51 | def __init__(self, connection: Connection, folder: Path | None = None):
|
52 | 52 | super().__init__()
|
53 | 53 | self._connection = connection
|
54 |
| - self._workflows: list[CustomWorkflow] = [] |
55 |
| - |
56 | 54 | self._folder = folder or user_data_dir / "workflows"
|
57 |
| - for file in self._folder.glob("*.json"): |
58 |
| - try: |
59 |
| - self._process_file(file) |
60 |
| - except Exception as e: |
61 |
| - log.exception(f"Error loading workflow from {file}: {e}") |
| 55 | + self._workflows: list[CustomWorkflow] = [] |
62 | 56 |
|
| 57 | + self._connection.state_changed.connect(self._handle_connection) |
63 | 58 | self._connection.workflow_published.connect(self._process_remote_workflow)
|
64 |
| - for wf in self._connection.workflows.keys(): |
65 |
| - self._process_remote_workflow(wf) |
| 59 | + self._handle_connection(self._connection.state) |
| 60 | + |
| 61 | + def _handle_connection(self, state: ConnectionState): |
| 62 | + if state in (ConnectionState.connected, ConnectionState.disconnected): |
| 63 | + self.clear() |
| 64 | + |
| 65 | + if state is ConnectionState.connected: |
| 66 | + for file in self._folder.glob("*.json"): |
| 67 | + try: |
| 68 | + self._process_file(file) |
| 69 | + except Exception as e: |
| 70 | + log.exception(f"Error loading workflow from {file}: {e}") |
| 71 | + |
| 72 | + for wf in self._connection.workflows.keys(): |
| 73 | + self._process_remote_workflow(wf) |
66 | 74 |
|
67 | 75 | def _node_inputs(self):
|
68 |
| - if client := self._connection.client_if_connected: |
69 |
| - return client.models.node_inputs |
70 |
| - return {} |
| 76 | + return self._connection.client.models.node_inputs |
71 | 77 |
|
72 | 78 | def _create_workflow(
|
73 | 79 | self, id: str, source: WorkflowSource, graph: dict, path: Path | None = None
|
@@ -127,6 +133,12 @@ def remove(self, id: str):
|
127 | 133 | self._workflows.pop(idx.row())
|
128 | 134 | self.endRemoveRows()
|
129 | 135 |
|
| 136 | + def clear(self): |
| 137 | + if len(self._workflows) > 0: |
| 138 | + self.beginResetModel() |
| 139 | + self._workflows.clear() |
| 140 | + self.endResetModel() |
| 141 | + |
130 | 142 | def set_graph(self, index: QModelIndex, graph: dict):
|
131 | 143 | wf = self._workflows[index.row()]
|
132 | 144 | wf.workflow = ComfyWorkflow.import_graph(graph, self._node_inputs())
|
@@ -266,20 +278,24 @@ def workflow_parameters(w: ComfyWorkflow):
|
266 | 278 | case ("ETN_Parameter", "choice"):
|
267 | 279 | name = node.input("name", "Parameter")
|
268 | 280 | default = node.input("default", "")
|
269 |
| - connected, input_name = next(w.find_connected(node.output()), (None, "")) |
270 |
| - if connected: |
271 |
| - if input_type := w.input_type(connected.type, input_name): |
272 |
| - if isinstance(input_type[0], list): |
273 |
| - yield CustomParam( |
274 |
| - ParamKind.choice, name, choices=input_type[0], default=default |
275 |
| - ) |
| 281 | + if choices := _get_choices(w, node): |
| 282 | + yield CustomParam(ParamKind.choice, name, choices=choices, default=default) |
276 | 283 | else:
|
277 | 284 | yield CustomParam(ParamKind.text, name, default=default)
|
278 | 285 | case ("ETN_Parameter", unknown_type) if unknown_type != "auto":
|
279 | 286 | unknown = node.input("name", "?") + ": " + unknown_type
|
280 | 287 | log.warning(f"Custom workflow has an unsupported parameter type {unknown}")
|
281 | 288 |
|
282 | 289 |
|
| 290 | +def _get_choices(w: ComfyWorkflow, node: ComfyNode): |
| 291 | + connected, input_name = next(w.find_connected(node.output()), (None, "")) |
| 292 | + if connected: |
| 293 | + if input_type := w.input_type(connected.type, input_name): |
| 294 | + if isinstance(input_type[0], list): |
| 295 | + return input_type[0] |
| 296 | + return None |
| 297 | + |
| 298 | + |
283 | 299 | ImageGenerator = Callable[[WorkflowInput | None], Awaitable[None | Literal[False] | WorkflowInput]]
|
284 | 300 |
|
285 | 301 |
|
|
0 commit comments