Skip to content

Commit

Permalink
Use standard Exception, chain lora iterables
Browse files Browse the repository at this point in the history
  • Loading branch information
Danamir committed Dec 10, 2023
1 parent 48e44bc commit 6d65ef6
Showing 1 changed file with 5 additions and 15 deletions.
20 changes: 5 additions & 15 deletions ai_diffusion/workflow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations
import math
import re
from itertools import chain
from pathlib import Path
from typing import Any, List, NamedTuple, Optional

Expand All @@ -16,11 +17,6 @@
_pattern_lora = re.compile(r"\s*<lora:([^:<>]+)(?::(-?[^:<>]*))?>\s*", re.IGNORECASE)


class LoraException(Exception):
def __init__(self, msg):
super().__init__(msg)


class ScaledExtent(NamedTuple):
initial: Extent # resolution for initial generation
expanded: Extent # resolution for high res pass
Expand Down Expand Up @@ -216,15 +212,15 @@ def _parse_loras(client: Client, prompt: str) -> list[dict[str, str | float]]:
if not lora_name:
error = f"LoRA not found : {match[0]}"
log.warning(error)
raise LoraException(error)
raise Exception(error)

lora_strength = match[1] if match[1] != "" else 1.0
try:
lora_strength = float(lora_strength)
except ValueError:
error = f"Invalid LoRA strength for {match[0]} : {lora_strength}"
log.warning(error)
raise LoraException(error)
raise Exception(error)

loras.append(dict(name=lora_name, strength=lora_strength))
return loras
Expand Down Expand Up @@ -260,15 +256,9 @@ def load_model_with_lora(
else:
log.warning(f"Style VAE {style.vae} not found, using default VAE from checkpoint")

for lora in style.loras:
if lora["name"] not in comfy.lora_models:
log.warning(f"Style LoRA {lora['name']} not found, skipping")
continue
model, clip = w.load_lora(model, clip, lora["name"], lora["strength"], lora["strength"])

for lora in additional_loras:
for lora in chain(style.loras, additional_loras):
if lora["name"] not in comfy.lora_models:
log.warning(f"Prompt LoRA {lora['name']} not found, skipping")
log.warning(f"LoRA {lora['name']} not found, skipping")
continue
model, clip = w.load_lora(model, clip, lora["name"], lora["strength"], lora["strength"])

Expand Down

0 comments on commit 6d65ef6

Please sign in to comment.