diff --git a/ai_diffusion/comfy_client.py b/ai_diffusion/comfy_client.py index 5c26d1b1cc..38bee0ef2b 100644 --- a/ai_diffusion/comfy_client.py +++ b/ai_diffusion/comfy_client.py @@ -126,7 +126,7 @@ async def connect(url=default_url, access_token=""): # Check for required and optional model resources models = client.models - models.node_inputs = {name: nodes[name]["input"].get("required", None) for name in nodes} + models.node_inputs = {name: nodes[name]["input"] for name in nodes} available_resources = client.models.resources = {} clip_models = nodes["DualCLIPLoader"]["input"]["required"]["clip_name1"][0] diff --git a/ai_diffusion/comfy_workflow.py b/ai_diffusion/comfy_workflow.py index a986b1bcaa..09eb729986 100644 --- a/ai_diffusion/comfy_workflow.py +++ b/ai_diffusion/comfy_workflow.py @@ -64,17 +64,22 @@ def __init__(self, node_inputs: dict | None = None, run_mode=ComfyRunMode.server self.node_count = 0 self.sample_count = 0 self._cache: dict[str, Output | Output2 | Output3 | Output4] = {} - self._nodes_required_inputs: dict[str, dict[str, Any]] = node_inputs or {} + self._nodes_inputs: dict[str, dict[str, Any]] = node_inputs or {} self._run_mode: ComfyRunMode = run_mode @staticmethod - def import_graph(existing: dict): - w = ComfyWorkflow() + def import_graph(existing: dict, node_inputs: dict): + w = ComfyWorkflow(node_inputs) + existing = _convert_ui_workflow(existing, node_inputs) node_map: dict[str, str] = {} queue = list(existing.keys()) while queue: id = queue.pop(0) node = deepcopy(existing[id]) + if node_inputs and node["class_type"] not in node_inputs: + raise ValueError( + f"Workflow contains a node of type {node['class_type']} which is not installed on the ComfyUI server." + ) edges = [e for e in node["inputs"].values() if isinstance(e, list)] if any(e[0] not in node_map for e in edges): queue.append(id) # requeue node if an input is not yet mapped @@ -94,7 +99,7 @@ def from_dict(existing: dict): return w def add_default_values(self, node_name: str, args: dict): - if node_inputs := self._nodes_required_inputs.get(node_name, None): + if node_inputs := _inputs_for_node(self._nodes_inputs, node_name, "required"): for k, v in node_inputs.items(): if k not in args: if len(v) == 1 and isinstance(v[0], list) and len(v[0]) > 0: @@ -834,3 +839,59 @@ def estimate_pose(self, image: Output, resolution: int): # use smaller model, but it requires onnxruntime, see #630 mdls["bbox_detector"] = "yolo_nas_l_fp16.onnx" return self.add("DWPreprocessor", 1, image=image, resolution=resolution, **feat, **mdls) + + +def _inputs_for_node(node_inputs: dict[str, dict[str, Any]], node_name: str, filter=""): + inputs = node_inputs.get(node_name) + if inputs is None: + return None + if filter: + return inputs.get(filter) + result = inputs.get("required", {}) + result.update(inputs.get("optional", {})) + return result + + +def _convert_ui_workflow(w: dict, node_inputs: dict): + version = w.get("version") + nodes = w.get("nodes") + links = w.get("links") + if not (version and nodes and links): + return w + + primitives = {} + for node in nodes: + if node["type"] == "PrimitiveNode": + primitives[node["id"]] = node["widgets_values"][0] + + r = {} + for node in nodes: + id = node["id"] + type = node["type"] + if type == "PrimitiveNode": + continue + + inputs = {} + fields = _inputs_for_node(node_inputs, type) + if fields is None: + raise ValueError( + f"Workflow uses node type {type}, but it is not installed on the ComfyUI server." + ) + widget_count = 0 + for field_name, field in fields.items(): + field_type = field[0] + if field_type in ["INT", "FLOAT", "BOOL", "STRING"] or isinstance(field_type, list): + inputs[field_name] = node["widgets_values"][widget_count] + widget_count += 1 + for connection in node["inputs"]: + if connection["name"] == field_name and connection["link"] is not None: + link = next(l for l in links if l[0] == connection["link"]) + prim = primitives.get(link[1]) + if prim is not None: + inputs[field_name] = prim + else: + inputs[field_name] = [link[1], link[2]] + break + r[id] = {"class_type": type, "inputs": inputs} + + return r diff --git a/ai_diffusion/custom_workflow.py b/ai_diffusion/custom_workflow.py index a6fdd3c06d..31c0c74b00 100644 --- a/ai_diffusion/custom_workflow.py +++ b/ai_diffusion/custom_workflow.py @@ -25,16 +25,18 @@ class CustomWorkflow: id: str source: WorkflowSource graph: dict + workflow: ComfyWorkflow path: Path | None = None + @staticmethod + def from_api(id: str, source: WorkflowSource, graph: dict, path: Path | None = None): + # doesn't work for UI workflow export (API workflow only) + return CustomWorkflow(id, source, graph, ComfyWorkflow.import_graph(graph, {}), path) + @property def name(self): return self.id.removesuffix(".json") - @property - def workflow(self): - return ComfyWorkflow.import_graph(self.graph) - class WorkflowCollection(QAbstractListModel): @@ -58,12 +60,20 @@ def __init__(self, connection: Connection, folder: Path | None = None): for wf in self._connection.workflows.keys(): self._process_remote_workflow(wf) + def _create_workflow( + self, id: str, source: WorkflowSource, graph: dict, path: Path | None = None + ): + wf = ComfyWorkflow.import_graph(graph, self._connection.client.models.node_inputs) + return CustomWorkflow(id, source, graph, wf, path) + def _process_remote_workflow(self, id: str): - self._process(CustomWorkflow(id, WorkflowSource.remote, self._connection.workflows[id])) + graph = self._connection.workflows[id] + self._process(self._create_workflow(id, WorkflowSource.remote, graph)) def _process_file(self, file: Path): with file.open("r") as f: - self._process(CustomWorkflow(file.stem, WorkflowSource.local, json.load(f), file)) + graph = json.load(f) + self._process(self._create_workflow(file.stem, WorkflowSource.local, graph, file)) def _process(self, workflow: CustomWorkflow): idx = self.find_index(workflow.id) @@ -94,6 +104,9 @@ def append(self, item: CustomWorkflow): self._workflows.append(item) self.endInsertRows() + def add_from_document(self, id: str, graph: dict): + self.append(self._create_workflow(id, WorkflowSource.document, graph)) + def remove(self, id: str): idx = self.find_index(id) if idx.isValid(): @@ -118,7 +131,7 @@ def save_as(self, id: str, graph: dict): self._folder.mkdir(exist_ok=True) path = self._folder / f"{id}.json" path.write_text(json.dumps(graph, indent=2)) - self.append(CustomWorkflow(id, WorkflowSource.local, graph, path)) + self.append(self._create_workflow(id, WorkflowSource.local, graph, path)) return id def import_file(self, filepath: Path): @@ -126,7 +139,7 @@ def import_file(self, filepath: Path): with filepath.open("r") as f: graph = json.load(f) try: - ComfyWorkflow.import_graph(graph) + ComfyWorkflow.import_graph(graph, self._connection.client.models.node_inputs) except Exception as e: raise RuntimeError(f"This is not a supported workflow file ({e})") return self.save_as(filepath.stem, graph) @@ -279,7 +292,7 @@ def _set_workflow_id(self, id: str): def set_graph(self, id: str, graph: dict): if self._workflows.find(id) is None: - self._workflows.append(CustomWorkflow(id, WorkflowSource.document, graph)) + self._workflows.add_from_document(id, graph) self.workflow_id = id def import_file(self, filepath: Path): diff --git a/tests/data/object_info.json b/tests/data/object_info.json new file mode 100644 index 0000000000..8ae9803922 --- /dev/null +++ b/tests/data/object_info.json @@ -0,0 +1,241 @@ +{ + "GrowMask": { + "input": { + "required": { + "mask": [ + "MASK" + ], + "expand": [ + "INT", + { + "default": 0, + "min": -16384, + "max": 16384, + "step": 1 + } + ], + "tapered_corners": [ + "BOOLEAN", + { + "default": true + } + ] + } + }, + "input_order": { + "required": [ + "mask", + "expand", + "tapered_corners" + ] + }, + "output": [ + "MASK" + ], + "output_is_list": [ + false + ], + "output_name": [ + "MASK" + ], + "name": "GrowMask", + "display_name": "GrowMask", + "description": "", + "python_module": "comfy_extras.nodes_mask", + "category": "mask", + "output_node": false + }, + "ImageUpscaleWithModel": { + "input": { + "required": { + "upscale_model": [ + "UPSCALE_MODEL" + ], + "image": [ + "IMAGE" + ] + } + }, + "input_order": { + "required": [ + "upscale_model", + "image" + ] + }, + "output": [ + "IMAGE" + ], + "output_is_list": [ + false + ], + "output_name": [ + "IMAGE" + ], + "name": "ImageUpscaleWithModel", + "display_name": "Upscale Image (using Model)", + "description": "", + "python_module": "comfy_extras.nodes_upscale_model", + "category": "image/upscaling", + "output_node": false + }, + "ETN_ApplyMaskToImage": { + "input": { + "required": { + "image": [ + "IMAGE" + ], + "mask": [ + "MASK" + ] + } + }, + "input_order": { + "required": [ + "image", + "mask" + ] + }, + "output": [ + "IMAGE" + ], + "output_is_list": [ + false + ], + "output_name": [ + "IMAGE" + ], + "name": "ETN_ApplyMaskToImage", + "display_name": "Apply Mask to Image", + "description": "", + "python_module": "custom_nodes.comfyui-tooling-nodes", + "category": "external_tooling", + "output_node": false + }, + "UpscaleModelLoader": { + "input": { + "required": { + "model_name": [ + [ + "4x_NMKD-Superscale-SP_178000_G.pth", + "OmniSR_X2_DIV2K.safetensors", + "OmniSR_X3_DIV2K.safetensors", + "OmniSR_X4_DIV2K.safetensors" + ] + ] + } + }, + "input_order": { + "required": [ + "model_name" + ] + }, + "output": [ + "UPSCALE_MODEL" + ], + "output_is_list": [ + false + ], + "output_name": [ + "UPSCALE_MODEL" + ], + "name": "UpscaleModelLoader", + "display_name": "Load Upscale Model", + "description": "", + "python_module": "comfy_extras.nodes_upscale_model", + "category": "loaders", + "output_node": false + }, + "ETN_KritaCanvas": { + "input": {}, + "input_order": {}, + "output": [ + "IMAGE", + "INT", + "INT", + "INT" + ], + "output_is_list": [ + false, + false, + false, + false + ], + "output_name": [ + "image", + "width", + "height", + "seed" + ], + "name": "ETN_KritaCanvas", + "display_name": "Krita Canvas", + "description": "", + "python_module": "custom_nodes.comfyui-tooling-nodes", + "category": "krita", + "output_node": false + }, + "ETN_KritaOutput": { + "input": { + "required": { + "images": [ + "IMAGE" + ], + "format": [ + [ + "PNG", + "JPEG" + ], + { + "default": "PNG" + } + ] + } + }, + "input_order": { + "required": [ + "images", + "format" + ] + }, + "output": [], + "output_is_list": [], + "output_name": [], + "name": "ETN_KritaOutput", + "display_name": "Krita Output", + "description": "", + "python_module": "custom_nodes.comfyui-tooling-nodes", + "category": "krita", + "output_node": true + }, + "ETN_KritaMaskLayer": { + "input": { + "required": { + "name": [ + "STRING", + { + "default": "Mask" + } + ] + } + }, + "input_order": { + "required": [ + "name" + ] + }, + "output": [ + "MASK" + ], + "output_is_list": [ + false + ], + "output_name": [ + "mask" + ], + "name": "ETN_KritaMaskLayer", + "display_name": "Krita Mask Layer", + "description": "", + "python_module": "custom_nodes.comfyui-tooling-nodes", + "category": "krita", + "output_node": false + } +} \ No newline at end of file diff --git a/tests/data/workflow-api.json b/tests/data/workflow-api.json new file mode 100644 index 0000000000..8dfba8fc14 --- /dev/null +++ b/tests/data/workflow-api.json @@ -0,0 +1,64 @@ +{ + "0": { + "class_type": "UpscaleModelLoader", + "inputs": { + "model_name": "4x_NMKD-Superscale-SP_178000_G.pth" + } + }, + "1": { + "class_type": "ETN_KritaCanvas", + "inputs": {} + }, + "2": { + "class_type": "ETN_KritaMaskLayer", + "inputs": { + "name": "Zauber" + } + }, + "3": { + "class_type": "GrowMask", + "inputs": { + "mask": [ + "2", + 0 + ], + "expand": 4 + } + }, + "4": { + "class_type": "ImageUpscaleWithModel", + "inputs": { + "upscale_model": [ + "0", + 0 + ], + "image": [ + "1", + 0 + ] + } + }, + "5": { + "class_type": "ETN_ApplyMaskToImage", + "inputs": { + "image": [ + "4", + 0 + ], + "mask": [ + "3", + 0 + ] + } + }, + "6": { + "class_type": "ETN_KritaOutput", + "inputs": { + "images": [ + "5", + 0 + ], + "format": "PNG" + } + } +} \ No newline at end of file diff --git a/tests/data/workflow-ui.json b/tests/data/workflow-ui.json new file mode 100644 index 0000000000..7ade8ed334 --- /dev/null +++ b/tests/data/workflow-ui.json @@ -0,0 +1,438 @@ +{ + "last_node_id": 10, + "last_link_id": 10, + "nodes": [ + { + "id": 3, + "type": "GrowMask", + "pos": { + "0": 448, + "1": 57 + }, + "size": [ + 315, + 82 + ], + "flags": {}, + "order": 5, + "mode": 0, + "inputs": [ + { + "name": "mask", + "type": "MASK", + "link": 3 + }, + { + "name": "expand", + "type": "INT", + "link": 2, + "widget": { + "name": "expand" + } + } + ], + "outputs": [ + { + "name": "MASK", + "type": "MASK", + "links": [ + 10 + ], + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "GrowMask" + }, + "widgets_values": [ + 4, + true + ] + }, + { + "id": 4, + "type": "PrimitiveNode", + "pos": { + "0": 138, + "1": 156 + }, + "size": [ + 210, + 82 + ], + "flags": {}, + "order": 0, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "INT", + "type": "INT", + "links": [ + 2 + ], + "slot_index": 0, + "widget": { + "name": "expand" + } + } + ], + "properties": { + "Run widget replace on values": false + }, + "widgets_values": [ + 4, + "fixed" + ] + }, + { + "id": 8, + "type": "PrimitiveNode", + "pos": { + "0": -280, + "1": -370 + }, + "size": [ + 364.38086885107873, + 106 + ], + "flags": {}, + "order": 1, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "COMBO", + "type": "COMBO", + "links": [ + 4 + ], + "slot_index": 0, + "widget": { + "name": "model_name" + } + } + ], + "properties": { + "Run widget replace on values": false + }, + "widgets_values": [ + "4x_NMKD-Superscale-SP_178000_G.pth", + "fixed", + "" + ] + }, + { + "id": 9, + "type": "ImageUpscaleWithModel", + "pos": { + "0": 470, + "1": -370 + }, + "size": { + "0": 340.20001220703125, + "1": 46 + }, + "flags": {}, + "order": 6, + "mode": 0, + "inputs": [ + { + "name": "upscale_model", + "type": "UPSCALE_MODEL", + "link": 5 + }, + { + "name": "image", + "type": "IMAGE", + "link": 6 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 9 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "ImageUpscaleWithModel" + } + }, + { + "id": 10, + "type": "ETN_ApplyMaskToImage", + "pos": { + "0": 870, + "1": -190 + }, + "size": { + "0": 239.40000915527344, + "1": 46 + }, + "flags": {}, + "order": 7, + "mode": 0, + "inputs": [ + { + "name": "image", + "type": "IMAGE", + "link": 9 + }, + { + "name": "mask", + "type": "MASK", + "link": 10 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 8 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "ETN_ApplyMaskToImage" + } + }, + { + "id": 6, + "type": "UpscaleModelLoader", + "pos": { + "0": 120, + "1": -370 + }, + "size": [ + 315, + 58 + ], + "flags": {}, + "order": 4, + "mode": 0, + "inputs": [ + { + "name": "model_name", + "type": "COMBO", + "link": 4, + "widget": { + "name": "model_name" + } + } + ], + "outputs": [ + { + "name": "UPSCALE_MODEL", + "type": "UPSCALE_MODEL", + "links": [ + 5 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "UpscaleModelLoader" + }, + "widgets_values": [ + "4x_NMKD-Superscale-SP_178000_G.pth" + ] + }, + { + "id": 1, + "type": "ETN_KritaCanvas", + "pos": { + "0": 205, + "1": -247 + }, + "size": [ + 200, + 100 + ], + "flags": {}, + "order": 2, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "image", + "type": "IMAGE", + "links": [ + 6 + ], + "shape": 3, + "slot_index": 0 + }, + { + "name": "width", + "type": "INT", + "links": null, + "shape": 3 + }, + { + "name": "height", + "type": "INT", + "links": null, + "shape": 3 + }, + { + "name": "seed", + "type": "INT", + "links": null, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "ETN_KritaCanvas" + } + }, + { + "id": 2, + "type": "ETN_KritaOutput", + "pos": { + "0": 1140, + "1": -190 + }, + "size": [ + 200, + 120 + ], + "flags": {}, + "order": 8, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 8 + } + ], + "outputs": [], + "properties": { + "Node name for S&R": "ETN_KritaOutput" + }, + "widgets_values": [ + "PNG" + ] + }, + { + "id": 5, + "type": "ETN_KritaMaskLayer", + "pos": { + "0": 41, + "1": 11 + }, + "size": { + "0": 315, + "1": 58 + }, + "flags": {}, + "order": 3, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "mask", + "type": "MASK", + "links": [ + 3 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "ETN_KritaMaskLayer" + }, + "widgets_values": [ + "Zauber" + ] + } + ], + "links": [ + [ + 2, + 4, + 0, + 3, + 1, + "INT" + ], + [ + 3, + 5, + 0, + 3, + 0, + "MASK" + ], + [ + 4, + 8, + 0, + 6, + 0, + "COMBO" + ], + [ + 5, + 6, + 0, + 9, + 0, + "UPSCALE_MODEL" + ], + [ + 6, + 1, + 0, + 9, + 1, + "IMAGE" + ], + [ + 8, + 10, + 0, + 2, + 0, + "IMAGE" + ], + [ + 9, + 9, + 0, + 10, + 0, + "IMAGE" + ], + [ + 10, + 3, + 0, + 10, + 1, + "MASK" + ] + ], + "groups": [], + "config": {}, + "extra": { + "ds": { + "scale": 1.0834705943388394, + "offset": [ + 311.97951538052513, + 487.01874472527845 + ] + } + }, + "version": 0.4 +} \ No newline at end of file diff --git a/tests/test_custom_workflow.py b/tests/test_custom_workflow.py index 6ce33a9c91..46b86b2c60 100644 --- a/tests/test_custom_workflow.py +++ b/tests/test_custom_workflow.py @@ -12,6 +12,8 @@ from ai_diffusion.image import Image, Extent from ai_diffusion import workflow +from .config import test_dir + def test_collection(tmp_path: Path): file1 = tmp_path / "file1.json" @@ -27,13 +29,13 @@ def test_collection(tmp_path: Path): collection = WorkflowCollection(connection, tmp_path) assert len(collection) == 3 - assert collection.find("file1") == CustomWorkflow( + assert collection.find("file1") == CustomWorkflow.from_api( "file1", WorkflowSource.local, {"file": 1}, file1 ) - assert collection.find("file2") == CustomWorkflow( + assert collection.find("file2") == CustomWorkflow.from_api( "file2", WorkflowSource.local, {"file": 2}, file2 ) - assert collection.find("connection1") == CustomWorkflow( + assert collection.find("connection1") == CustomWorkflow.from_api( "connection1", WorkflowSource.remote, {"connection": 1} ) @@ -56,17 +58,17 @@ def on_data_changed(start, end): connection.workflow_published.emit("connection2") assert len(collection) == 4 - assert collection.find("connection2") == CustomWorkflow( + assert collection.find("connection2") == CustomWorkflow.from_api( "connection2", WorkflowSource.remote, {"connection": 2} ) collection.set_graph(collection.index(0), {"file": 3}) - assert collection.find("file1") == CustomWorkflow( + assert collection.find("file1") == CustomWorkflow.from_api( "file1", WorkflowSource.local, {"file": 3}, file1 ) assert events == [("begin_insert", 3), "end_insert", ("data_changed", 0)] - collection.append(CustomWorkflow("doc1", WorkflowSource.document, {"doc": 1})) + collection.add_from_document("doc1", {"doc": 1}) sorted = SortedWorkflows(collection) assert sorted[0].source is WorkflowSource.document @@ -158,18 +160,28 @@ def test_workspace(): def test_import(): - w = ComfyWorkflow.import_graph( - { - "4": {"class_type": "A", "inputs": {"int": 4, "float": 1.2, "string": "mouse"}}, - "zak": {"class_type": "C", "inputs": {"in": ["9", 1]}}, - "9": {"class_type": "B", "inputs": {"in": ["4", 0]}}, - } - ) + graph = { + "4": {"class_type": "A", "inputs": {"int": 4, "float": 1.2, "string": "mouse"}}, + "zak": {"class_type": "C", "inputs": {"in": ["9", 1]}}, + "9": {"class_type": "B", "inputs": {"in": ["4", 0]}}, + } + w = ComfyWorkflow.import_graph(graph, {}) assert w.node(0) == ComfyNode(0, "A", {"int": 4, "float": 1.2, "string": "mouse"}) assert w.node(1) == ComfyNode(1, "B", {"in": Output(0, 0)}) assert w.node(2) == ComfyNode(2, "C", {"in": Output(1, 1)}) +def test_import_ui_workflow(): + graph = json.loads((test_dir / "data" / "workflow-ui.json").read_text()) + object_info = json.loads((test_dir / "data" / "object_info.json").read_text()) + node_inputs = {k: v.get("input") for k, v in object_info.items()} + result = ComfyWorkflow.import_graph(graph, node_inputs) + + expected_graph = json.loads((test_dir / "data" / "workflow-api.json").read_text()) + expected = ComfyWorkflow.import_graph(expected_graph, {}) + assert result.root == expected.root + + def test_parameters(): w = ComfyWorkflow() w.add("ETN_IntParameter", 1, name="int", default=4, min=0, max=10)