-
Notifications
You must be signed in to change notification settings - Fork 3
/
sd15_lora.py
173 lines (154 loc) · 7.71 KB
/
sd15_lora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import sys
import anyio
import dagger
import os
import time
import subprocess
import urllib.request
import zipfile
import textwrap
import yaml
# Load from config.yml
config = yaml.load(open("config.yml", "r"), Loader=yaml.FullLoader)
MODEL_NAME = config.get("model_name", "runwayml/stable-diffusion-v1-5")
IMAGE = config.get("container_image", "quay.io/lukemarsden/lora:v0.0.2")
ASSETS = config.get("brands", [
"coke",
"dagger",
"docker",
"kubernetes",
"nike",
"vision-pro",
])
PROMPTS = config.get("prompts", {
"mug": "photograph of a coffee mug with logo on it, in the style of <s1><s2>",
"mug2": "coffee mug with brand logo on it, in the style of <s1><s2>",
"mug3": "coffee mug with brand logo on it, in the style of <s1><s2>, 50mm portrait photography, hard rim lighting photography, merchandise",
"tshirt": "woman torso wearing tshirt with <s1><s2> logo, 50mm portrait photography, hard rim lighting photography, merchandise",
})
NUM_IMAGES = config.get("num_images", 10)
URL_PREFIX = config.get("url_prefix", "https://storage.googleapis.com/dagger-assets/")
COEFF = config.get("finetune_weighting", 0.5)
async def main():
print("Spawning docker socket forwarder...")
p = subprocess.Popen(["socat", "TCP-LISTEN:12345,reuseaddr,fork,bind=172.17.0.1", "UNIX-CONNECT:/var/run/docker.sock"])
time.sleep(1)
print("Done!")
config = dagger.Config(log_output=sys.stdout)
# create output directory on the host
output_dir = os.path.join(os.getcwd(), "output")
print("=============================")
print(f"OUTPUT DIRECTORY: {output_dir}")
print("=============================")
os.makedirs(os.path.join(output_dir, "assets"), exist_ok=True)
os.makedirs(os.path.join(output_dir, "downloads"), exist_ok=True)
os.makedirs(os.path.join(output_dir, "loras"), exist_ok=True)
os.makedirs(os.path.join(output_dir, "inference"), exist_ok=True)
for brand in ASSETS:
# http download storage.googleapis.com/dagger-assets/dagger.zip
urllib.request.urlretrieve(
URL_PREFIX + brand + ".zip",
os.path.join(output_dir, "downloads", f"{brand}.zip"),
)
# unzip with zipfile module
with zipfile.ZipFile(os.path.join(output_dir, "downloads", f"{brand}.zip"), 'r') as zip_ref:
zip_ref.extractall(os.path.join(output_dir, "assets"))
# train the loras
for brand in ASSETS:
# initialize Dagger client - no parallelism here
async with dagger.Connection(config) as client:
# fine tune lora
try:
python = (
client
.container()
.from_("docker:latest") # TODO: use '@sha256:...'
# break cache
# .with_env_variable("BREAK_CACHE", str(time.time()))
.with_entrypoint("/usr/local/bin/docker")
.with_exec(["-H", "tcp://172.17.0.1:12345",
"run", "-i", "--rm", "--gpus", "all",
"-v", os.path.join(output_dir, "assets", brand)+":/input",
"-v", os.path.join(output_dir, "loras", brand)+":/output",
IMAGE,
'lora_pti',
f'--pretrained_model_name_or_path={MODEL_NAME}',
'--instance_data_dir=/input', '--output_dir=/output',
'--train_text_encoder', '--resolution=512',
'--train_batch_size=1',
'--gradient_accumulation_steps=4', '--scale_lr',
'--learning_rate_unet=1e-4',
'--learning_rate_text=1e-5', '--learning_rate_ti=5e-4',
'--color_jitter', '--lr_scheduler="linear"',
'--lr_warmup_steps=0',
'--placeholder_tokens="<s1>|<s2>"',
'--use_template="style"', '--save_steps=100',
'--max_train_steps_ti=1000',
'--max_train_steps_tuning=1000',
'--perform_inversion=True', '--clip_ti_decay',
'--weight_decay_ti=0.000', '--weight_decay_lora=0.001',
'--continue_inversion', '--continue_inversion_lr=1e-4',
'--device="cuda:0"', '--lora_rank=1'
])
)
# execute
err = await python.stderr()
out = await python.stdout()
# print stderr
print(f"Hello from Dagger, fine tune LoRA on {brand}: {out}{err}")
except Exception as e:
import pdb; pdb.set_trace()
print(f"error: {e}")
async with dagger.Connection(config) as client:
for brand in ASSETS:
for key, prompt in PROMPTS.items():
for seed in range(NUM_IMAGES):
# inference!
python = (
client
.container()
.from_("docker:latest")
.with_entrypoint("/usr/local/bin/docker")
.with_exec(["-H", "tcp://172.17.0.1:12345",
"run", "-i", "--rm", "--gpus", "all",
"-v", os.path.join(output_dir, "loras", brand)+":/input",
"-v", os.path.join(output_dir, "inference", brand)+":/output",
IMAGE,
'python3',
'-c',
# dedent
textwrap.dedent(f"""
from diffusers import StableDiffusionPipeline, EulerAncestralDiscreteScheduler
import torch
from lora_diffusion import tune_lora_scale, patch_pipe
model_id = "{MODEL_NAME}"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(
"cuda"
)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
prompt = "{prompt}"
seed = {seed}
torch.manual_seed(seed)
patch_pipe(
pipe,
"/input/final_lora.safetensors",
patch_text=True,
patch_ti=True,
patch_unet=True,
)
coeff = {COEFF}
tune_lora_scale(pipe.unet, coeff)
tune_lora_scale(pipe.text_encoder, coeff)
image = pipe(prompt, num_inference_steps=50, guidance_scale=7).images[0]
image.save(f"/output/{key}-{{seed}}.jpg")
image
""")
])
)
# execute
err = await python.stderr()
out = await python.stdout()
# print stderr
print(f"Hello from Dagger, inference {brand}, prompt: {prompt} and {out}{err}")
p.terminate()
anyio.run(main)