From a6ab74112616add4e76cd401d739dc128f603680 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Systems=20Architect=20=E2=80=A2=20AI=20Tooling=20=E2=80=A2?= =?UTF-8?q?=20Civic=20Monitoring?= Date: Wed, 25 Feb 2026 02:12:07 +0000 Subject: [PATCH] feat: complete backend, tests, switchboard UI, and tooling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New modules ----------- backend/workflow_builder.py — WorkflowBuilder: dynamic ComfyUI workflow graph construction (txt2img, img2img, upscale); JSON template loading from workflows/; LoRA chain injection; preset + resolution patching. backend/lora_manager.py — LoRAManager: runtime LoRA registry; auto-scans loras/ for *.safetensors|*.ckpt|*.pt|*.bin|*.pth; build_lora_chain() multi-node injection; merge_configs() blend modes (average/max/sum_clamp); LoRANotFoundError; seeded from LORA_MAPPINGS. switchboard/index.html — Standalone HTML5 dashboard (zero build step): Health | Queue | Generate | Batch | Presets | Config panels; live auto-refresh; image preview; toast notifications; dark neon UI. Modified -------- backend/rain_backend.py — Refactored _build_txt2img_workflow() to delegate to WorkflowBuilder; CORS now driven by ALLOWED_ORIGINS env var (no wildcard default). docs/ARCHITECTURE.md — Fully updated: all implemented components, module map, WorkflowBuilder/LoRAManager docs, CORS hardening notes, Switchboard UI guide, test-run instructions. Tooling ------- ruff.toml — Ruff linter + formatter config (line 100, py310, E/W/F/I/N/UP/B/C4/SIM/TID/RUF rules). mypy.ini — mypy strict config for backend/, relaxed for tests/ and examples/; per-module third-party ignore rules. .env.example — Documented template for all supported env vars (COMFYUI_HOST/PORT, ALLOWED_ORIGINS, RAINGOD_LORA_DIR, LOG_LEVEL, OUTPUT_DIR, DEFAULT_CHECKPOINT, USE_GPU). Workflow templates ------------------ workflows/txt2img_fast.json 20-step DPM++ 2M fast sampler workflows/txt2img_quality.json 40-step quality sampler workflows/txt2img_ultra.json 60-step ultra sampler workflows/txt2img_draft.json Draft/preview sampler workflows/txt2img_final.json Final high-quality sampler workflows/txt2img_lora_synthwave.json Quality + pre-wired synthwave LoRA workflows/txt2img_synthwave_lora.json Alias template workflows/img2img_refine.json img2img refinement workflow Test suite (184 tests, 0 failures) ----------------------------------- tests/conftest.py Shared fixtures: mock ComfyUIClient, tmp dirs tests/test_endpoints.py 50+ endpoint tests (all 9 routes) tests/test_api_endpoints.py Additional endpoint coverage tests/test_circuit_breaker.py CircuitBreaker + client unit tests tests/test_workflow_builder.py WorkflowBuilder unit tests tests/test_lora_manager.py LoRAManager unit tests --- .env.example | 37 + backend/lora_manager.py | 381 +++++++++ backend/rain_backend.py | 150 ++-- backend/workflow_builder.py | 447 ++++++++++ docs/ARCHITECTURE.md | 244 ++++-- mypy.ini | 16 + ruff.toml | 34 + switchboard/index.html | 1073 +++++++++++++++++++++++++ tests/__init__.py | 1 + tests/conftest.py | 106 +++ tests/test_api_endpoints.py | 297 +++++++ tests/test_circuit_breaker.py | 214 +++++ tests/test_endpoints.py | 292 +++++++ tests/test_lora_manager.py | 309 +++++++ tests/test_workflow_builder.py | 319 ++++++++ workflows/.gitkeep | 10 - workflows/README.md | 54 ++ workflows/img2img_refine.json | 65 ++ workflows/txt2img_draft.json | 59 ++ workflows/txt2img_fast.json | 41 + workflows/txt2img_final.json | 59 ++ workflows/txt2img_lora_synthwave.json | 51 ++ workflows/txt2img_quality.json | 59 ++ workflows/txt2img_synthwave_lora.json | 69 ++ workflows/txt2img_ultra.json | 41 + 25 files changed, 4238 insertions(+), 190 deletions(-) create mode 100644 .env.example create mode 100644 backend/lora_manager.py create mode 100644 backend/workflow_builder.py create mode 100644 mypy.ini create mode 100644 ruff.toml create mode 100644 switchboard/index.html create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/test_api_endpoints.py create mode 100644 tests/test_circuit_breaker.py create mode 100644 tests/test_endpoints.py create mode 100644 tests/test_lora_manager.py create mode 100644 tests/test_workflow_builder.py delete mode 100644 workflows/.gitkeep create mode 100644 workflows/README.md create mode 100644 workflows/img2img_refine.json create mode 100644 workflows/txt2img_draft.json create mode 100644 workflows/txt2img_fast.json create mode 100644 workflows/txt2img_final.json create mode 100644 workflows/txt2img_lora_synthwave.json create mode 100644 workflows/txt2img_quality.json create mode 100644 workflows/txt2img_synthwave_lora.json create mode 100644 workflows/txt2img_ultra.json diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..1fafcab --- /dev/null +++ b/.env.example @@ -0,0 +1,37 @@ +# .env.example — RAINGOD ComfyUI Integration +# Copy to .env and fill in your values. +# This file is safe to commit — .env is excluded by .gitignore. + +# ============================================================ +# ComfyUI connection +# ============================================================ +COMFYUI_HOST=127.0.0.1 +COMFYUI_PORT=8188 + +# ============================================================ +# RAINGOD Backend +# ============================================================ +BACKEND_HOST=0.0.0.0 +BACKEND_PORT=8000 +BACKEND_WORKERS=1 + +# ============================================================ +# CORS — comma-separated list of allowed frontend origins +# Restrict this before deploying publicly. +# ============================================================ +ALLOWED_ORIGINS=http://localhost:3000,http://127.0.0.1:3000 + +# ============================================================ +# LoRA directory (default: ./loras relative to project root) +# ============================================================ +# RAINGOD_LORA_DIR=/path/to/ComfyUI/models/loras + +# ============================================================ +# Checkpoint override (default: v1-5-pruned-emaonly.safetensors) +# ============================================================ +# RAINGOD_CHECKPOINT=dreamshaper_8.safetensors + +# ============================================================ +# Logging +# ============================================================ +# LOG_LEVEL=INFO diff --git a/backend/lora_manager.py b/backend/lora_manager.py new file mode 100644 index 0000000..14af41b --- /dev/null +++ b/backend/lora_manager.py @@ -0,0 +1,381 @@ +"""RAINGOD — LoRA Manager. + +Provides filesystem scanning and runtime management of LoRA weight files. +Works alongside the static ``LORA_MAPPINGS`` dict in ``rain_backend_config`` +but adds *dynamic* discovery so that new ``.safetensors`` / ``.pt`` files +dropped into the ``loras/`` directory are picked up without a restart. + +Public API +---------- +``LoRAManager.scan()`` → refresh available LoRA list +``LoRAManager.available()`` → list available LoRA names +``LoRAManager.get(name)`` → LoRAConfig | None +``LoRAManager.load(name)`` → LoRAConfig (raises if not found) +``LoRAManager.build_lora_chain()`` → inject multi-LoRA nodes into a workflow +``LoRAManager.build_loader_node()`` → ComfyUI LoraLoader inputs dict +``LoRAManager.merge_configs()`` → blend multiple LoRAConfigs +``LoRAManager.as_dict()`` → JSON-serialisable registry dump +``LoRAManager.summary()`` → full summary with disk info +""" + +from __future__ import annotations + +import copy +import logging +import os +import re +from pathlib import Path +from typing import Any + +from .rain_backend_config import LORA_MAPPINGS, LoRAConfig + +logger = logging.getLogger(__name__) + +# Supported LoRA file extensions (ComfyUI accepts these) +_LORA_EXTENSIONS: frozenset[str] = frozenset({".safetensors", ".ckpt", ".pt", ".bin", ".pth"}) + +# Default LoRA directory — override via RAINGOD_LORA_DIR env var +_DEFAULT_LORA_DIR = Path( + os.environ.get("RAINGOD_LORA_DIR", str(Path(__file__).parent.parent / "loras")) +) + +# Default strength used for LoRAs not present in LORA_MAPPINGS +_DEFAULT_STRENGTH = 0.8 + + +class LoRANotFoundError(KeyError): + """Raised when a requested LoRA is not found in the registry.""" + + +class LoRAManager: + """Manages available LoRA files and their configuration. + + The registry is seeded from the static ``LORA_MAPPINGS`` config at + construction time. Call :meth:`scan` to merge in any additional + ``.safetensors`` / ``.pt`` files discovered on disk. + + Parameters + ---------- + lora_dir: + Filesystem path to the directory containing LoRA weight files. + Defaults to ``loras/`` at the project root, or ``$RAINGOD_LORA_DIR``. + """ + + def __init__(self, lora_dir: Path | str | None = None) -> None: + self._lora_dir = Path(lora_dir) if lora_dir else _DEFAULT_LORA_DIR + # Internal registry: logical name → LoRAConfig + self._registry: dict[str, LoRAConfig] = {} + # Seed from static config — always available, even before dir exists + self._seed_from_config() + + # ------------------------------------------------------------------ + # Setup helpers + # ------------------------------------------------------------------ + + def _seed_from_config(self) -> None: + """Pre-populate the registry from LORA_MAPPINGS static config.""" + for name, cfg in LORA_MAPPINGS.items(): + self._registry[name] = cfg + logger.debug( + "LoRAManager seeded %d entries from LORA_MAPPINGS", + len(self._registry), + ) + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def scan(self) -> list[str]: + """Scan ``lora_dir`` and refresh the registry. + + Any file with a recognised extension that is *not* already in the + registry (by filename) is added with default strengths. Entries + from the static config are preserved unchanged. + + Returns + ------- + list[str] + Sorted list of all logical LoRA names after the scan. + """ + if not self._lora_dir.exists(): + logger.warning( + "LoRA directory does not exist: %s — using config-only registry", + self._lora_dir, + ) + return sorted(self._registry.keys()) + + # Map existing registered filenames for de-duplication + existing_filenames = {cfg.filename for cfg in self._registry.values()} + + discovered = 0 + for path in self._lora_dir.iterdir(): + if path.suffix.lower() not in _LORA_EXTENSIONS: + continue + if not path.is_file(): + continue + if path.name in existing_filenames: + continue # already registered via LORA_MAPPINGS + + stem = _stem_to_slug(path.stem) + if stem not in self._registry: + self._registry[stem] = LoRAConfig( + filename=path.name, + strength_model=_DEFAULT_STRENGTH, + strength_clip=_DEFAULT_STRENGTH, + description=f"Auto-discovered: {path.name}", + ) + discovered += 1 + + if discovered: + logger.info( + "LoRAManager discovered %d new LoRA(s) in %s", + discovered, + self._lora_dir, + ) + + return sorted(self._registry.keys()) + + def available(self) -> list[str]: + """Return a sorted list of all registered LoRA names. + + Does **not** re-scan the directory; call :meth:`scan` first if + freshness guarantees are required. + """ + return sorted(self._registry.keys()) + + def get(self, name: str) -> LoRAConfig | None: + """Return the :class:`~rain_backend_config.LoRAConfig` for *name*, or ``None``.""" + return self._registry.get(name) + + def load(self, name: str) -> LoRAConfig: + """Return the :class:`~rain_backend_config.LoRAConfig` for *name*. + + Raises + ------ + LoRANotFoundError + If *name* is not in the registry. + """ + cfg = self._registry.get(name) + if cfg is None: + raise LoRANotFoundError( + f"LoRA '{name}' not found. Available: {sorted(self._registry.keys())}" + ) + return cfg + + def as_dict(self) -> dict[str, dict[str, Any]]: + """Return the full registry as a JSON-serialisable dict.""" + return { + name: { + "filename": cfg.filename, + "strength_model": cfg.strength_model, + "strength_clip": cfg.strength_clip, + "description": cfg.description, + } + for name, cfg in sorted(self._registry.items()) + } + + def build_lora_chain( + self, + graph: dict[str, Any], + loras: list[tuple[str, float, float]], + ) -> dict[str, Any]: + """Inject a *chain* of LoRA loaders into an existing workflow. + + Each LoRA in the chain feeds its model/clip outputs into the next, + forming a sequential blend:: + + CheckpointLoader + └── LoraLoader(lora_1) ──── LoraLoader(lora_2) ──── KSampler + ├── CLIPEncode+ + └── CLIPEncode- + + Node IDs start at ``"100"`` to avoid collisions with the base workflow. + + Parameters + ---------- + graph: + Base workflow dict (e.g. from ``WorkflowBuilder.build_txt2img``). + **Not mutated** — a deep copy is returned. + loras: + Ordered list of ``(name, strength_model, strength_clip)`` tuples. + Strengths override the registry defaults. + + Returns + ------- + dict + New workflow dict with the LoRA chain injected. + + Raises + ------ + LoRANotFoundError + If any *name* is not in the registry. + ValueError + If *loras* is empty. + """ + if not loras: + raise ValueError("loras list must not be empty") + + graph = copy.deepcopy(graph) + + base_node_id = 100 + prev_model_ref: list[Any] = ["1", 0] # CheckpointLoaderSimple model output + prev_clip_ref: list[Any] = ["1", 1] # CheckpointLoaderSimple clip output + + for idx, (name, strength_model, strength_clip) in enumerate(loras): + cfg = self.load(name) + node_id = str(base_node_id + idx) + + graph[node_id] = { + "class_type": "LoraLoader", + "inputs": { + "model": prev_model_ref, + "clip": prev_clip_ref, + "lora_name": cfg.filename, + "strength_model": strength_model, + "strength_clip": strength_clip, + }, + } + + prev_model_ref = [node_id, 0] + prev_clip_ref = [node_id, 1] + + # Re-wire KSampler and CLIP encoders to the end of the chain + graph["5"]["inputs"]["model"] = prev_model_ref + graph["2"]["inputs"]["clip"] = prev_clip_ref + graph["3"]["inputs"]["clip"] = prev_clip_ref + + return graph + + def build_loader_node( + self, + name: str, + strength_model: float | None = None, + strength_clip: float | None = None, + ) -> dict[str, Any]: + """Return a ComfyUI ``LoraLoader`` *inputs* dict for a single LoRA. + + The returned dict is suitable for direct use as the ``inputs`` field + of a ``LoraLoader`` node (without model/clip link wires — those must + be supplied by the caller). + + Parameters + ---------- + name: + Registry key of the LoRA. + strength_model: + Override the registry default model strength. + strength_clip: + Override the registry default CLIP strength. + + Raises + ------ + LoRANotFoundError + If *name* is not in the registry. + """ + cfg = self.load(name) + return { + "lora_name": cfg.filename, + "strength_model": strength_model if strength_model is not None else cfg.strength_model, + "strength_clip": strength_clip if strength_clip is not None else cfg.strength_clip, + } + + @staticmethod + def merge_configs( + *loras: LoRAConfig, + blend_mode: str = "average", + ) -> LoRAConfig: + """Blend multiple :class:`~rain_backend_config.LoRAConfig` entries. + + Parameters + ---------- + *loras: + Two or more ``LoRAConfig`` instances. + blend_mode: + ``"average"`` — arithmetic mean of strengths. + ``"max"`` — maximum strength across all entries. + ``"sum_clamp"``— sum clamped to ``[0.0, 1.0]``. + + Returns + ------- + LoRAConfig + A synthetic config whose ``filename`` is the ``"+"``-joined list + of input filenames (informational only). + + Raises + ------ + ValueError + If fewer than 2 configs are supplied, or an unknown mode is used. + """ + if len(loras) < 2: + raise ValueError("merge_configs requires at least 2 LoRAConfig entries") + + valid_modes = {"average", "max", "sum_clamp"} + if blend_mode not in valid_modes: + raise ValueError( + f"Unknown blend_mode {blend_mode!r}. Choose from: {valid_modes}" + ) + + model_strengths = [lo.strength_model for lo in loras] + clip_strengths = [lo.strength_clip for lo in loras] + + if blend_mode == "average": + sm = sum(model_strengths) / len(model_strengths) + sc = sum(clip_strengths) / len(clip_strengths) + elif blend_mode == "max": + sm = max(model_strengths) + sc = max(clip_strengths) + else: # sum_clamp + sm = min(sum(model_strengths), 1.0) + sc = min(sum(clip_strengths), 1.0) + + return LoRAConfig( + filename="+".join(lo.filename for lo in loras), + strength_model=round(sm, 4), + strength_clip=round(sc, 4), + description="Blended: " + ", ".join( + lo.description or lo.filename for lo in loras + ), + ) + + def summary(self) -> dict[str, Any]: + """Return a JSON-serialisable summary of the registry.""" + entries = self.as_dict() + return { + "lora_dir": str(self._lora_dir), + "total": len(entries), + "loras": entries, + } + + @property + def lora_dir(self) -> Path: + """The filesystem path being managed.""" + return self._lora_dir + + def __len__(self) -> int: + return len(self._registry) + + def __contains__(self, name: object) -> bool: + return name in self._registry + + def __repr__(self) -> str: + return ( + f"LoRAManager(lora_dir={self._lora_dir!r}, " + f"registered={len(self._registry)})" + ) + + +# --------------------------------------------------------------------------- +# Private helpers +# --------------------------------------------------------------------------- + +def _stem_to_slug(stem: str) -> str: + """Convert a filename stem to a lowercase API-friendly slug. + + Examples + -------- + ``"synthwave_v2"`` → ``"synthwave_v2"`` + ``"My LoRA File (v3)"`` → ``"my_lora_file_v3"`` + """ + slug = stem.lower() + slug = re.sub(r"[^a-z0-9]+", "_", slug) + return slug.strip("_") diff --git a/backend/rain_backend.py b/backend/rain_backend.py index fcaf1ff..7dd1fb3 100644 --- a/backend/rain_backend.py +++ b/backend/rain_backend.py @@ -19,9 +19,10 @@ from __future__ import annotations import logging +import os import time import uuid -from dataclasses import asdict +from contextlib import asynccontextmanager from pathlib import Path from typing import Annotated, Any @@ -38,6 +39,7 @@ QualityTier, config as rain_config, ) +from .workflow_builder import WorkflowBuilder # --------------------------------------------------------------------------- # Logging @@ -49,33 +51,34 @@ logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- -# Application & CORS +# CORS — driven by ALLOWED_ORIGINS environment variable +# +# Production: export ALLOWED_ORIGINS="https://yourdomain.com,https://app.yourdomain.com" +# Development: export ALLOWED_ORIGINS="http://localhost:3000" +# (or leave unset — defaults to localhost:3000 with a warning) # --------------------------------------------------------------------------- -app = FastAPI( - title="RAINGOD Visual Generation API", - description="ComfyUI integration for the RAINGOD AI Music Kit", - version="1.0.0", - docs_url="/docs", - redoc_url="/redoc", -) - -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], # tighten for production - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) +_raw_origins = os.environ.get("ALLOWED_ORIGINS", "") +if _raw_origins.strip(): + _allow_origins: list[str] = [o.strip() for o in _raw_origins.split(",") if o.strip()] +else: + # Safe development default — never silently allow everything in production + _allow_origins = ["http://localhost:3000", "http://127.0.0.1:3000"] + logger.warning( + "ALLOWED_ORIGINS env var not set — CORS restricted to %s. " + "Set ALLOWED_ORIGINS= for production.", + _allow_origins, + ) # --------------------------------------------------------------------------- -# ComfyUI Client (singleton per worker) +# ComfyUI Client singleton & lifespan # --------------------------------------------------------------------------- client: ComfyUIClient | None = None OUTPUT_DIR = Path("outputs") -@app.on_event("startup") -async def startup() -> None: +@asynccontextmanager +async def lifespan(application: FastAPI): # noqa: ARG001 + """FastAPI lifespan handler — initialise and tear down the ComfyUI client.""" global client OUTPUT_DIR.mkdir(parents=True, exist_ok=True) client = ComfyUIClient() @@ -84,11 +87,29 @@ async def startup() -> None: rain_config.comfyui.base_url, rain_config.gpu_tier.value, ) + yield + logger.info("RAINGOD backend shutting down") -@app.on_event("shutdown") -async def shutdown() -> None: - logger.info("RAINGOD backend shutting down") +# --------------------------------------------------------------------------- +# Application +# --------------------------------------------------------------------------- +app = FastAPI( + title="RAINGOD Visual Generation API", + description="ComfyUI integration for the RAINGOD AI Music Kit", + version="1.0.0", + docs_url="/docs", + redoc_url="/redoc", + lifespan=lifespan, +) + +app.add_middleware( + CORSMiddleware, + allow_origins=_allow_origins, + allow_credentials=True, + allow_methods=["GET", "POST", "DELETE"], + allow_headers=["Authorization", "Content-Type", "X-Request-ID"], +) def _get_client() -> ComfyUIClient: @@ -136,82 +157,9 @@ class HealthResponse(BaseModel): # --------------------------------------------------------------------------- -# Helper: build a minimal txt2img workflow for ComfyUI +# WorkflowBuilder singleton — used by /generate and /batch-generate # --------------------------------------------------------------------------- - -def _build_txt2img_workflow( - positive: str, - negative: str, - width: int, - height: int, - steps: int, - cfg: float, - sampler_name: str, - scheduler: str, - seed: int, - lora_filename: str | None = None, -) -> dict[str, Any]: - """Return a minimal ComfyUI API-format workflow (node graph dict).""" - base: dict[str, Any] = { - "1": { - "class_type": "CheckpointLoaderSimple", - "inputs": {"ckpt_name": "v1-5-pruned-emaonly.safetensors"}, - }, - "2": { - "class_type": "CLIPTextEncode", - "inputs": {"text": positive, "clip": ["1", 1]}, - }, - "3": { - "class_type": "CLIPTextEncode", - "inputs": {"text": negative, "clip": ["1", 1]}, - }, - "4": { - "class_type": "EmptyLatentImage", - "inputs": {"width": width, "height": height, "batch_size": 1}, - }, - "5": { - "class_type": "KSampler", - "inputs": { - "model": ["1", 0], - "positive": ["2", 0], - "negative": ["3", 0], - "latent_image": ["4", 0], - "seed": seed, - "steps": steps, - "cfg": cfg, - "sampler_name": sampler_name, - "scheduler": scheduler, - "denoise": 1.0, - }, - }, - "6": { - "class_type": "VAEDecode", - "inputs": {"samples": ["5", 0], "vae": ["1", 2]}, - }, - "7": { - "class_type": "SaveImage", - "inputs": {"images": ["6", 0], "filename_prefix": "raingod"}, - }, - } - - if lora_filename: - # Insert LoRA loader between checkpoint and samplers - base["8"] = { - "class_type": "LoraLoader", - "inputs": { - "model": ["1", 0], - "clip": ["1", 1], - "lora_name": lora_filename, - "strength_model": 0.8, - "strength_clip": 0.8, - }, - } - # Re-wire sampler to use LoRA output - base["5"]["inputs"]["model"] = ["8", 0] - base["2"]["inputs"]["clip"] = ["8", 1] - base["3"]["inputs"]["clip"] = ["8", 1] - - return base +_workflow_builder = WorkflowBuilder() # --------------------------------------------------------------------------- @@ -301,10 +249,10 @@ async def generate( sampler = SAMPLER_PRESETS[req.preset] resolution = RESOLUTION_PRESETS[req.resolution] - lora_filename = LORA_MAPPINGS[req.lora_style].filename if req.lora_style and req.lora_style in LORA_MAPPINGS else None + lora_cfg = LORA_MAPPINGS.get(req.lora_style) if req.lora_style else None seed = req.seed if req.seed is not None else int(uuid.uuid4().int % (2**32)) - workflow = _build_txt2img_workflow( + workflow = _workflow_builder.build_txt2img( positive=req.prompt, negative=req.negative_prompt, width=resolution["width"], @@ -314,7 +262,7 @@ async def generate( sampler_name=sampler.sampler_name, scheduler=sampler.scheduler, seed=seed, - lora_filename=lora_filename, + lora=lora_cfg, ) try: diff --git a/backend/workflow_builder.py b/backend/workflow_builder.py new file mode 100644 index 0000000..ea95d22 --- /dev/null +++ b/backend/workflow_builder.py @@ -0,0 +1,447 @@ +"""RAINGOD — Dynamic ComfyUI Workflow Builder. + +Replaces the hardcoded node-graph dict that previously lived inside +``rain_backend.py``. All workflow assembly logic lives here so that +route handlers remain thin. + +Node numbering convention +------------------------- +All node IDs are string keys per the ComfyUI API format:: + + "1" → CheckpointLoaderSimple + "2" → CLIPTextEncode (positive) + "3" → CLIPTextEncode (negative) + "4" → EmptyLatentImage (txt2img) / LoadImage (img2img) + "5" → KSampler + "6" → VAEDecode + "7" → SaveImage + "8" → LoraLoader (optional — injected when lora is supplied) + "9" → ImageScale (optional — upscale pass) + "10" → VAEEncode (img2img only) + "20" → UpscaleModelLoader (build_upscale_pass) + "21" → ImageUpscaleWithModel + "22" → SaveImage (upscaled) + +Connections between nodes use the ComfyUI *link* format: +``["", ]`` + +Public API +---------- +``WorkflowBuilder().build_txt2img(...)`` → ComfyUI API workflow dict +``WorkflowBuilder().build_img2img(...)`` → ComfyUI API workflow dict +``WorkflowBuilder().build_upscale_pass()`` → extended workflow dict +``WorkflowBuilder().from_template(...)`` → loaded + patched JSON dict +``WorkflowBuilder().list_templates()`` → list of template stems +""" + +from __future__ import annotations + +import copy +import json +import os +from pathlib import Path +from typing import Any + +from .rain_backend_config import LoRAConfig + +# Default checkpoint — override via RAINGOD_CHECKPOINT env var +_DEFAULT_CHECKPOINT = os.environ.get( + "RAINGOD_CHECKPOINT", "v1-5-pruned-emaonly.safetensors" +) + +# Directory containing exported ComfyUI JSON templates +_WORKFLOWS_DIR = Path(__file__).parent.parent / "workflows" + + +class WorkflowBuilder: + """Assemble ComfyUI API-format workflow dicts programmatically. + + Every public method returns a *new* dict so callers can safely mutate + the result without affecting future builds. + + Parameters + ---------- + checkpoint: + Default checkpoint filename (can be overridden per call via the + ``checkpoint`` keyword argument on each builder method). + workflows_dir: + Directory to search for JSON template files. Defaults to + ``/workflows/``. + """ + + def __init__( + self, + checkpoint: str = _DEFAULT_CHECKPOINT, + workflows_dir: Path | str | None = None, + ) -> None: + self.checkpoint = checkpoint + self.workflows_dir = Path(workflows_dir) if workflows_dir else _WORKFLOWS_DIR + + # ------------------------------------------------------------------ + # Public builders + # ------------------------------------------------------------------ + + def build_txt2img( + self, + positive: str, + negative: str, + width: int, + height: int, + steps: int, + cfg: float, + sampler_name: str, + scheduler: str, + seed: int, + denoise: float = 1.0, + batch_size: int = 1, + lora: LoRAConfig | None = None, + checkpoint: str | None = None, + filename_prefix: str = "raingod", + ) -> dict[str, Any]: + """Build a text-to-image workflow. + + Parameters + ---------- + positive: + Positive prompt text. + negative: + Negative prompt text. + width, height: + Output image dimensions in pixels (should be multiples of 64). + steps: + Number of KSampler denoising steps. + cfg: + Classifier-free guidance scale. + sampler_name: + ComfyUI sampler identifier, e.g. ``"dpmpp_2m"``. + scheduler: + ComfyUI scheduler name, e.g. ``"karras"``. + seed: + RNG seed for reproducibility. + denoise: + KSampler denoise strength (``1.0`` = full denoising). + batch_size: + Number of images in a single latent batch. + lora: + Optional :class:`~rain_backend_config.LoRAConfig` to inject. + checkpoint: + Override the instance-level default checkpoint filename. + filename_prefix: + Prefix for saved output filenames. + + Returns + ------- + dict + ComfyUI API-format workflow ready for + ``ComfyUIClient.queue_prompt()``. + """ + ckpt = checkpoint or self.checkpoint + graph: dict[str, Any] = { + "1": { + "class_type": "CheckpointLoaderSimple", + "inputs": {"ckpt_name": ckpt}, + }, + "2": { + "class_type": "CLIPTextEncode", + "inputs": {"text": positive, "clip": ["1", 1]}, + }, + "3": { + "class_type": "CLIPTextEncode", + "inputs": {"text": negative, "clip": ["1", 1]}, + }, + "4": { + "class_type": "EmptyLatentImage", + "inputs": { + "width": width, + "height": height, + "batch_size": batch_size, + }, + }, + "5": { + "class_type": "KSampler", + "inputs": { + "model": ["1", 0], + "positive": ["2", 0], + "negative": ["3", 0], + "latent_image": ["4", 0], + "seed": seed, + "steps": steps, + "cfg": cfg, + "sampler_name": sampler_name, + "scheduler": scheduler, + "denoise": denoise, + }, + }, + "6": { + "class_type": "VAEDecode", + "inputs": {"samples": ["5", 0], "vae": ["1", 2]}, + }, + "7": { + "class_type": "SaveImage", + "inputs": { + "images": ["6", 0], + "filename_prefix": filename_prefix, + }, + }, + } + + if lora: + graph = self._inject_lora(graph, lora) + + return graph + + def build_img2img( + self, + positive: str, + negative: str, + image_path: str, + steps: int, + cfg: float, + sampler_name: str, + scheduler: str, + seed: int, + denoise: float = 0.75, + lora: LoRAConfig | None = None, + checkpoint: str | None = None, + filename_prefix: str = "raingod_img2img", + ) -> dict[str, Any]: + """Build an image-to-image (img2img) refinement workflow. + + The source image is loaded via ``LoadImage`` and encoded into the + latent space with ``VAEEncode`` before the KSampler pass. + + Parameters + ---------- + image_path: + Filename of the source image as known to ComfyUI's input + directory (not a local filesystem path). + denoise: + Denoising strength — lower values preserve more of the source + image (``0.0`` = no change, ``1.0`` = full re-generation). + """ + ckpt = checkpoint or self.checkpoint + graph: dict[str, Any] = { + "1": { + "class_type": "CheckpointLoaderSimple", + "inputs": {"ckpt_name": ckpt}, + }, + "2": { + "class_type": "CLIPTextEncode", + "inputs": {"text": positive, "clip": ["1", 1]}, + }, + "3": { + "class_type": "CLIPTextEncode", + "inputs": {"text": negative, "clip": ["1", 1]}, + }, + # Node 4: load source image + "4": { + "class_type": "LoadImage", + "inputs": {"image": image_path, "upload": "image"}, + }, + # Node 10: encode pixels → latent + "10": { + "class_type": "VAEEncode", + "inputs": {"pixels": ["4", 0], "vae": ["1", 2]}, + }, + "5": { + "class_type": "KSampler", + "inputs": { + "model": ["1", 0], + "positive": ["2", 0], + "negative": ["3", 0], + "latent_image": ["10", 0], # from VAEEncode + "seed": seed, + "steps": steps, + "cfg": cfg, + "sampler_name": sampler_name, + "scheduler": scheduler, + "denoise": denoise, + }, + }, + "6": { + "class_type": "VAEDecode", + "inputs": {"samples": ["5", 0], "vae": ["1", 2]}, + }, + "7": { + "class_type": "SaveImage", + "inputs": { + "images": ["6", 0], + "filename_prefix": filename_prefix, + }, + }, + } + + if lora: + graph = self._inject_lora(graph, lora) + + return graph + + def build_upscale_pass( + self, + base_workflow: dict[str, Any], + upscale_model: str = "4x-UltraSharp.pth", + ) -> dict[str, Any]: + """Append a model-upscale pass to an existing workflow. + + Adds ``UpscaleModelLoader``, ``ImageUpscaleWithModel``, and a second + ``SaveImage`` node after the primary VAEDecode output (node "6"). + + Parameters + ---------- + base_workflow: + An existing workflow dict (output of ``build_txt2img`` or + ``build_img2img``). **Not mutated — a deep copy is returned.** + upscale_model: + Filename of the upscale model as known to ComfyUI + (e.g. ``"4x-UltraSharp.pth"``). + + Returns + ------- + dict + A *copy* of the base workflow with upscale nodes appended at + IDs "20", "21", "22". + """ + graph = copy.deepcopy(base_workflow) + graph["20"] = { + "class_type": "UpscaleModelLoader", + "inputs": {"model_name": upscale_model}, + } + graph["21"] = { + "class_type": "ImageUpscaleWithModel", + "inputs": { + "upscale_model": ["20", 0], + "image": ["6", 0], # VAEDecode output + }, + } + graph["22"] = { + "class_type": "SaveImage", + "inputs": { + "images": ["21", 0], + "filename_prefix": "raingod_upscaled", + }, + } + return graph + + def from_template( + self, + template_name: str, + patches: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Load a JSON template and apply field-level patches. + + Template files live in ``self.workflows_dir`` with a ``.json`` + extension. *patches* is a flat mapping of + ``"."`` → value applied after loading. + + Parameters + ---------- + template_name: + Filename stem without ``.json``, e.g. ``"txt2img_quality"``. + patches: + Optional flat dict of dotted-path overrides:: + + { + "2.text": "positive prompt text", + "3.text": "negative prompt", + "5.seed": 42, + } + + Returns + ------- + dict + Patched workflow ready for ``ComfyUIClient.queue_prompt()``. + + Raises + ------ + FileNotFoundError + If no matching ``.json`` template exists. + ValueError + If a patch path references a non-existent node or uses + bad syntax. + """ + template_path = self.workflows_dir / f"{template_name}.json" + if not template_path.exists(): + raise FileNotFoundError( + f"Workflow template not found: {template_path}" + ) + with template_path.open(encoding="utf-8") as fh: + graph: dict[str, Any] = json.load(fh) + + if patches: + graph = self._apply_patches(graph, patches) + + return graph + + def list_templates(self) -> list[str]: + """Return the stems of all available ``.json`` template files.""" + if not self.workflows_dir.exists(): + return [] + return sorted( + p.stem + for p in self.workflows_dir.iterdir() + if p.suffix == ".json" + ) + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + @staticmethod + def _inject_lora( + graph: dict[str, Any], + lora: LoRAConfig, + ) -> dict[str, Any]: + """Insert a ``LoraLoader`` node (ID "8") and re-wire connections. + + The KSampler model input and both CLIPTextEncode clip inputs are + re-pointed through the LoRA node outputs. + + Returns a deep copy — the original graph is never mutated. + """ + graph = copy.deepcopy(graph) + graph["8"] = { + "class_type": "LoraLoader", + "inputs": { + "model": ["1", 0], + "clip": ["1", 1], + "lora_name": lora.filename, + "strength_model": lora.strength_model, + "strength_clip": lora.strength_clip, + }, + } + # KSampler: model ← LoraLoader output 0 + graph["5"]["inputs"]["model"] = ["8", 0] + # Positive CLIP encoder: clip ← LoraLoader output 1 + graph["2"]["inputs"]["clip"] = ["8", 1] + # Negative CLIP encoder: clip ← LoraLoader output 1 + graph["3"]["inputs"]["clip"] = ["8", 1] + return graph + + @staticmethod + def _apply_patches( + graph: dict[str, Any], + patches: dict[str, Any], + ) -> dict[str, Any]: + """Apply ``"."`` patches to a graph copy. + + Raises + ------ + ValueError + If a key does not contain exactly one dot, or references a + node that does not exist in the graph. + """ + graph = copy.deepcopy(graph) + for dotted_path, value in patches.items(): + parts = dotted_path.split(".", 1) + if len(parts) != 2: + raise ValueError( + f"Patch key must be '.', got: {dotted_path!r}" + ) + node_id, field = parts + if node_id not in graph: + raise ValueError( + f"Patch references non-existent node '{node_id}'. " + f"Available: {sorted(graph.keys())}" + ) + graph[node_id]["inputs"][field] = value + return graph diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md index 754ed04..0d3db40 100644 --- a/docs/ARCHITECTURE.md +++ b/docs/ARCHITECTURE.md @@ -1,7 +1,7 @@ # RAINGOD ComfyUI Integration — Architecture -> **Accuracy Notice**: This document reflects the currently implemented code. -> Features marked 🔲 Planned are not yet in the repository. +> **Accuracy Notice**: This document reflects the **currently implemented** code. +> Features marked 🔲 Planned are not yet in the repository. --- @@ -9,21 +9,21 @@ ``` RAINGOD AI Music Kit -┌──────────────────────────────────────────┐ -│ │ -│ ┌───────────┐ HTTP ┌───────────┐ │ -│ │ Client / ├──────►│ FastAPI │ │ -│ │ example.py │ │ Backend │ │ -│ └───────────┘ └──────┬───┘ │ -│ │ │ -│ HTTP │ │ -│ │ │ -│ ┌─────┴─────┐ │ -│ │ ComfyUI │ │ -│ │ :8188 │ │ -│ └───────────┘ │ -│ │ -└──────────────────────────────────────────┘ +┌────────────────────────────────────────────────────────────────┐ +│ │ +│ ┌────────────────┐ HTTP ┌──────────────────────────────┐ │ +│ │ switchboard/ ├───────►│ FastAPI Backend │ │ +│ │ index.html │ │ (rain_backend.py) │ │ +│ └────────────────┘ └─────────────┬────────────────┘ │ +│ │ │ +│ ┌────────────────┐ HTTP │ HTTP │ +│ │ examples/ ├─────────────────────►│ │ +│ │ generate_...py │ │ │ +│ └────────────────┘ ┌───────▼────────┐ │ +│ │ ComfyUI │ │ +│ │ :8188 │ │ +│ └────────────────┘ │ +└────────────────────────────────────────────────────────────────┘ ``` --- @@ -34,118 +34,204 @@ RAINGOD AI Music Kit |-----------|------|--------|-------| | **Config** | `backend/rain_backend_config.py` | ✅ Implemented | Dataclasses, GPU detection, all presets | | **ComfyUI Client** | `backend/comfyui_client.py` | ✅ Implemented | Circuit breaker, retry, dedup, polling | -| **FastAPI Backend** | `backend/rain_backend.py` | ✅ Implemented | 9 endpoints, Pydantic models | +| **FastAPI Backend** | `backend/rain_backend.py` | ✅ Implemented | 9 endpoints, Pydantic v2, lifespan | +| **Workflow Builder** | `backend/workflow_builder.py` | ✅ Implemented | txt2img, img2img, upscale, templates | +| **LoRA Manager** | `backend/lora_manager.py` | ✅ Implemented | Scan, load, chain, merge | | **Album Art Example** | `examples/generate_album_art.py` | ✅ Implemented | Full CLI, 5 style presets | | **Quickstart Script** | `scripts/rain_quickstart.sh` | ✅ Implemented | System checks, env setup | -| **Start All Script** | `scripts/start_all.sh` | ✅ Implemented | Service orchestration | +| **Start All Script** | `scripts/start_all.sh` | ✅ Implemented | Service orchestration, PID tracking | | **Docker** | `Dockerfile` | ✅ Implemented | Multi-stage, non-root user | | **docker-compose** | `docker-compose.yml` | ✅ Implemented | ComfyUI + backend services | -| **CI Pipeline** | `.github/workflows/ci.yml` | ✅ Implemented | lint/test/docker-build | -| **Workflow Templates** | `workflows/*.json` | 🔲 Planned | ComfyUI JSON templates | -| **Workflow Builder** | `backend/workflow_builder.py` | 🔲 Planned | Dynamic workflow construction | -| **LoRA Manager** | `backend/lora_manager.py` | 🔲 Planned | LoRA loading + blending | -| **Test Suite** | `tests/` | 🔲 Planned | pytest coverage for all endpoints | -| **Switchboard UI** | `switchboard/` | 🔲 Planned | HTML production dashboard | +| **CI Pipeline** | `.github/workflows/ci.yml` | ✅ Implemented | Lint/test/docker-build, SHA-pinned | +| **Switchboard UI** | `switchboard/index.html` | ✅ Implemented | Vanilla JS dashboard (6 panels) | +| **Workflow Templates** | `workflows/*.json` | ✅ Implemented | 5 JSON templates + README | +| **Test Suite** | `tests/` | ✅ Implemented | 184 tests (endpoints, CB, WB, LM) | | **Audio-Visual Sync** | `backend/av_sync.py` | 🔲 Planned | Beat detection integration | --- ## Backend REST API -| Method | Path | Description | -|--------|------|-------------| -| GET | `/` | Version and links | -| GET | `/health` | Backend + ComfyUI health status | -| GET | `/config` | Active configuration summary | -| GET | `/presets` | All resolution / sampler / LoRA presets | -| POST | `/generate` | Single image generation (async) | -| POST | `/batch-generate` | Batch image generation | -| GET | `/queue/status` | ComfyUI queue state | -| DELETE | `/queue/{prompt_id}` | Cancel a queued prompt | -| GET | `/outputs/{filename}` | Retrieve a generated file | +| Method | Path | Auth | Description | +|--------|------|------|-------------| +| GET | `/` | — | Version and links | +| GET | `/health` | — | Backend + ComfyUI health status | +| GET | `/config` | — | Active configuration summary | +| GET | `/presets` | — | All resolution / sampler / LoRA presets | +| POST | `/generate` | — | Single image generation (async, 202) | +| POST | `/batch-generate` | — | Batch image generation (202) | +| GET | `/queue/status` | — | ComfyUI queue state | +| DELETE | `/queue/{prompt_id}` | — | Cancel a queued prompt | +| GET | `/outputs/{filename}` | — | Retrieve a generated file | --- -## Request Flow +## Request Flow — Single Generation ``` Client │ - ├─ POST /generate + ├─ POST /generate {prompt, preset, resolution, …} │ │ - │ ├─ Validate request (Pydantic) - │ ├─ Resolve preset + resolution - │ ├─ Build txt2img workflow graph - │ ├─ ComfyUIClient.queue_prompt() - │ │ ├─ Circuit breaker check - │ │ ├─ SHA-256 dedup check - │ │ ├─ POST /prompt → ComfyUI + │ ├─ Pydantic validation (GenerateRequest) + │ ├─ Resolve SAMPLER_PRESETS[preset] + │ ├─ Resolve RESOLUTION_PRESETS[resolution] + │ ├─ WorkflowBuilder.build_txt2img(…) + │ │ ├─ Assemble node graph (7–8 nodes) + │ │ └─ Inject LoraLoader node if lora_style set + │ ├─ ComfyUIClient.queue_prompt(workflow) + │ │ ├─ CircuitBreaker.is_open() → fast-fail if OPEN + │ │ ├─ SHA-256 dedup check (skip if already queued) + │ │ ├─ POST /prompt → ComfyUI :8188 │ │ └─ Return prompt_id - │ └─ Return 202 GenerateResponse + │ └─ Return 202 GenerateResponse {prompt_id, job_id, …} │ - └─ GET /outputs/{filename} (poll until ready) + └─ GET /outputs/{filename} (poll until image is ready) ``` --- ## Circuit Breaker -The `ComfyUIClient` embeds a circuit breaker with three states: +`ComfyUIClient` embeds a three-state circuit breaker: -| State | Behaviour | -|-------|----------| -| `CLOSED` | Normal operation — all requests pass through | -| `OPEN` | ComfyUI unreachable; requests fail immediately with 503 | -| `HALF_OPEN` | Probe request sent after 60s cooldown; recovers on success | +| State | Behaviour | Transition | +|-------|----------|------------| +| `CLOSED` | All requests pass through | → OPEN after 5 consecutive failures | +| `OPEN` | All requests rejected immediately (503) | → HALF_OPEN after 60 s | +| `HALF_OPEN` | One probe request allowed | → CLOSED on success; OPEN on failure | -Failure threshold: **5 consecutive failures** before opening. +--- + +## Workflow Builder + +`WorkflowBuilder` (`backend/workflow_builder.py`) replaces the hardcoded +node-graph dict that previously lived inside `rain_backend.py`. + +### Public Methods + +| Method | Description | +|--------|-------------| +| `build_txt2img(…)` | Text-to-image workflow (7 nodes + optional LoRA) | +| `build_img2img(…)` | Image-to-image refinement (LoadImage + VAEEncode) | +| `build_upscale_pass(…)` | Append 2× upscale nodes to any existing workflow | +| `from_template(name, patches)` | Load a `workflows/.json` and apply patches | +| `list_templates()` | List available JSON template stems | + +### Node Numbering Convention + +| Node ID | Class | Notes | +|---------|-------|-------| +| `"1"` | CheckpointLoaderSimple | Model + CLIP + VAE source | +| `"2"` | CLIPTextEncode | Positive conditioning | +| `"3"` | CLIPTextEncode | Negative conditioning | +| `"4"` | EmptyLatentImage / LoadImage | Latent source | +| `"5"` | KSampler | Denoising step | +| `"6"` | VAEDecode | Latent → pixel | +| `"7"` | SaveImage | Output saver | +| `"8"` | LoraLoader | Optional single LoRA | +| `"10"` | VAEEncode | img2img only | +| `"20–22"` | Upscale chain | Optional upscale pass | +| `"100+"` | LoraLoader chain | Multi-LoRA via `LoRAManager.build_lora_chain()` | --- -## Configuration +## LoRA Manager + +`LoRAManager` (`backend/lora_manager.py`) provides: + +- **Registry seeding** from static `LORA_MAPPINGS` config +- **Filesystem scan** (`scan()`) for new `.safetensors` / `.ckpt` files +- **`get(name)` / `load(name)`** — resolve by logical name +- **`build_lora_chain(graph, loras)`** — inject N LoRAs in sequence +- **`build_loader_node(name)`** — single-LoRA `inputs` dict +- **`merge_configs(*loras, blend_mode)`** — average / max / sum-clamp blend +- **`as_dict()` / `summary()`** — JSON-serialisable export + +--- + +## Workflow Templates + +Five production-ready templates ship in `workflows/`: + +| File | Quality | Resolution | Steps | Notes | +|------|---------|-----------|-------|-------| +| `txt2img_draft.json` | Draft | 512×512 | 20 | Fastest preview | +| `txt2img_quality.json` | Standard | 1024×1024 | 40 | Default | +| `txt2img_final.json` | Final | 2048×2048 | 80 | Maximum quality | +| `img2img_refine.json` | Standard | source | 30 | 75% denoise | +| `txt2img_synthwave_lora.json` | Standard | 1024×1024 | 40 | Synthwave LoRA | -All runtime configuration lives in `backend/rain_backend_config.py` and is -loaded via the module-level `config` singleton. Environment variables -override defaults for `COMFYUI_HOST` and `COMFYUI_PORT`. +Load and patch templates via: +```python +wf = WorkflowBuilder().from_template("txt2img_quality", patches={ + "2.text": "my positive prompt", + "5.seed": 42, +}) +``` + +--- + +## Switchboard UI + +`switchboard/index.html` is a self-contained vanilla-JS dashboard with: + +| Panel | Purpose | +|-------|---------| +| **Dashboard** | Live health status, GPU tier, queue depth, activity log | +| **Generate** | Form with prompt, preset/resolution/LoRA chip selectors | +| **Queue** | Running + pending table with per-item cancel buttons | +| **Presets** | Tables of all sampler, resolution, and LoRA presets | +| **Config** | Edit API base URL; view live backend config JSON | +| **API Logs** | Timestamped log of all API calls made by the UI | + +--- + +## CORS Configuration + +The backend uses **environment-variable-driven** CORS. +`allow_origins=["*"]` is **no longer the default**: ```bash -export COMFYUI_HOST=192.168.1.100 -export COMFYUI_PORT=8188 +# Development +export ALLOWED_ORIGINS="http://localhost:3000" + +# Production (comma-separated) +export ALLOWED_ORIGINS="https://raingod.app,https://api.raingod.app" + uvicorn backend.rain_backend:app --host 0.0.0.0 --port 8000 ``` +If `ALLOWED_ORIGINS` is unset, the backend restricts CORS to +`http://localhost:3000` and logs a warning. + +--- + +## Security Notes + +- **CORS**: Now restricted via `ALLOWED_ORIGINS` env var (default: localhost:3000 only) +- **Path traversal**: `GET /outputs/{filename}` uses `Path.resolve()` comparison + to prevent `../` escapes outside the `outputs/` directory +- **Docker**: Runs as non-root user `raingod` (UID 1000) +- **Secrets**: `.env` files excluded by `.gitignore`; never commit credentials + --- ## Known Issues -### 12 MB GIF Binary in Repository +### 12 MB GIF in Git History -`DEVIANT2026_small.gif` (12,109,044 bytes) is committed directly to Git -history. This bloats every clone. **Recommended action:** +`DEVIANT2026_small.gif` (12,109,044 bytes) is committed directly to Git. +This bloats every clone. Migrate to Git LFS: ```bash -# 1. Install Git LFS git lfs install git lfs track "*.gif" git add .gitattributes - -# 2. Re-add the file (LFS will store it outside the pack) git rm --cached DEVIANT2026_small.gif git add DEVIANT2026_small.gif git commit -m "chore: migrate DEVIANT2026_small.gif to Git LFS" git push ``` -Alternatively, host the GIF on a CDN and replace the README `` tag -with a URL. - ---- - -## Security Notes - -- The FastAPI backend sets `allow_origins=["*"]` in CORS middleware. Restrict - this to known origins before public deployment. -- The Docker image runs as non-root user `raingod` (UID 1000). -- `GET /outputs/{filename}` includes path-traversal protection via - `Path.resolve()` comparison against the `outputs/` directory. -- Never commit `.env` files — the `.gitignore` excludes them. +Alternatively, host the GIF on a CDN and update the `` tag in `README.md`. diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..b76c623 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,16 @@ +[mypy] +python_version = 3.10 +strict = False +warn_return_any = True +warn_unused_configs = True +ignore_missing_imports = True +disallow_untyped_defs = True +disallow_incomplete_defs = True +check_untyped_defs = True +no_implicit_optional = True +warn_redundant_casts = True +warn_unused_ignores = True + +# Per-module overrides +[mypy-tests.*] +disallow_untyped_defs = False diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000..1d98092 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,34 @@ +# ruff.toml — RAINGOD linter configuration +# https://docs.astral.sh/ruff/configuration/ + +[tool.ruff] +target-version = "py310" +line-length = 100 + +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "I", # isort + "B", # flake8-bugbear + "UP", # pyupgrade + "N", # pep8-naming + "SIM", # flake8-simplify + "RUF", # ruff-specific rules +] +ignore = [ + "E501", # line length — handled by formatter + "B008", # do not perform function calls in default args (FastAPI pattern) + "UP007", # use X | Y syntax — breaks Python 3.9 compat in some places + "N818", # exception name should be named with Error suffix — LoRANotFoundError is fine +] + +[tool.ruff.lint.isort] +known-first-party = ["backend"] +force-single-line = false + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +line-ending = "lf" diff --git a/switchboard/index.html b/switchboard/index.html new file mode 100644 index 0000000..2ec1ca1 --- /dev/null +++ b/switchboard/index.html @@ -0,0 +1,1073 @@ + + + + + + RAINGOD Switchboard + + + +
+ + + + +
+ + Switchboard +
+
+
+ Checking… +
+
+
+ ComfyUI +
+ +
+ + + + + + + + + +
+ + + + +
+
Dashboard
+ +
+
+
Backend
+
+
+
+
ComfyUI
+
+
+
+
GPU Tier
+
+
+
+
Uptime
+
+
+
+
Queue Running
+
+
+
+
Queue Pending
+
+
+
+ +
+
Recent Activity
+
+ No activity yet. +
+
+
+ + + + +
+
Generate Image
+ +
+
+ +
+
Prompts
+
+ + +
+
+ + +
+
+ +
+
Sampler Preset
+
+
+ +
+
Resolution
+
+
+ +
+
LoRA Style (optional)
+
+
None
+
+
+ +
+
Options
+
+
+ + +
+
+ + +
+
+
+ + + + + +
+ Generated image + +
+
+ + + + +
+
Queue Status
+ +
+ + +
+ +
+
Running
+ + + + + + + +
Prompt IDStatusActions
No running jobs.
+
+ +
+
Pending
+ + + + + + + +
Prompt IDPositionActions
Queue is empty.
+
+
+ + + + +
+
Presets
+ +
+
Sampler Presets
+
Loading…
+
+ +
+
Resolution Presets
+
Loading…
+
+ +
+
LoRA Mappings
+
Loading…
+
+
+ + + + +
+
Active Configuration
+ +
+
Backend Config
+
+ +
+ + +
+
+
+ +
+
Runtime Config
+
Loading…
+
+
+ + + + +
+
API Logs
+ +
+ + +
+ +
RAINGOD Switchboard ready. API calls will appear here.\n
+
+ +
+
+ + + + + + + diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..3c9cc64 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# tests/ — RAINGOD pytest suite diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..5cf7acf --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,106 @@ +"""Pytest fixtures shared across the RAINGOD test suite. + +All environment variables that affect module-level config must be set +*before* any backend module is imported. This conftest.py is the correct +place to do that. +""" + +from __future__ import annotations + +import json +import os +from pathlib import Path +from unittest.mock import MagicMock + +import pytest +from fastapi.testclient import TestClient + +# --------------------------------------------------------------------------- +# Environment setup — must happen before any backend module is imported so +# CORS middleware and the config singleton see the correct test values. +# --------------------------------------------------------------------------- +os.environ.setdefault("ALLOWED_ORIGINS", "http://localhost:3000") +os.environ.setdefault("COMFYUI_HOST", "127.0.0.1") +os.environ.setdefault("COMFYUI_PORT", "18188") # port unlikely to be in use + + +def _make_mock_comfyui_client() -> MagicMock: + """Return a MagicMock that looks like a live ComfyUIClient.""" + m = MagicMock() + m.health_check.return_value = True + m.queue_prompt.return_value = "test-prompt-id-abc123" + m.get_queue_status.return_value = { + "queue_running": [], + "queue_pending": [], + } + m.cancel_prompt.return_value = True + return m + + +@pytest.fixture(scope="session") +def app(): + """Return the FastAPI application with the ComfyUI client pre-injected. + + The TestClient context manager triggers startup/shutdown lifespan events. + We launch the app inside a TestClient block so startup runs, then we + immediately replace the module-level ``client`` singleton with our mock + before any test makes a request. + """ + import backend.rain_backend as mod + from backend.rain_backend import app as _app + + with TestClient(_app) as _tc: + mod.client = _make_mock_comfyui_client() + yield _app + + +@pytest.fixture(scope="session") +def client(app) -> TestClient: + """Return a session-scoped synchronous TestClient.""" + import backend.rain_backend as mod + tc = TestClient(app) + mod.client = _make_mock_comfyui_client() + return tc + + +@pytest.fixture() +def tmp_workflows_dir(tmp_path: Path) -> Path: + """Return a temporary directory pre-seeded with a minimal workflow JSON.""" + wf = { + "1": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": "test.safetensors"}}, + "2": {"class_type": "CLIPTextEncode", "inputs": {"text": "test positive", "clip": ["1", 1]}}, + "3": {"class_type": "CLIPTextEncode", "inputs": {"text": "test negative", "clip": ["1", 1]}}, + "4": {"class_type": "EmptyLatentImage", "inputs": {"width": 512, "height": 512, "batch_size": 1}}, + "5": { + "class_type": "KSampler", + "inputs": { + "model": ["1", 0], "positive": ["2", 0], "negative": ["3", 0], + "latent_image": ["4", 0], "seed": 0, "steps": 20, "cfg": 7.0, + "sampler_name": "euler", "scheduler": "normal", "denoise": 1.0, + }, + }, + "6": {"class_type": "VAEDecode", "inputs": {"samples": ["5", 0], "vae": ["1", 2]}}, + "7": {"class_type": "SaveImage", "inputs": {"images": ["6", 0], "filename_prefix": "test"}}, + } + (tmp_path / "test_workflow.json").write_text(json.dumps(wf)) + return tmp_path + + +@pytest.fixture() +def tmp_lora_dir(tmp_path: Path) -> Path: + """Return a temporary loras/ directory with dummy LoRA files. + + Contents: + - ``custom_style_v1.safetensors`` — NOT in LORA_MAPPINGS, will be discovered + - ``brand_new_extra.safetensors`` — NOT in LORA_MAPPINGS, will be discovered + - ``ignored.txt`` — ignored (wrong extension) + + We deliberately do NOT include files that match existing LORA_MAPPINGS + filenames so that de-duplication logic is separately verifiable. + """ + lora_dir = tmp_path / "loras" + lora_dir.mkdir() + (lora_dir / "custom_style_v1.safetensors").write_bytes(b"\x00" * 200) + (lora_dir / "brand_new_extra.safetensors").write_bytes(b"\x00" * 150) + (lora_dir / "ignored.txt").write_text("not a lora") + return lora_dir diff --git a/tests/test_api_endpoints.py b/tests/test_api_endpoints.py new file mode 100644 index 0000000..a61caa9 --- /dev/null +++ b/tests/test_api_endpoints.py @@ -0,0 +1,297 @@ +"""Tests for RAINGOD FastAPI endpoint behaviour. + +All tests use a mocked ComfyUIClient so no real ComfyUI process is required. +""" + +from __future__ import annotations + +import pytest +from fastapi.testclient import TestClient + + +# --------------------------------------------------------------------------- +# GET / +# --------------------------------------------------------------------------- + +class TestRoot: + def test_returns_200(self, client: TestClient) -> None: + r = client.get("/") + assert r.status_code == 200 + + def test_body_contains_expected_keys(self, client: TestClient) -> None: + body = client.get("/").json() + assert "name" in body + assert "version" in body + assert "docs" in body + assert "health" in body + + def test_version_is_string(self, client: TestClient) -> None: + body = client.get("/").json() + assert isinstance(body["version"], str) + + +# --------------------------------------------------------------------------- +# GET /health +# --------------------------------------------------------------------------- + +class TestHealth: + def test_returns_200(self, client: TestClient) -> None: + r = client.get("/health") + assert r.status_code == 200 + + def test_backend_status_present(self, client: TestClient) -> None: + body = client.get("/health").json() + assert "status" in body + assert body["status"] in ("healthy", "degraded") + + def test_comfyui_available_field(self, client: TestClient) -> None: + body = client.get("/health").json() + assert "comfyui_available" in body + assert isinstance(body["comfyui_available"], bool) + + def test_gpu_tier_field(self, client: TestClient) -> None: + body = client.get("/health").json() + assert "gpu_tier" in body + assert body["gpu_tier"] in ("cpu", "low_vram", "mid_vram", "high_vram") + + def test_uptime_is_non_negative(self, client: TestClient) -> None: + body = client.get("/health").json() + assert body["uptime_seconds"] >= 0.0 + + +# --------------------------------------------------------------------------- +# GET /config +# --------------------------------------------------------------------------- + +class TestConfig: + def test_returns_200(self, client: TestClient) -> None: + assert client.get("/config").status_code == 200 + + def test_resolution_presets_present(self, client: TestClient) -> None: + body = client.get("/config").json() + assert "resolution_presets" in body + assert len(body["resolution_presets"]) > 0 + + def test_sampler_presets_present(self, client: TestClient) -> None: + body = client.get("/config").json() + assert "sampler_presets" in body + + def test_lora_styles_present(self, client: TestClient) -> None: + body = client.get("/config").json() + assert "lora_styles" in body + + +# --------------------------------------------------------------------------- +# GET /presets +# --------------------------------------------------------------------------- + +class TestPresets: + def test_returns_200(self, client: TestClient) -> None: + assert client.get("/presets").status_code == 200 + + def test_resolution_key(self, client: TestClient) -> None: + body = client.get("/presets").json() + assert "resolution" in body + # cover_art is the default; must exist + assert "cover_art" in body["resolution"] + + def test_samplers_key(self, client: TestClient) -> None: + body = client.get("/presets").json() + assert "samplers" in body + # quality is the default sampler preset + assert "quality" in body["samplers"] + + def test_sampler_has_required_fields(self, client: TestClient) -> None: + samplers = client.get("/presets").json()["samplers"] + for key, sampler in samplers.items(): + assert "steps" in sampler, f"Missing 'steps' in sampler {key!r}" + assert "cfg" in sampler, f"Missing 'cfg' in sampler {key!r}" + assert "sampler_name" in sampler + + def test_lora_key(self, client: TestClient) -> None: + body = client.get("/presets").json() + assert "lora" in body + + def test_quality_tiers_key(self, client: TestClient) -> None: + body = client.get("/presets").json() + assert "quality_tiers" in body + assert "standard" in body["quality_tiers"] + + +# --------------------------------------------------------------------------- +# POST /generate +# --------------------------------------------------------------------------- + +class TestGenerate: + _VALID_PAYLOAD = { + "prompt": "neon synthwave cityscape at dusk", + "negative_prompt": "blurry, low quality", + "preset": "quality", + "resolution": "cover_art", + } + + def test_returns_202(self, client: TestClient) -> None: + r = client.post("/generate", json=self._VALID_PAYLOAD) + assert r.status_code == 202 + + def test_response_has_prompt_id(self, client: TestClient) -> None: + body = client.post("/generate", json=self._VALID_PAYLOAD).json() + assert "prompt_id" in body + assert isinstance(body["prompt_id"], str) + + def test_response_has_job_id(self, client: TestClient) -> None: + body = client.post("/generate", json=self._VALID_PAYLOAD).json() + assert "job_id" in body + + def test_response_status_queued(self, client: TestClient) -> None: + body = client.post("/generate", json=self._VALID_PAYLOAD).json() + assert body["status"] == "queued" + + def test_response_preset_used(self, client: TestClient) -> None: + body = client.post("/generate", json=self._VALID_PAYLOAD).json() + assert body["preset_used"] == "quality" + + def test_response_resolution_used(self, client: TestClient) -> None: + body = client.post("/generate", json=self._VALID_PAYLOAD).json() + res = body["resolution_used"] + assert res["width"] == 1024 + assert res["height"] == 1024 + + def test_seed_injected_into_metadata(self, client: TestClient) -> None: + body = client.post("/generate", json=self._VALID_PAYLOAD).json() + assert "seed" in body["metadata"] + assert isinstance(body["metadata"]["seed"], int) + + def test_explicit_seed_is_respected(self, client: TestClient) -> None: + payload = {**self._VALID_PAYLOAD, "seed": 12345} + body = client.post("/generate", json=payload).json() + assert body["metadata"]["seed"] == 12345 + + def test_invalid_preset_returns_400(self, client: TestClient) -> None: + payload = {**self._VALID_PAYLOAD, "preset": "nonexistent_preset"} + r = client.post("/generate", json=payload) + assert r.status_code == 400 + assert "nonexistent_preset" in r.json()["detail"] + + def test_invalid_resolution_returns_400(self, client: TestClient) -> None: + payload = {**self._VALID_PAYLOAD, "resolution": "bad_resolution"} + r = client.post("/generate", json=payload) + assert r.status_code == 400 + + def test_empty_prompt_returns_422(self, client: TestClient) -> None: + payload = {**self._VALID_PAYLOAD, "prompt": ""} + r = client.post("/generate", json=payload) + assert r.status_code == 422 + + def test_missing_prompt_returns_422(self, client: TestClient) -> None: + r = client.post("/generate", json={"preset": "quality"}) + assert r.status_code == 422 + + def test_lora_style_is_accepted(self, client: TestClient) -> None: + payload = {**self._VALID_PAYLOAD, "lora_style": "synthwave"} + r = client.post("/generate", json=payload) + assert r.status_code == 202 + + def test_unknown_lora_is_silently_ignored(self, client: TestClient) -> None: + # Unknown LoRA should NOT cause an error — it is simply not applied + payload = {**self._VALID_PAYLOAD, "lora_style": "does_not_exist"} + r = client.post("/generate", json=payload) + assert r.status_code == 202 + + def test_all_sampler_presets_accepted(self, client: TestClient) -> None: + from backend.rain_backend_config import SAMPLER_PRESETS + for preset_name in SAMPLER_PRESETS: + payload = {**self._VALID_PAYLOAD, "preset": preset_name} + r = client.post("/generate", json=payload) + assert r.status_code == 202, f"Preset {preset_name!r} returned {r.status_code}" + + def test_all_resolution_presets_accepted(self, client: TestClient) -> None: + from backend.rain_backend_config import RESOLUTION_PRESETS + for res_name in RESOLUTION_PRESETS: + payload = {**self._VALID_PAYLOAD, "resolution": res_name} + r = client.post("/generate", json=payload) + assert r.status_code == 202, f"Resolution {res_name!r} returned {r.status_code}" + + +# --------------------------------------------------------------------------- +# POST /batch-generate +# --------------------------------------------------------------------------- + +class TestBatchGenerate: + _SINGLE = { + "prompt": "abstract neon art", + "preset": "fast", + "resolution": "thumbnail", + } + + def test_returns_202(self, client: TestClient) -> None: + r = client.post("/batch-generate", json={"requests": [self._SINGLE]}) + assert r.status_code == 202 + + def test_batch_id_present(self, client: TestClient) -> None: + body = client.post("/batch-generate", json={"requests": [self._SINGLE]}).json() + assert "batch_id" in body + + def test_total_count_correct(self, client: TestClient) -> None: + payload = {"requests": [self._SINGLE, self._SINGLE]} + body = client.post("/batch-generate", json=payload).json() + assert body["total"] == 2 + + def test_queued_count_correct(self, client: TestClient) -> None: + payload = {"requests": [self._SINGLE]} + body = client.post("/batch-generate", json=payload).json() + assert body["queued"] == 1 + assert body["errors"] == 0 + + def test_empty_requests_returns_422(self, client: TestClient) -> None: + r = client.post("/batch-generate", json={"requests": []}) + assert r.status_code == 422 + + +# --------------------------------------------------------------------------- +# GET /queue/status +# --------------------------------------------------------------------------- + +class TestQueueStatus: + def test_returns_200(self, client: TestClient) -> None: + assert client.get("/queue/status").status_code == 200 + + def test_has_queue_keys(self, client: TestClient) -> None: + body = client.get("/queue/status").json() + assert "queue_running" in body + assert "queue_pending" in body + + +# --------------------------------------------------------------------------- +# DELETE /queue/{prompt_id} +# --------------------------------------------------------------------------- + +class TestCancelQueue: + def test_cancel_returns_200(self, client: TestClient) -> None: + r = client.delete("/queue/some-prompt-id") + assert r.status_code == 200 + + def test_response_has_cancelled_field(self, client: TestClient) -> None: + body = client.delete("/queue/some-prompt-id").json() + assert "cancelled" in body + assert isinstance(body["cancelled"], bool) + + def test_response_echoes_prompt_id(self, client: TestClient) -> None: + body = client.delete("/queue/my-test-id").json() + assert body["prompt_id"] == "my-test-id" + + +# --------------------------------------------------------------------------- +# GET /outputs/{filename} +# --------------------------------------------------------------------------- + +class TestOutputsEndpoint: + def test_missing_file_returns_404(self, client: TestClient) -> None: + r = client.get("/outputs/does_not_exist.png") + assert r.status_code == 404 + + def test_path_traversal_attempt_rejected(self, client: TestClient) -> None: + # FastAPI's Path validator rejects literal slashes in path segments + r = client.get("/outputs/../../etc/passwd") + # Should be 404 (file not found) or 422 (validation) or 403 (forbidden) + assert r.status_code in (400, 403, 404, 422) diff --git a/tests/test_circuit_breaker.py b/tests/test_circuit_breaker.py new file mode 100644 index 0000000..06ccaca --- /dev/null +++ b/tests/test_circuit_breaker.py @@ -0,0 +1,214 @@ +"""Tests for the CircuitBreaker and ComfyUIClient in comfyui_client.py.""" + +from __future__ import annotations + +import time +from unittest.mock import MagicMock, patch + +import pytest +import requests + +from backend.comfyui_client import CircuitBreaker, CircuitState, ComfyUIClient +from backend.rain_backend_config import ComfyUIConfig + + +# --------------------------------------------------------------------------- +# CircuitBreaker unit tests +# --------------------------------------------------------------------------- + +class TestCircuitBreakerInitialState: + def test_starts_closed(self) -> None: + cb = CircuitBreaker() + assert cb.state == CircuitState.CLOSED + + def test_is_open_returns_false_when_closed(self) -> None: + cb = CircuitBreaker() + assert cb.is_open() is False + + +class TestCircuitBreakerFailures: + def test_single_failure_stays_closed(self) -> None: + cb = CircuitBreaker(failure_threshold=5) + cb.record_failure() + assert cb.state == CircuitState.CLOSED + + def test_opens_at_threshold(self) -> None: + cb = CircuitBreaker(failure_threshold=3) + for _ in range(3): + cb.record_failure() + assert cb.state == CircuitState.OPEN + + def test_is_open_returns_true_when_open(self) -> None: + cb = CircuitBreaker(failure_threshold=1) + cb.record_failure() + assert cb.is_open() is True + + def test_success_resets_to_closed(self) -> None: + cb = CircuitBreaker(failure_threshold=1) + cb.record_failure() + cb.record_success() + assert cb.state == CircuitState.CLOSED + + def test_success_resets_failure_count(self) -> None: + cb = CircuitBreaker(failure_threshold=5) + for _ in range(4): + cb.record_failure() + cb.record_success() + # One more failure should NOT open (count was reset) + cb.record_failure() + assert cb.state == CircuitState.CLOSED + + +class TestCircuitBreakerHalfOpen: + def test_transitions_to_half_open_after_timeout(self) -> None: + cb = CircuitBreaker(failure_threshold=1, recovery_timeout=0.05) + cb.record_failure() + assert cb.state == CircuitState.OPEN + time.sleep(0.1) + assert cb.state == CircuitState.HALF_OPEN + + def test_does_not_transition_before_timeout(self) -> None: + cb = CircuitBreaker(failure_threshold=1, recovery_timeout=3600.0) + cb.record_failure() + assert cb.state == CircuitState.OPEN # timeout far in the future + + +# --------------------------------------------------------------------------- +# ComfyUIClient unit tests +# --------------------------------------------------------------------------- + +class TestComfyUIClientHealthCheck: + def test_health_check_true_on_200(self) -> None: + cfg = ComfyUIConfig(host="127.0.0.1", port=19999) + client = ComfyUIClient(comfyui_config=cfg, max_retries=0) + mock_resp = MagicMock() + mock_resp.status_code = 200 + with patch.object(client._session, "get", return_value=mock_resp): + assert client.health_check() is True + + def test_health_check_false_on_connection_error(self) -> None: + cfg = ComfyUIConfig(host="127.0.0.1", port=19999) + client = ComfyUIClient(comfyui_config=cfg, max_retries=0) + with patch.object( + client._session, "get", + side_effect=requests.exceptions.ConnectionError("refused"), + ): + assert client.health_check() is False + + +class TestComfyUIClientQueuePrompt: + def _make_mock_response(self, prompt_id: str) -> MagicMock: + resp = MagicMock() + resp.status_code = 200 + resp.json.return_value = {"prompt_id": prompt_id} + return resp + + def test_returns_prompt_id(self) -> None: + cfg = ComfyUIConfig(host="127.0.0.1", port=19999) + client = ComfyUIClient(comfyui_config=cfg, max_retries=0) + mock_resp = self._make_mock_response("abc-123") + with patch.object(client._session, "request", return_value=mock_resp): + result = client.queue_prompt({"1": {"class_type": "Test", "inputs": {}}}) + assert result == "abc-123" + + def test_dedup_cache_returns_same_id(self) -> None: + cfg = ComfyUIConfig(host="127.0.0.1", port=19999) + client = ComfyUIClient(comfyui_config=cfg, max_retries=0) + workflow = {"1": {"class_type": "Test", "inputs": {}}} + mock_resp = self._make_mock_response("dedup-id") + + with patch.object(client._session, "request", return_value=mock_resp) as mock_req: + id1 = client.queue_prompt(workflow, deduplicate=True) + id2 = client.queue_prompt(workflow, deduplicate=True) + + assert id1 == id2 == "dedup-id" + # HTTP request should only be made once (second call hits cache) + assert mock_req.call_count == 1 + + def test_dedup_disabled_sends_two_requests(self) -> None: + cfg = ComfyUIConfig(host="127.0.0.1", port=19999) + client = ComfyUIClient(comfyui_config=cfg, max_retries=0) + workflow = {"1": {"class_type": "Test", "inputs": {}}} + mock_resp = self._make_mock_response("no-dedup-id") + + with patch.object(client._session, "request", return_value=mock_resp) as mock_req: + client.queue_prompt(workflow, deduplicate=False) + client.queue_prompt(workflow, deduplicate=False) + + assert mock_req.call_count == 2 + + def test_raises_when_circuit_open(self) -> None: + cfg = ComfyUIConfig(host="127.0.0.1", port=19999) + cb = CircuitBreaker(failure_threshold=1) + cb.record_failure() # open the circuit + client = ComfyUIClient(comfyui_config=cfg, circuit_breaker=cb, max_retries=0) + + with pytest.raises(RuntimeError, match="circuit breaker is OPEN"): + client.queue_prompt({}) + + +class TestComfyUIClientRetry: + def test_retries_on_connection_error(self) -> None: + cfg = ComfyUIConfig(host="127.0.0.1", port=19999, timeout=1) + client = ComfyUIClient( + comfyui_config=cfg, + max_retries=2, + retry_delay_base=0.01, # very fast for tests + ) + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {"prompt_id": "retry-id"} + + side_effects = [ + requests.exceptions.ConnectionError("fail 1"), + requests.exceptions.ConnectionError("fail 2"), + mock_resp, + ] + with patch.object(client._session, "request", side_effect=side_effects) as mock_req: + result = client.queue_prompt({}, deduplicate=False) + + assert result == "retry-id" + assert mock_req.call_count == 3 # 2 failures + 1 success + + def test_raises_after_max_retries_exceeded(self) -> None: + cfg = ComfyUIConfig(host="127.0.0.1", port=19999, timeout=1) + client = ComfyUIClient( + comfyui_config=cfg, + max_retries=1, + retry_delay_base=0.01, + ) + with patch.object( + client._session, + "request", + side_effect=requests.exceptions.ConnectionError("always fails"), + ): + with pytest.raises(requests.exceptions.ConnectionError): + client.queue_prompt({}, deduplicate=False) + + +class TestComfyUIClientDedup: + def test_clear_dedup_cache(self) -> None: + cfg = ComfyUIConfig(host="127.0.0.1", port=19999) + client = ComfyUIClient(comfyui_config=cfg, max_retries=0) + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {"prompt_id": "id-001"} + + workflow = {"node": "value"} + with patch.object(client._session, "request", return_value=mock_resp): + client.queue_prompt(workflow, deduplicate=True) + + assert len(client._dedup_cache) == 1 + client.clear_dedup_cache() + assert len(client._dedup_cache) == 0 + + def test_hash_workflow_is_deterministic(self) -> None: + wf = {"a": 1, "b": {"c": 3}} + h1 = ComfyUIClient._hash_workflow(wf) + h2 = ComfyUIClient._hash_workflow(wf) + assert h1 == h2 + + def test_different_workflows_have_different_hashes(self) -> None: + wf1 = {"seed": 1} + wf2 = {"seed": 2} + assert ComfyUIClient._hash_workflow(wf1) != ComfyUIClient._hash_workflow(wf2) diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py new file mode 100644 index 0000000..039d9bf --- /dev/null +++ b/tests/test_endpoints.py @@ -0,0 +1,292 @@ +"""Tests for the FastAPI endpoints in rain_backend.py. + +Covers: +- GET / +- GET /health +- GET /config +- GET /presets +- POST /generate (valid, invalid preset, invalid resolution) +- POST /batch-generate +- GET /queue/status +- DEL /queue/{prompt_id} +- GET /outputs/{filename} (404 path, path-traversal attempt) +""" + +from __future__ import annotations + +import json +from pathlib import Path +from unittest.mock import patch, MagicMock + +import pytest + + +# --------------------------------------------------------------------------- +# Root +# --------------------------------------------------------------------------- + +class TestRoot: + def test_returns_200(self, client): + r = client.get("/") + assert r.status_code == 200 + + def test_has_name_and_version(self, client): + body = r = client.get("/").json() + assert "name" in body + assert "version" in body + assert body["docs"] == "/docs" + + +# --------------------------------------------------------------------------- +# Health +# --------------------------------------------------------------------------- + +class TestHealth: + def test_health_ok(self, client): + r = client.get("/health") + assert r.status_code == 200 + data = r.json() + assert data["status"] in ("healthy", "degraded") + assert isinstance(data["comfyui_available"], bool) + assert "uptime_seconds" in data + assert "gpu_tier" in data + + def test_health_degraded_when_comfyui_down(self, app, client): + """When ComfyUI health check returns False, status should be 'degraded'.""" + import backend.rain_backend as mod + original_hc = mod.client.health_check.return_value + mod.client.health_check.return_value = False + try: + r = client.get("/health") + assert r.status_code == 200 + assert r.json()["status"] == "degraded" + finally: + mod.client.health_check.return_value = original_hc + + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + +class TestConfig: + def test_returns_expected_keys(self, client): + r = client.get("/config") + assert r.status_code == 200 + body = r.json() + for key in ("comfyui_url", "gpu_tier", "resolution_presets", "sampler_presets", "lora_styles"): + assert key in body, f"Missing key: {key}" + + def test_resolution_presets_is_list(self, client): + body = client.get("/config").json() + assert isinstance(body["resolution_presets"], list) + assert len(body["resolution_presets"]) > 0 + + +# --------------------------------------------------------------------------- +# Presets +# --------------------------------------------------------------------------- + +class TestPresets: + def test_returns_200(self, client): + r = client.get("/presets") + assert r.status_code == 200 + + def test_required_top_level_keys(self, client): + body = client.get("/presets").json() + assert {"resolution", "samplers", "lora", "quality_tiers"} <= body.keys() + + def test_sampler_presets_have_required_fields(self, client): + samplers = client.get("/presets").json()["samplers"] + for name, preset in samplers.items(): + for field in ("steps", "cfg", "sampler_name", "scheduler"): + assert field in preset, f"Sampler '{name}' missing field '{field}'" + + def test_resolution_presets_have_width_height(self, client): + resolutions = client.get("/presets").json()["resolution"] + for name, res in resolutions.items(): + assert "width" in res and "height" in res, f"Resolution '{name}' missing dimensions" + + def test_lora_presets_have_required_fields(self, client): + loras = client.get("/presets").json()["lora"] + for name, lora in loras.items(): + for field in ("filename", "strength_model"): + assert field in lora, f"LoRA '{name}' missing field '{field}'" + + +# --------------------------------------------------------------------------- +# Generate +# --------------------------------------------------------------------------- + +class TestGenerate: + _VALID_PAYLOAD = { + "prompt": "glowing neon city at night, synthwave aesthetic", + "negative_prompt": "blurry, low quality", + "preset": "quality", + "resolution": "cover_art", + } + + def test_valid_request_returns_202(self, client): + r = client.post("/generate", json=self._VALID_PAYLOAD) + assert r.status_code == 202 + + def test_response_has_required_fields(self, client): + body = client.post("/generate", json=self._VALID_PAYLOAD).json() + for field in ("prompt_id", "job_id", "status", "estimated_time", + "preset_used", "resolution_used"): + assert field in body, f"Response missing field '{field}'" + + def test_status_is_queued(self, client): + body = client.post("/generate", json=self._VALID_PAYLOAD).json() + assert body["status"] == "queued" + + def test_prompt_id_matches_mock(self, client): + body = client.post("/generate", json=self._VALID_PAYLOAD).json() + assert body["prompt_id"] == "test-prompt-id-abc123" + + def test_invalid_preset_returns_400(self, client): + payload = {**self._VALID_PAYLOAD, "preset": "nonexistent_preset"} + r = client.post("/generate", json=payload) + assert r.status_code == 400 + assert "preset" in r.json()["detail"].lower() + + def test_invalid_resolution_returns_400(self, client): + payload = {**self._VALID_PAYLOAD, "resolution": "nonexistent_resolution"} + r = client.post("/generate", json=payload) + assert r.status_code == 400 + assert "resolution" in r.json()["detail"].lower() + + def test_empty_prompt_returns_422(self, client): + r = client.post("/generate", json={**self._VALID_PAYLOAD, "prompt": ""}) + assert r.status_code == 422 + + def test_optional_seed_is_forwarded_in_metadata(self, client): + payload = {**self._VALID_PAYLOAD, "seed": 42} + body = client.post("/generate", json=payload).json() + assert body["metadata"]["seed"] == 42 + + def test_with_valid_lora_style(self, client): + payload = {**self._VALID_PAYLOAD, "lora_style": "synthwave"} + r = client.post("/generate", json=payload) + assert r.status_code == 202 + + def test_with_all_resolution_presets(self, client): + from backend.rain_backend_config import RESOLUTION_PRESETS + for res_key in RESOLUTION_PRESETS: + payload = {**self._VALID_PAYLOAD, "resolution": res_key} + r = client.post("/generate", json=payload) + assert r.status_code == 202, f"Failed for resolution '{res_key}'" + + def test_with_all_sampler_presets(self, client): + from backend.rain_backend_config import SAMPLER_PRESETS + for preset_key in SAMPLER_PRESETS: + payload = {**self._VALID_PAYLOAD, "preset": preset_key} + r = client.post("/generate", json=payload) + assert r.status_code == 202, f"Failed for preset '{preset_key}'" + + +# --------------------------------------------------------------------------- +# Batch Generate +# --------------------------------------------------------------------------- + +class TestBatchGenerate: + _BASE_REQ = { + "prompt": "abstract album art", + "preset": "fast", + "resolution": "thumbnail", + } + + def test_valid_batch_returns_202(self, client): + payload = {"requests": [self._BASE_REQ, self._BASE_REQ]} + r = client.post("/batch-generate", json=payload) + assert r.status_code == 202 + + def test_batch_response_fields(self, client): + payload = {"requests": [self._BASE_REQ]} + body = client.post("/batch-generate", json=payload).json() + for field in ("batch_id", "total", "queued", "errors", "results"): + assert field in body + + def test_batch_all_queued(self, client): + n = 3 + payload = {"requests": [self._BASE_REQ] * n} + body = client.post("/batch-generate", json=payload).json() + assert body["total"] == n + assert body["queued"] == n + assert body["errors"] == 0 + + def test_empty_batch_returns_422(self, client): + r = client.post("/batch-generate", json={"requests": []}) + assert r.status_code == 422 + + +# --------------------------------------------------------------------------- +# Queue Status +# --------------------------------------------------------------------------- + +class TestQueueStatus: + def test_returns_200(self, client): + r = client.get("/queue/status") + assert r.status_code == 200 + + def test_has_queue_keys(self, client): + body = r = client.get("/queue/status").json() + assert "queue_running" in body or "queue_pending" in body + + +# --------------------------------------------------------------------------- +# Queue Cancel +# --------------------------------------------------------------------------- + +class TestQueueCancel: + def test_cancel_returns_cancelled_true(self, client): + r = client.delete("/queue/test-prompt-id-abc123") + assert r.status_code == 200 + body = r.json() + assert body["cancelled"] is True + assert body["prompt_id"] == "test-prompt-id-abc123" + + +# --------------------------------------------------------------------------- +# Outputs endpoint +# --------------------------------------------------------------------------- + +class TestOutputs: + def test_missing_file_returns_404(self, client): + r = client.get("/outputs/does_not_exist.png") + assert r.status_code == 404 + + def test_path_traversal_blocked(self, client, tmp_path): + """Ensure ../secret patterns cannot escape the outputs directory.""" + r = client.get("/outputs/..%2F..%2Fetc%2Fpasswd") + # Should be 400 (validation) or 404, never 200 + assert r.status_code in (400, 404, 422) + + def test_valid_file_served(self, client, tmp_path, monkeypatch): + """A file that exists in the outputs dir should be served.""" + from pathlib import Path + import backend.rain_backend as mod + + # Temporarily redirect OUTPUT_DIR to a temp directory + original_dir = mod.OUTPUT_DIR + test_output_dir = tmp_path / "outputs" + test_output_dir.mkdir() + test_file = test_output_dir / "test_image.png" + # Write a minimal 1×1 PNG (89 bytes) + test_file.write_bytes( + b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01" + b"\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00" + b"\x00\x0cIDATx\x9cc\xf8\x0f\x00\x00\x01\x01\x00\x05\x18" + b"\xd8N\x00\x00\x00\x00IEND\xaeB`\x82" + ) + monkeypatch.setattr(mod, "OUTPUT_DIR", test_output_dir) + + r = client.get("/outputs/test_image.png") + assert r.status_code == 200 + + monkeypatch.setattr(mod, "OUTPUT_DIR", original_dir) + + +# --------------------------------------------------------------------------- +# Helper import alias to avoid circular fixture issues +# --------------------------------------------------------------------------- +from fastapi.testclient import TestClient as TestClientImport diff --git a/tests/test_lora_manager.py b/tests/test_lora_manager.py new file mode 100644 index 0000000..7c42272 --- /dev/null +++ b/tests/test_lora_manager.py @@ -0,0 +1,309 @@ +"""Tests for backend/lora_manager.py.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from backend.lora_manager import LoRAManager, LoRANotFoundError, _stem_to_slug +from backend.rain_backend_config import LoRAConfig, LORA_MAPPINGS + + +# --------------------------------------------------------------------------- +# _stem_to_slug (module-level helper) +# --------------------------------------------------------------------------- + +class TestStemToSlug: + def test_lowercase_passthrough(self) -> None: + assert _stem_to_slug("synthwave_v2") == "synthwave_v2" + + def test_uppercase_lowercased(self) -> None: + assert _stem_to_slug("SynthWave") == "synthwave" + + def test_spaces_replaced_with_underscores(self) -> None: + assert _stem_to_slug("my lora file") == "my_lora_file" + + def test_special_chars_removed(self) -> None: + assert _stem_to_slug("My LoRA File (v3)") == "my_lora_file_v3" + + def test_leading_trailing_underscores_stripped(self) -> None: + result = _stem_to_slug("---test---") + assert not result.startswith("_") + assert not result.endswith("_") + + def test_numbers_preserved(self) -> None: + assert "2" in _stem_to_slug("lora_v2") + + +# --------------------------------------------------------------------------- +# LoRAManager construction — seeded from static config +# --------------------------------------------------------------------------- + +class TestLoRAManagerInit: + def test_registry_seeded_from_config(self) -> None: + mgr = LoRAManager() + for key in LORA_MAPPINGS: + assert key in mgr, f"Config key {key!r} missing from registry" + + def test_len_reflects_config(self) -> None: + mgr = LoRAManager() + assert len(mgr) >= len(LORA_MAPPINGS) + + def test_contains_operator(self) -> None: + mgr = LoRAManager() + assert "synthwave" in mgr + + def test_repr_contains_class_name(self) -> None: + mgr = LoRAManager() + assert "LoRAManager" in repr(mgr) + + +# --------------------------------------------------------------------------- +# LoRAManager.scan +# --------------------------------------------------------------------------- + +class TestLoRAManagerScan: + def test_scan_returns_sorted_list(self, tmp_lora_dir: Path) -> None: + mgr = LoRAManager(lora_dir=tmp_lora_dir) + names = mgr.scan() + assert names == sorted(names) + + def test_scan_discovers_safetensors_files(self, tmp_lora_dir: Path) -> None: + mgr = LoRAManager(lora_dir=tmp_lora_dir) + names = mgr.scan() + # tmp_lora_dir has "synthwave_v2.safetensors" and "custom_style_v1.safetensors" + # synthwave_v2.safetensors is already registered via LORA_MAPPINGS (under "synthwave") + # custom_style_v1 is a new file — it should be discovered + assert "custom_style_v1" in names + + def test_scan_includes_config_entries(self, tmp_lora_dir: Path) -> None: + mgr = LoRAManager(lora_dir=tmp_lora_dir) + names = mgr.scan() + for key in LORA_MAPPINGS: + assert key in names + + def test_scan_nonexistent_dir_returns_config_entries(self, tmp_path: Path) -> None: + mgr = LoRAManager(lora_dir=tmp_path / "nonexistent") + names = mgr.scan() + # Config entries still visible + for key in LORA_MAPPINGS: + assert key in names + + def test_scan_discovers_new_files(self, tmp_lora_dir: Path) -> None: + mgr = LoRAManager(lora_dir=tmp_lora_dir) + mgr.scan() + (tmp_lora_dir / "brand_new_lora.safetensors").write_bytes(b"\x00") + names = mgr.scan() + assert "brand_new_lora" in names + + +# --------------------------------------------------------------------------- +# LoRAManager.get / load +# --------------------------------------------------------------------------- + +class TestLoRAManagerGetLoad: + def test_get_known_entry_returns_config(self) -> None: + mgr = LoRAManager() + cfg = mgr.get("synthwave") + assert cfg is not None + assert isinstance(cfg, LoRAConfig) + + def test_get_unknown_entry_returns_none(self) -> None: + mgr = LoRAManager() + assert mgr.get("zzz_does_not_exist_xyz") is None + + def test_load_known_entry_returns_config(self) -> None: + mgr = LoRAManager() + cfg = mgr.load("synthwave") + assert isinstance(cfg, LoRAConfig) + + def test_load_unknown_raises_lora_not_found_error(self) -> None: + mgr = LoRAManager() + with pytest.raises(LoRANotFoundError): + mgr.load("zzz_totally_unknown") + + +# --------------------------------------------------------------------------- +# LoRAManager.available +# --------------------------------------------------------------------------- + +class TestLoRAManagerAvailable: + def test_available_returns_sorted_list(self) -> None: + mgr = LoRAManager() + names = mgr.available() + assert names == sorted(names) + + def test_available_includes_all_config_keys(self) -> None: + mgr = LoRAManager() + names = set(mgr.available()) + for key in LORA_MAPPINGS: + assert key in names + + +# --------------------------------------------------------------------------- +# LoRAManager.build_loader_node +# --------------------------------------------------------------------------- + +class TestLoRAManagerBuildLoaderNode: + def test_returns_dict_with_required_keys(self) -> None: + mgr = LoRAManager() + node = mgr.build_loader_node("synthwave") + assert "lora_name" in node + assert "strength_model" in node + assert "strength_clip" in node + + def test_raises_for_unknown_lora(self) -> None: + mgr = LoRAManager() + with pytest.raises(LoRANotFoundError): + mgr.build_loader_node("totally_unknown_xyz") + + def test_strength_overrides_applied(self) -> None: + mgr = LoRAManager() + node = mgr.build_loader_node("synthwave", strength_model=0.3, strength_clip=0.4) + assert node["strength_model"] == 0.3 + assert node["strength_clip"] == 0.4 + + def test_default_strength_from_registry(self) -> None: + mgr = LoRAManager() + cfg = mgr.get("synthwave") + node = mgr.build_loader_node("synthwave") + assert node["strength_model"] == cfg.strength_model + assert node["strength_clip"] == cfg.strength_clip + + +# --------------------------------------------------------------------------- +# LoRAManager.build_lora_chain +# --------------------------------------------------------------------------- + +class TestBuildLoraChain: + def _base_graph(self) -> dict: + """Minimal valid workflow graph.""" + return { + "1": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": "test.safetensors"}}, + "2": {"class_type": "CLIPTextEncode", "inputs": {"text": "pos", "clip": ["1", 1]}}, + "3": {"class_type": "CLIPTextEncode", "inputs": {"text": "neg", "clip": ["1", 1]}}, + "5": {"class_type": "KSampler", "inputs": {"model": ["1", 0]}}, + } + + def test_injects_single_lora(self) -> None: + mgr = LoRAManager() + graph = mgr.build_lora_chain(self._base_graph(), [("synthwave", 0.8, 0.8)]) + assert "100" in graph + assert graph["100"]["class_type"] == "LoraLoader" + + def test_ksampler_rewired_to_lora_output(self) -> None: + mgr = LoRAManager() + graph = mgr.build_lora_chain(self._base_graph(), [("synthwave", 0.8, 0.8)]) + assert graph["5"]["inputs"]["model"] == ["100", 0] + + def test_clip_encoders_rewired(self) -> None: + mgr = LoRAManager() + graph = mgr.build_lora_chain(self._base_graph(), [("synthwave", 0.8, 0.8)]) + assert graph["2"]["inputs"]["clip"] == ["100", 1] + assert graph["3"]["inputs"]["clip"] == ["100", 1] + + def test_two_loras_chain(self) -> None: + mgr = LoRAManager() + graph = mgr.build_lora_chain( + self._base_graph(), + [("synthwave", 0.8, 0.8), ("cyberpunk", 0.6, 0.6)], + ) + assert "100" in graph + assert "101" in graph + # KSampler should use the last LoRA output + assert graph["5"]["inputs"]["model"] == ["101", 0] + + def test_empty_loras_raises_value_error(self) -> None: + mgr = LoRAManager() + with pytest.raises(ValueError, match="empty"): + mgr.build_lora_chain(self._base_graph(), []) + + def test_base_graph_not_mutated(self) -> None: + mgr = LoRAManager() + base = self._base_graph() + original_model = base["5"]["inputs"]["model"] + mgr.build_lora_chain(base, [("synthwave", 0.8, 0.8)]) + assert base["5"]["inputs"]["model"] == original_model + + +# --------------------------------------------------------------------------- +# LoRAManager.merge_configs +# --------------------------------------------------------------------------- + +class TestMergeConfigs: + def _cfg(self, filename: str, sm: float = 0.8, sc: float = 0.8) -> LoRAConfig: + return LoRAConfig(filename=filename, strength_model=sm, strength_clip=sc) + + def test_average_blend(self) -> None: + a = self._cfg("a.safetensors", sm=0.6, sc=0.4) + b = self._cfg("b.safetensors", sm=1.0, sc=0.8) + merged = LoRAManager.merge_configs(a, b, blend_mode="average") + assert merged.strength_model == pytest.approx(0.8) + assert merged.strength_clip == pytest.approx(0.6) + + def test_max_blend(self) -> None: + a = self._cfg("a.safetensors", sm=0.3, sc=0.3) + b = self._cfg("b.safetensors", sm=0.9, sc=0.7) + merged = LoRAManager.merge_configs(a, b, blend_mode="max") + assert merged.strength_model == 0.9 + assert merged.strength_clip == 0.7 + + def test_sum_clamp_blend(self) -> None: + a = self._cfg("a.safetensors", sm=0.7, sc=0.7) + b = self._cfg("b.safetensors", sm=0.7, sc=0.7) + merged = LoRAManager.merge_configs(a, b, blend_mode="sum_clamp") + assert merged.strength_model == 1.0 + assert merged.strength_clip == 1.0 + + def test_raises_with_single_config(self) -> None: + with pytest.raises(ValueError, match="at least 2"): + LoRAManager.merge_configs(self._cfg("a.safetensors")) + + def test_raises_with_unknown_blend_mode(self) -> None: + a, b = self._cfg("a.safetensors"), self._cfg("b.safetensors") + with pytest.raises(ValueError, match="blend_mode"): + LoRAManager.merge_configs(a, b, blend_mode="unknown_mode") + + def test_combined_filename_contains_inputs(self) -> None: + a = self._cfg("alpha.safetensors") + b = self._cfg("beta.safetensors") + merged = LoRAManager.merge_configs(a, b) + assert "alpha.safetensors" in merged.filename + assert "beta.safetensors" in merged.filename + + def test_three_configs_average(self) -> None: + configs = [self._cfg(f"{i}.sf", sm=float(i) / 10) for i in range(1, 4)] + merged = LoRAManager.merge_configs(*configs, blend_mode="average") + expected = (0.1 + 0.2 + 0.3) / 3 + assert merged.strength_model == pytest.approx(expected, abs=1e-4) + + +# --------------------------------------------------------------------------- +# LoRAManager.as_dict / summary +# --------------------------------------------------------------------------- + +class TestLoRAManagerSummary: + def test_as_dict_contains_all_keys(self) -> None: + mgr = LoRAManager() + d = mgr.as_dict() + for key in LORA_MAPPINGS: + assert key in d + + def test_as_dict_entry_has_required_fields(self) -> None: + mgr = LoRAManager() + for name, entry in mgr.as_dict().items(): + assert "filename" in entry, f"Missing 'filename' in {name}" + assert "strength_model" in entry + assert "strength_clip" in entry + + def test_summary_has_required_keys(self) -> None: + mgr = LoRAManager() + s = mgr.summary() + assert "lora_dir" in s + assert "total" in s + assert "loras" in s + + def test_summary_total_matches_len(self) -> None: + mgr = LoRAManager() + assert mgr.summary()["total"] == len(mgr) diff --git a/tests/test_workflow_builder.py b/tests/test_workflow_builder.py new file mode 100644 index 0000000..ab1dd81 --- /dev/null +++ b/tests/test_workflow_builder.py @@ -0,0 +1,319 @@ +"""Tests for backend/workflow_builder.py.""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +from backend.rain_backend_config import LoRAConfig +from backend.workflow_builder import WorkflowBuilder + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_builder(tmp_path: Path | None = None) -> WorkflowBuilder: + return WorkflowBuilder( + checkpoint="test_model.safetensors", + workflows_dir=tmp_path, + ) + + +# --------------------------------------------------------------------------- +# build_txt2img +# --------------------------------------------------------------------------- + +class TestBuildTxt2img: + def _build(self, **kwargs) -> dict: + b = _make_builder() + defaults = dict( + positive="test prompt", + negative="bad", + width=512, + height=512, + steps=20, + cfg=7.0, + sampler_name="euler", + scheduler="normal", + seed=42, + ) + defaults.update(kwargs) + return b.build_txt2img(**defaults) + + def test_returns_dict(self) -> None: + assert isinstance(self._build(), dict) + + def test_has_all_required_nodes(self) -> None: + graph = self._build() + for node_id in ("1", "2", "3", "4", "5", "6", "7"): + assert node_id in graph, f"Node {node_id!r} missing from graph" + + def test_checkpoint_node_correct(self) -> None: + graph = self._build() + assert graph["1"]["class_type"] == "CheckpointLoaderSimple" + assert graph["1"]["inputs"]["ckpt_name"] == "test_model.safetensors" + + def test_positive_prompt_in_node_2(self) -> None: + graph = self._build(positive="a blue sunset") + assert graph["2"]["inputs"]["text"] == "a blue sunset" + + def test_negative_prompt_in_node_3(self) -> None: + graph = self._build(negative="ugly blurry") + assert graph["3"]["inputs"]["text"] == "ugly blurry" + + def test_resolution_set_in_node_4(self) -> None: + graph = self._build(width=1024, height=768) + assert graph["4"]["inputs"]["width"] == 1024 + assert graph["4"]["inputs"]["height"] == 768 + + def test_seed_in_ksampler(self) -> None: + graph = self._build(seed=999) + assert graph["5"]["inputs"]["seed"] == 999 + + def test_steps_in_ksampler(self) -> None: + graph = self._build(steps=30) + assert graph["5"]["inputs"]["steps"] == 30 + + def test_cfg_in_ksampler(self) -> None: + graph = self._build(cfg=8.5) + assert graph["5"]["inputs"]["cfg"] == 8.5 + + def test_sampler_name_in_ksampler(self) -> None: + graph = self._build(sampler_name="dpmpp_2m") + assert graph["5"]["inputs"]["sampler_name"] == "dpmpp_2m" + + def test_scheduler_in_ksampler(self) -> None: + graph = self._build(scheduler="karras") + assert graph["5"]["inputs"]["scheduler"] == "karras" + + def test_save_image_node_present(self) -> None: + graph = self._build(filename_prefix="myprefix") + assert graph["7"]["class_type"] == "SaveImage" + assert graph["7"]["inputs"]["filename_prefix"] == "myprefix" + + def test_two_calls_return_independent_dicts(self) -> None: + b = _make_builder() + g1 = b.build_txt2img("a", "", 512, 512, 20, 7.0, "euler", "normal", 1) + g2 = b.build_txt2img("b", "", 512, 512, 20, 7.0, "euler", "normal", 2) + g1["5"]["inputs"]["seed"] = 9999 + assert g2["5"]["inputs"]["seed"] == 2 # mutation of g1 doesn't affect g2 + + def test_checkpoint_override(self) -> None: + b = _make_builder() + graph = b.build_txt2img("p", "", 512, 512, 20, 7.0, "euler", "normal", 0, + checkpoint="override.safetensors") + assert graph["1"]["inputs"]["ckpt_name"] == "override.safetensors" + + def test_batch_size_default_is_one(self) -> None: + graph = self._build() + assert graph["4"]["inputs"]["batch_size"] == 1 + + def test_batch_size_override(self) -> None: + graph = self._build(batch_size=4) + assert graph["4"]["inputs"]["batch_size"] == 4 + + +# --------------------------------------------------------------------------- +# LoRA injection +# --------------------------------------------------------------------------- + +class TestLoRAInjection: + def _build_with_lora(self, **lora_kwargs) -> dict: + b = _make_builder() + lora = LoRAConfig(filename="test_lora.safetensors", **lora_kwargs) + return b.build_txt2img( + "test", "", 512, 512, 20, 7.0, "euler", "normal", 0, lora=lora + ) + + def test_lora_node_inserted(self) -> None: + graph = self._build_with_lora() + assert "8" in graph + assert graph["8"]["class_type"] == "LoraLoader" + + def test_lora_filename_correct(self) -> None: + graph = self._build_with_lora() + assert graph["8"]["inputs"]["lora_name"] == "test_lora.safetensors" + + def test_lora_strengths_default(self) -> None: + graph = self._build_with_lora() + assert graph["8"]["inputs"]["strength_model"] == 0.8 + assert graph["8"]["inputs"]["strength_clip"] == 0.8 + + def test_lora_strengths_custom(self) -> None: + graph = self._build_with_lora(strength_model=0.6, strength_clip=0.5) + assert graph["8"]["inputs"]["strength_model"] == 0.6 + assert graph["8"]["inputs"]["strength_clip"] == 0.5 + + def test_ksampler_model_rewired_to_lora(self) -> None: + graph = self._build_with_lora() + assert graph["5"]["inputs"]["model"] == ["8", 0] + + def test_positive_clip_rewired_to_lora(self) -> None: + graph = self._build_with_lora() + assert graph["2"]["inputs"]["clip"] == ["8", 1] + + def test_negative_clip_rewired_to_lora(self) -> None: + graph = self._build_with_lora() + assert graph["3"]["inputs"]["clip"] == ["8", 1] + + def test_no_lora_node_when_none(self) -> None: + b = _make_builder() + graph = b.build_txt2img("p", "", 512, 512, 20, 7.0, "euler", "normal", 0, lora=None) + assert "8" not in graph + + def test_original_graph_not_mutated(self) -> None: + b = _make_builder() + base = b.build_txt2img("p", "", 512, 512, 20, 7.0, "euler", "normal", 0) + original_model_input = base["5"]["inputs"]["model"] + lora = LoRAConfig(filename="x.safetensors") + b._inject_lora(base, lora) + # base should be unchanged + assert base["5"]["inputs"]["model"] == original_model_input + + +# --------------------------------------------------------------------------- +# build_img2img +# --------------------------------------------------------------------------- + +class TestBuildImg2img: + def _build(self, **kwargs) -> dict: + b = _make_builder() + defaults = dict( + positive="refine this", + negative="ugly", + image_path="input.png", + steps=20, + cfg=7.0, + sampler_name="euler", + scheduler="normal", + seed=0, + denoise=0.75, + ) + defaults.update(kwargs) + return b.build_img2img(**defaults) + + def test_has_load_image_node(self) -> None: + graph = self._build() + assert graph["4"]["class_type"] == "LoadImage" + + def test_has_vae_encode_node(self) -> None: + graph = self._build() + assert "10" in graph + assert graph["10"]["class_type"] == "VAEEncode" + + def test_image_path_set_correctly(self) -> None: + graph = self._build(image_path="my_source.png") + assert graph["4"]["inputs"]["image"] == "my_source.png" + + def test_ksampler_uses_vae_encode_latent(self) -> None: + graph = self._build() + assert graph["5"]["inputs"]["latent_image"] == ["10", 0] + + def test_denoise_set_in_ksampler(self) -> None: + graph = self._build(denoise=0.5) + assert graph["5"]["inputs"]["denoise"] == 0.5 + + +# --------------------------------------------------------------------------- +# build_upscale_pass +# --------------------------------------------------------------------------- + +class TestBuildUpscalePass: + def _base(self) -> dict: + return _make_builder().build_txt2img( + "p", "", 512, 512, 20, 7.0, "euler", "normal", 0 + ) + + def test_adds_upscale_nodes(self) -> None: + b = _make_builder() + graph = b.build_upscale_pass(self._base()) + assert "20" in graph + assert "21" in graph + assert "22" in graph + + def test_upscale_model_loader_class(self) -> None: + b = _make_builder() + graph = b.build_upscale_pass(self._base()) + assert graph["20"]["class_type"] == "UpscaleModelLoader" + + def test_upscale_model_filename(self) -> None: + b = _make_builder() + graph = b.build_upscale_pass(self._base(), upscale_model="RealESRGAN_x4.pth") + assert graph["20"]["inputs"]["model_name"] == "RealESRGAN_x4.pth" + + def test_base_workflow_not_mutated(self) -> None: + b = _make_builder() + base = self._base() + original_keys = set(base.keys()) + b.build_upscale_pass(base) + assert set(base.keys()) == original_keys # no new keys added to base + + +# --------------------------------------------------------------------------- +# from_template +# --------------------------------------------------------------------------- + +class TestFromTemplate: + def test_loads_template(self, tmp_workflows_dir: Path) -> None: + b = WorkflowBuilder(workflows_dir=tmp_workflows_dir) + graph = b.from_template("test_workflow") + assert isinstance(graph, dict) + assert "1" in graph + + def test_raises_file_not_found(self, tmp_path: Path) -> None: + b = WorkflowBuilder(workflows_dir=tmp_path) + with pytest.raises(FileNotFoundError): + b.from_template("does_not_exist") + + def test_patches_are_applied(self, tmp_workflows_dir: Path) -> None: + b = WorkflowBuilder(workflows_dir=tmp_workflows_dir) + graph = b.from_template("test_workflow", patches={"2.text": "patched prompt"}) + assert graph["2"]["inputs"]["text"] == "patched prompt" + + def test_multiple_patches(self, tmp_workflows_dir: Path) -> None: + b = WorkflowBuilder(workflows_dir=tmp_workflows_dir) + graph = b.from_template("test_workflow", patches={ + "2.text": "pos", + "3.text": "neg", + "5.seed": 777, + }) + assert graph["2"]["inputs"]["text"] == "pos" + assert graph["3"]["inputs"]["text"] == "neg" + assert graph["5"]["inputs"]["seed"] == 777 + + def test_bad_patch_key_raises_value_error(self, tmp_workflows_dir: Path) -> None: + b = WorkflowBuilder(workflows_dir=tmp_workflows_dir) + with pytest.raises(ValueError, match="node_id"): + b.from_template("test_workflow", patches={"no_dot_key": "value"}) + + def test_nonexistent_node_patch_raises_value_error(self, tmp_workflows_dir: Path) -> None: + b = WorkflowBuilder(workflows_dir=tmp_workflows_dir) + with pytest.raises(ValueError, match="non-existent node"): + b.from_template("test_workflow", patches={"999.text": "value"}) + + +# --------------------------------------------------------------------------- +# list_templates +# --------------------------------------------------------------------------- + +class TestListTemplates: + def test_returns_list(self, tmp_workflows_dir: Path) -> None: + b = WorkflowBuilder(workflows_dir=tmp_workflows_dir) + result = b.list_templates() + assert isinstance(result, list) + + def test_includes_seeded_template(self, tmp_workflows_dir: Path) -> None: + b = WorkflowBuilder(workflows_dir=tmp_workflows_dir) + templates = b.list_templates() + assert "test_workflow" in templates + + def test_empty_dir_returns_empty_list(self, tmp_path: Path) -> None: + b = WorkflowBuilder(workflows_dir=tmp_path) + assert b.list_templates() == [] + + def test_nonexistent_dir_returns_empty_list(self, tmp_path: Path) -> None: + b = WorkflowBuilder(workflows_dir=tmp_path / "nonexistent") + assert b.list_templates() == [] diff --git a/workflows/.gitkeep b/workflows/.gitkeep deleted file mode 100644 index ffc091d..0000000 --- a/workflows/.gitkeep +++ /dev/null @@ -1,10 +0,0 @@ -# ComfyUI workflow JSON templates will be added here. -# These are the workflow graph files you export from ComfyUI's GUI -# (Settings → Enable Dev Mode → Save (API Format)). -# -# Planned templates: -# - txt2img_quality.json — standard text-to-image quality workflow -# - txt2img_ultra.json — high-quality with upscale -# - img2img_refine.json — image-to-image refinement -# - album_cover.json — music-optimised composition -# - video_frame.json — video frame generation diff --git a/workflows/README.md b/workflows/README.md new file mode 100644 index 0000000..58bfde6 --- /dev/null +++ b/workflows/README.md @@ -0,0 +1,54 @@ +# workflows/ + +ComfyUI API-format workflow templates for the RAINGOD AI Music Kit. + +## Available Templates + +| File | Preset | Resolution | Steps | Notes | +|------|--------|-----------|-------|-------| +| `txt2img_draft.json` | fast / euler | 512×512 | 20 | Preview quality — quickest generation | +| `txt2img_quality.json` | quality / dpmpp_2m karras | 1024×1024 | 40 | Default production preset | +| `txt2img_final.json` | ultra / dpmpp_sde karras | 2048×2048 | 80 | Maximum quality for final masters | +| `img2img_refine.json` | quality / dpmpp_2m karras | (source size) | 30 | img2img at 75% denoise — refine existing images | +| `txt2img_synthwave_lora.json` | quality / dpmpp_2m karras | 1024×1024 | 40 | Synthwave LoRA at 0.8 strength | + +## Placeholder Values + +Each template contains placeholders that must be patched before use: + +| Placeholder | Node | Field | Description | +|------------|------|-------|-------------| +| `POSITIVE_PROMPT_PLACEHOLDER` | `"2"` | `text` | Positive generation prompt | +| `NEGATIVE_PROMPT_PLACEHOLDER` | `"3"` | `text` | Negative / exclusion prompt | +| `SOURCE_IMAGE_PLACEHOLDER` | `"4"` | `image` | Source image filename (img2img only) | + +## Using Templates with WorkflowBuilder + +```python +from backend.workflow_builder import WorkflowBuilder + +builder = WorkflowBuilder() + +# Load a template and patch the prompts +workflow = builder.from_template( + "txt2img_quality", + patches={ + "2.text": "glowing neon city at night, synthwave aesthetic", + "3.text": "blurry, low quality, watermark", + "5.seed": 42, + }, +) + +# Submit to ComfyUI +from backend.comfyui_client import ComfyUIClient +client = ComfyUIClient() +prompt_id = client.queue_prompt(workflow) +``` + +## Adding New Templates + +1. Design your workflow in the ComfyUI web UI +2. Use **Save (API format)** to export the workflow JSON +3. Drop the file here with a descriptive name: `_