Skip to content

Commit

Permalink
Fix download_models minimal missing control nets
Browse files Browse the repository at this point in the history
  • Loading branch information
Acly committed Dec 5, 2023
1 parent 3fee881 commit 6384247
Showing 1 changed file with 10 additions and 18 deletions.
28 changes: 10 additions & 18 deletions scripts/download_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""

import asyncio
from itertools import chain
from itertools import chain, islice
import aiohttp
import sys
from pathlib import Path
Expand All @@ -38,6 +38,10 @@ def all_models():
)


def required_models():
return chain(resources.required_models, islice(resources.default_checkpoints, 1))


def _progress(name: str, size: int | None):
return tqdm(
total=size,
Expand Down Expand Up @@ -89,24 +93,17 @@ async def main(
):
print(f"Generative AI for Krita - Model download - v{ai_diffusion.__version__}")
verbose = verbose or dry_run
models = required_models() if minimal else all_models()

timeout = aiohttp.ClientTimeout(total=None, sock_connect=10, sock_read=60)
async with aiohttp.ClientSession(timeout=timeout) as client:
for model in all_models():
for model in models:
if (
(no_sd15 and model.sd_version is SDVersion.sd15)
or (no_sdxl and model.sd_version is SDVersion.sdxl)
or (no_controlnet and model.kind is resources.ResourceKind.controlnet)
or (
no_upscalers
and model.kind is resources.ResourceKind.upscaler
and (not minimal or model.name != "NMKD Superscale model")
)
or (
no_checkpoints
and model.kind is resources.ResourceKind.checkpoint
and (not minimal or model.name != "Realistic Vision")
)
or (no_upscalers and model.kind is resources.ResourceKind.upscaler)
or (no_checkpoints and model.kind is resources.ResourceKind.checkpoint)
):
continue
if verbose:
Expand Down Expand Up @@ -143,12 +140,7 @@ async def main(
parser.add_argument("--no-controlnet", action="store_true", help="skip ControlNet models")
parser.add_argument("-m", "--minimal", action="store_true", help="minimum viable set of models")
args = parser.parse_args()
if args.minimal:
assert not args.no_sd15, "Minimal requires SD1.5 models"
args.no_sdxl = True
args.no_upscalers = True
args.no_checkpoints = True
args.no_controlnet = True
args.no_sdxl = args.no_sdxl or args.minimal
asyncio.run(
main(
args.destination,
Expand Down

0 comments on commit 6384247

Please sign in to comment.