Skip to content

Latest commit

 

History

History
112 lines (81 loc) · 3.5 KB

pixart.md

File metadata and controls

112 lines (81 loc) · 3.5 KB

Running the PixArtAlphaPipeline in under 8GB GPU VRAM

It is possible to run the [PixArtAlphaPipeline] under 8GB GPU VRAM by loading the text encoder in 8-bit numerical precision. Let's walk through a full-fledged example.

First, install the bitsandbytes library:

pip install -U bitsandbytes

Then load the text encoder in 8-bit:

from transformers import T5EncoderModel
from diffusers import PixArtAlphaPipeline

text_encoder = T5EncoderModel.from_pretrained(
    "PixArt-alpha/PixArt-XL-2-1024-MS",
    subfolder="text_encoder",
    load_in_8bit=True,
    device_map="auto",

)
pipe = PixArtAlphaPipeline.from_pretrained(
    "PixArt-alpha/PixArt-XL-2-1024-MS",
    text_encoder=text_encoder,
    transformer=None,
    device_map="auto"
)

Now, use the pipe to encode a prompt:

with torch.no_grad():
    prompt = "cute cat"
    prompt_embeds, prompt_attention_mask, negative_embeds, negative_prompt_attention_mask = pipe.encode_prompt(prompt)

del text_encoder
del pipe
flush()

flush() is just a utility function to clear the GPU VRAM and is implemented like so:

import gc 

def flush():
    gc.collect()
    torch.cuda.empty_cache()

Then compute the latents providing the prompt embeddings as inputs:

pipe = PixArtAlphaPipeline.from_pretrained(
    "PixArt-alpha/PixArt-XL-2-1024-MS",
    text_encoder=None,
    torch_dtype=torch.float16,
).to("cuda")

latents = pipe(
    negative_prompt=None, 
    prompt_embeds=prompt_embeds,
    negative_prompt_embeds=negative_embeds,
    prompt_attention_mask=prompt_attention_mask,
    negative_prompt_attention_mask=negative_prompt_attention_mask,
    num_images_per_prompt=1,
    output_type="latent",
).images

del pipe.transformer
flush()

Notice that while initializing pipe, you're setting text_encoder to None so that it's not loaded.

Once the latents are computed, pass it off the VAE to decode into a real image:

with torch.no_grad():
    image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
image = pipe.image_processor.postprocess(image, output_type="pil")
image.save("cat.png")

All of this, put together, should allow you to run [PixArtAlphaPipeline] under 8GB GPU VRAM.

Find the script here that can be run end-to-end to report the memory being used.

Text embeddings computed in 8-bit can have an impact on the quality of the generated images because of the information loss in the representation space induced by the reduced precision. It's recommended to compare the outputs with and without 8-bit.