-
Notifications
You must be signed in to change notification settings - Fork 20
/
infer_flux_regional_pulid.py
137 lines (111 loc) · 4.72 KB
/
infer_flux_regional_pulid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
from pipeline_flux_regional_pulid import RegionalFluxPipeline_PULID, RegionalFluxAttnProcessor2_0
if __name__ == "__main__":
model_path = "black-forest-labs/FLUX.1-dev"
pipeline = RegionalFluxPipeline_PULID.from_pretrained(model_path, torch_dtype=torch.bfloat16).to("cuda")
attn_procs = {}
for name in pipeline.transformer.attn_processors.keys():
if 'transformer_blocks' in name and name.endswith("attn.processor"):
attn_procs[name] = RegionalFluxAttnProcessor2_0()
else:
attn_procs[name] = pipeline.transformer.attn_processors[name]
pipeline.transformer.set_attn_processor(attn_procs)
# load pulid
pipeline.load_pulid_models()
pipeline.load_pretrain()
# single-person example
# generation settings
image_width = 1280
image_height = 1280
num_samples = 1
num_inference_steps = 24
guidance_scale = 3.5
seed = 124
# regional prompting settings
mask_inject_steps = 10
double_inject_blocks_interval = 1
single_inject_blocks_interval = 1
base_ratio = 0.2
# regional prompting settings
base_prompt = "In a classroom during the afternoon, a man is practicing guitar by himself, with sunlight beautifully illuminating the room"
background_prompt = "empty classroom"
regional_prompt_mask_pairs = {
"0": {
"description": "A man in a blue shirt and jeans, playing guitar",
"mask": [64, 320, 448, 1280]
}
}
# pulid input
id_image_paths = ["./assets/lecun.jpeg"]
id_weights = [1.0] # scale for pulid embedding injection
# multi-person example
# generation settings
# image_width = 1280
# image_height = 968
# num_samples = 1
# num_inference_steps = 24
# guidance_scale = 3.5
# seed = 124
# regional prompting settings
# mask_inject_steps = 8
# double_inject_blocks_interval = 1
# single_inject_blocks_interval = 2
# base_ratio = 0.1
# base_prompt = "In an elegant dining room, two men are having dinner at opposite ends of a long formal table, with warm lighting creating an atmospheric ambiance"
# background_prompt = "a dining room"
# regional_prompt_mask_pairs = {
# "0": {
# "description": "A man in a suit sitting at the table, with a plate of food and wine glass in front of him",
# "mask": [64, 128, 320, 968]
# },
# "1": {
# "description": "A man in a suit sitting at the table, with a plate of food and wine glass in front of him",
# "mask": [960, 128, 1216, 968]
# }
# }
# # pulid input
# id_image_paths = ["./assets/trump.jpg", "./assets/musk.jpg"]
# id_weights = [0.8, 0.8] # scale for pulid embedding injection
# prepare regional prompts and masks
# ensure image width and height are divisible by the vae scale factor
image_width = (image_width // pipeline.vae_scale_factor) * pipeline.vae_scale_factor
image_height = (image_height // pipeline.vae_scale_factor) * pipeline.vae_scale_factor
regional_prompts = []
regional_masks = []
background_mask = torch.ones((image_height, image_width))
for region_idx, region in regional_prompt_mask_pairs.items():
description = region['description']
mask = region['mask']
x1, y1, x2, y2 = mask
mask = torch.zeros((image_height, image_width))
mask[y1:y2, x1:x2] = 1.0
background_mask -= mask
regional_prompts.append(description)
regional_masks.append(mask)
# if regional masks don't cover the whole image, append background prompt and mask
if background_mask.sum() > 0:
regional_prompts.append(background_prompt)
regional_masks.append(background_mask)
# setup regional kwargs that pass to the pipeline
joint_attention_kwargs = {
'regional_prompts': regional_prompts,
'regional_masks': regional_masks,
'double_inject_blocks_interval': double_inject_blocks_interval,
'single_inject_blocks_interval': single_inject_blocks_interval,
'base_ratio': base_ratio,
'id_image_paths': id_image_paths,
'id_weights': id_weights,
'id_masks': regional_masks[:len(id_image_paths)], # use foreground mask as id mask
}
images = pipeline(
prompt=base_prompt,
num_samples=num_samples,
width=image_width, height=image_height,
mask_inject_steps=mask_inject_steps,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=torch.Generator("cuda").manual_seed(seed),
joint_attention_kwargs=joint_attention_kwargs,
).images
for idx, image in enumerate(images):
image.save(f"output_{idx}.jpg")