-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
repaint.py
40 lines (31 loc) · 1.32 KB
/
repaint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from io import BytesIO
import torch
import PIL
import requests
from diffusers import RePaintPipeline, RePaintScheduler
def download_image(url):
response = requests.get(url)
return PIL.Image.open(BytesIO(response.content)).convert("RGB")
img_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/celeba_hq_256.png"
mask_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/mask_256.png"
# Load the original image and the mask as PIL images
original_image = download_image(img_url).resize((256, 256))
mask_image = download_image(mask_url).resize((256, 256))
# Load the RePaint scheduler and pipeline based on a pretrained DDPM model
DEVICE = "cuda:1"
CACHE_DIR = "/comp_robot/rentianhe/weights/diffusers/"
scheduler = RePaintScheduler.from_pretrained("google/ddpm-ema-celebahq-256", cache_dir=CACHE_DIR)
pipe = RePaintPipeline.from_pretrained("google/ddpm-ema-celebahq-256", scheduler=scheduler, cache_dir=CACHE_DIR)
pipe = pipe.to(DEVICE)
generator = torch.Generator(device=DEVICE).manual_seed(0)
output = pipe(
image=original_image,
mask_image=mask_image,
num_inference_steps=250,
eta=0.0,
jump_length=10,
jump_n_sample=10,
generator=generator,
)
inpainted_image = output.images[0]
inpainted_image.save("./repaint_demo.jpg")