Skip to content

Commit

Permalink
Fix pose vector test, clean up unused code
Browse files Browse the repository at this point in the history
  • Loading branch information
Acly committed Oct 16, 2024
1 parent af8907f commit f9559c3
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 24 deletions.
2 changes: 1 addition & 1 deletion ai_diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Generative AI plugin for Krita"""

__version__ = "1.25.0"
__version__ = "1.26.0"

import importlib.util

Expand Down
21 changes: 1 addition & 20 deletions ai_diffusion/comfy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,26 +727,7 @@ def _extract_pose_json(msg: dict):
try:
output = msg["data"]["output"]
if "openpose_json" in output:
result = json.loads(output["openpose_json"][0])
return result[0] if isinstance(result, list) else result
return json.loads(output["openpose_json"][0])
except Exception as e:
log.warning(f"Error processing message, error={str(e)}, msg={msg}")
return None


def _validate_executed_node(msg: dict, image_count: int):
try:
output = msg["data"]["output"]
assert "openpose_json" not in output

images = output["images"]
if len(images) != image_count: # not critical
log.warning(f"Received number of images does not match: {len(images)} != {image_count}")
if image_count == 0 or len(images) == 0:
log.warning(f"Received no images (execution cached?)")
return False
if "source" in images[0] and images[0]["type"] == "output":
return True
except Exception as e:
log.warning(f"Error processing message, error={str(e)}, msg={msg}")
return False
7 changes: 5 additions & 2 deletions scripts/download_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

sys.path.append(str(Path(__file__).parent.parent))
from ai_diffusion import resources
from ai_diffusion.resources import Arch, ResourceKind
from ai_diffusion.resources import Arch, ResourceKind, ModelResource
from ai_diffusion.resources import required_models, default_checkpoints, optional_models

try:
Expand Down Expand Up @@ -115,6 +115,7 @@ async def main(
minimal=False,
recommended=False,
all=False,
exclude=[],
retry_attempts=5,
continue_on_error=False,
):
Expand All @@ -134,7 +135,7 @@ async def main(

timeout = aiohttp.ClientTimeout(total=None, sock_connect=10, sock_read=60)
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as client:
models = set()
models: set[ModelResource] = set()
models.update([m for m in default_checkpoints if all or (m.id.identifier in checkpoints)])
if minimal or recommended or all or sd15 or sdxl:
models.update([m for m in required_models if m.arch in versions])
Expand All @@ -151,6 +152,8 @@ async def main(
if prefetch or all:
models.update(resources.prefetch_models)

models = models - set([m for m in models if m.id.string in exclude])

if len(models) == 0:
print("\nNo models selected for download.")

Expand Down
2 changes: 1 addition & 1 deletion tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ async def main():
if not job_id:
job_id = await client.enqueue(job)
if msg.event is ClientEvent.finished and msg.job_id == job_id:
assert isinstance(msg.result, dict)
assert isinstance(msg.result, (dict, list))
result = Pose.from_open_pose_json(msg.result).to_svg()
(result_dir / image_name).write_text(result)
return
Expand Down

0 comments on commit f9559c3

Please sign in to comment.