-
Notifications
You must be signed in to change notification settings - Fork 46
/
controlnet_union_test_segment.py
76 lines (54 loc) · 2.59 KB
/
controlnet_union_test_segment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import sys
import cv2
import torch
import random
import numpy as np
from PIL import Image
from diffusers import AutoencoderKL
from controlnet_aux import SamDetector
from diffusers import EulerAncestralDiscreteScheduler
from models.controlnet_union import ControlNetModel_Union
from pipeline.pipeline_controlnet_union_sd_xl import StableDiffusionXLControlNetUnionPipeline
device=torch.device('cuda:0')
eulera_scheduler = EulerAncestralDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler")
# when test with other base model, you need to change the vae also.
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
controlnet_model = ControlNetModel_Union.from_pretrained("xinsir/controlnet-union-sdxl-1.0", torch_dtype=torch.float16, use_safetensors=True)
pipe = StableDiffusionXLControlNetUnionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet_model,
vae=vae,
torch_dtype=torch.float16,
scheduler=eulera_scheduler,
)
pipe = pipe.to(device)
processor = SamDetector.from_pretrained('dhkim2810/MobileSAM').to(device)
prompt = "your prompt, the longer the better, you can describe it as detail as possible"
negative_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
controlnet_img = cv2.imread("your image path")
controlnet_img = processor(controlnet_img, output_type='cv2')
# need to resize the image resolution to 1024 * 1024 or same bucket resolution to get the best performance
height, width, _ = controlnet_img.shape
ratio = np.sqrt(1024. * 1024. / (width * height))
new_width, new_height = int(width * ratio), int(height * ratio)
controlnet_img = cv2.resize(controlnet_img, (new_width, new_height))
controlnet_img = Image.fromarray(controlnet_img)
seed = random.randint(0, 2147483647)
generator = torch.Generator('cuda').manual_seed(seed)
# 0 -- openpose
# 1 -- depth
# 2 -- hed/pidi/scribble/ted
# 3 -- canny/lineart/anime_lineart/mlsd
# 4 -- normal
# 5 -- segment
images = pipe(prompt=[prompt]*1,
image_list=[0, 0, 0, 0, 0, controlnet_img],
negative_prompt=[negative_prompt]*1,
generator=generator,
width=new_width,
height=new_height,
num_inference_steps=30,
union_control=True,
union_control_type=torch.Tensor([0, 0, 0, 0, 0, 1]),
).images
images[0].save(f"your image save path, png format is usually better than jpg or webp in terms of image quality but got much bigger")