Skip to content

Commit

Permalink
Add some GGUF model detection
Browse files Browse the repository at this point in the history
- this makes eg. SD3.5 medium GGUF work
Acly committed Nov 29, 2024
1 parent fbe3781 commit a1571c2
Showing 3 changed files with 17 additions and 16 deletions.
2 changes: 1 addition & 1 deletion ai_diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Generative AI plugin for Krita"""

__version__ = "1.28.1"
__version__ = "1.29.0"

import importlib.util

22 changes: 11 additions & 11 deletions ai_diffusion/comfy_client.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@
from enum import Enum
from collections import deque
from itertools import chain, product
from typing import NamedTuple, Optional, Sequence
from typing import Any, NamedTuple, Optional, Sequence

from .api import WorkflowInput
from .client import Client, CheckpointInfo, ClientMessage, ClientEvent, DeviceInfo, ClientModels
@@ -163,6 +163,7 @@ async def connect(url=default_url, access_token=""):
# Retrieve list of checkpoints
checkpoints = await client.try_inspect("checkpoints")
diffusion_models = await client.try_inspect("diffusion_models")
diffusion_models.update(await client.try_inspect("unet_gguf"))
client._refresh_models(nodes, checkpoints, diffusion_models)

# Check supported SD versions and make sure there is at least one
@@ -369,11 +370,11 @@ async def disconnect(self):
self._unsubscribe_workflows(),
)

async def try_inspect(self, folder_name: str):
async def try_inspect(self, folder_name: str) -> dict[str, Any]:
try:
return await self._get(f"api/etn/model_info/{folder_name}")
except NetworkError:
return None # server has old external tooling version
return {} # server has old external tooling version

@property
def queued_count(self):
@@ -384,11 +385,13 @@ def is_executing(self):
return self._active is not None

async def refresh(self):
nodes, checkpoints, diffusion_models = await asyncio.gather(
nodes, checkpoints, diffusion_models, diffusion_gguf = await asyncio.gather(
self._get("object_info"),
self.try_inspect("checkpoints"),
self.try_inspect("diffusion_models"),
self.try_inspect("unet_gguf"),
)
diffusion_models.update(diffusion_gguf)
self._refresh_models(nodes, checkpoints, diffusion_models)

def _refresh_models(self, nodes: dict, checkpoints: dict | None, diffusion_models: dict | None):
@@ -407,7 +410,7 @@ def parse_model_info(models: dict, model_format: FileFormat):
return {
filename: CheckpointInfo(filename, arch, model_format)
for filename, arch, is_inpaint, is_refiner in parsed
if not (arch is None or is_inpaint or is_refiner)
if not (arch is None or (is_inpaint and arch is not Arch.flux) or is_refiner)
}

if checkpoints:
@@ -424,12 +427,9 @@ def parse_model_info(models: dict, model_format: FileFormat):
models.loras = nodes["LoraLoader"]["input"]["required"]["lora_name"][0]

if gguf_node := nodes.get("UnetLoaderGGUF", None):
gguf_models = {
name: CheckpointInfo(name, Arch.flux, FileFormat.diffusion)
for name in gguf_node["input"]["required"]["unet_name"][0]
}
models.checkpoints.update(gguf_models)
log.info(f"GGUF support: {len(gguf_models)} models found.")
for name in gguf_node["input"]["required"]["unet_name"][0]:
if name not in models.checkpoints:
models.checkpoints[name] = CheckpointInfo(name, Arch.flux, FileFormat.diffusion)
else:
log.info(f"GGUF support: node is not installed.")

9 changes: 5 additions & 4 deletions ai_diffusion/resources.py
Original file line number Diff line number Diff line change
@@ -6,10 +6,10 @@

# Version identifier for all the resources defined here. This is used as the server version.
# It usually follows the plugin version, but not all new plugin versions also require a server update.
version = "1.28.0"
version = "1.29.0"

comfy_url = "https://github.com/comfyanonymous/ComfyUI"
comfy_version = "61196d88576c95c1cd8535e881af48172d5af525"
comfy_version = "bf2650a80e5a7a888da206eab45c53dbb22940f7"


class CustomNode(NamedTuple):
@@ -39,7 +39,7 @@ class CustomNode(NamedTuple):
"External Tooling Nodes",
"comfyui-tooling-nodes",
"https://github.com/Acly/comfyui-tooling-nodes",
"e10daee9edea458fc709f60e725970a25567fca4",
"d7d421baaa7d3140fd7fc500d928244045211217",
["ETN_LoadImageBase64", "ETN_LoadMaskBase64", "ETN_SendImageWebSocket", "ETN_Translate"],
),
CustomNode(
@@ -56,7 +56,7 @@ class CustomNode(NamedTuple):
"GGUF",
"ComfyUI-GGUF",
"https://github.com/city96/ComfyUI-GGUF",
"8e898fad4caab59bf4144e0cf11978b893de7e54",
"4a8432884167f2526d60ef36e985bdabebb9e1e0",
["UnetLoaderGGUF", "DualCLIPLoaderGGUF"],
)
]
@@ -939,6 +939,7 @@ def is_required(kind: ResourceKind, arch: Arch, identifier: ControlMode | Upscal
resource_id(ResourceKind.text_encoder, Arch.all, "t5"): ["t5"],
resource_id(ResourceKind.vae, Arch.sd15, "default"): ["vae-ft-mse-840000-ema"],
resource_id(ResourceKind.vae, Arch.sdxl, "default"): ["sdxl_vae"],
resource_id(ResourceKind.vae, Arch.sd3, "default"): ["sd3"],
resource_id(ResourceKind.vae, Arch.flux, "default"): ["ae.s"],
}
# fmt: on

0 comments on commit a1571c2

Please sign in to comment.