Skip to content

Commit

Permalink
Add SD3 Pipeline (#329)
Browse files Browse the repository at this point in the history
* Add SD3 Pipeline

Co-authored-by: atiorh <[email protected]>
Co-authored-by: arda-argmax <[email protected]>

* Use swift-transformers for tokenization

* Use diffusionkit converters in torch2coreml

* Documentation and cleanup

* Add model link

* Consolidate batch prediction logic

* Remove DecoderSD3.swift and consolidate logic into Decoder.swift

* Remove DiffusionKit MLX inference reference from README

---------

Co-authored-by: atiorh <[email protected]>
Co-authored-by: arda-argmax <[email protected]>
Co-authored-by: atila <[email protected]>
  • Loading branch information
4 people authored Jul 23, 2024
1 parent 5a170d2 commit c891f43
Show file tree
Hide file tree
Showing 17 changed files with 1,326 additions and 59 deletions.
11 changes: 7 additions & 4 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import PackageDescription
let package = Package(
name: "stable-diffusion",
platforms: [
.macOS(.v11),
.iOS(.v14),
.macOS(.v13),
.iOS(.v16),
],
products: [
.library(
Expand All @@ -18,12 +18,15 @@ let package = Package(
targets: ["StableDiffusionCLI"])
],
dependencies: [
.package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.2.3")
.package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.2.3"),
.package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.8"),
],
targets: [
.target(
name: "StableDiffusion",
dependencies: [],
dependencies: [
.product(name: "Transformers", package: "swift-transformers"),
],
path: "swift/StableDiffusion"),
.executableTarget(
name: "StableDiffusionCLI",
Expand Down
47 changes: 47 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,52 @@ An example `<selected-recipe-string-key>` would be `"recipe_4.50_bit_mixedpalett

</details>


## <a name="using-stable-diffusion-3"></a> Using Stable Diffusion 3

<details>
<summary> Details (Click to expand) </summary>

### Model Conversion

Stable Diffusion 3 uses some new and some old models to run. For the text encoders, the conversion can be done using a similar command as before with the `--sd3-version` flag.

```bash
python -m python_coreml_stable_diffusion.torch2coreml --model-version stabilityai/stable-diffusion-3-medium --bundle-resources-for-swift-cli --convert-text-encoder --sd3-version -o <output-dir>
```

For the new models (MMDiT, a new VAE with 16 channels, and the T5 text encoder), there are a number of new CLI flags that utilize the [DiffusionKit](https://www.github.com/argmaxinc/DiffusionKit) repo:

- `--sd3-version`: Indicates to the converter to treat this as a Stable Diffusion 3 model
- `--convert-mmdit`: Convert the MMDiT model
- `--convert-vae-decoder`: Convert the new VAE model (this will use the 16 channel version if --sd3-version is set)
- `--include-t5`: Downloads and includes a pre-converted T5 text encoder in the conversion

e.g.:
```bash
python -m python_coreml_stable_diffusion.torch2coreml --model-version stabilityai/stable-diffusion-3-medium --bundle-resources-for-swift-cli --convert-vae-decoder --convert-mmdit --include-t5 --sd3-version -o <output-dir>
```

To convert the full pipeline with at 1024x1024 resolution, the following command may be used:

```bash
python -m python_coreml_stable_diffusion.torch2coreml --model-version stabilityai/stable-diffusion-3-medium --bundle-resources-for-swift-cli --convert-text-encoder --convert-vae-decoder --convert-mmdit --include-t5 --sd3-version --latent-h 128 --latent-w 128 -o <output-dir>
```

Keep in mind that the MMDiT model is quite large and will require increasingly more memory and time to convert as the latent resolution increases.

Also note that currently the MMDiT model requires fp32 and therefore only supports `CPU_AND_GPU` compute units and `ORIGINAL` attention implementation (the default for this pipeline).

### Swift Inference

Swift inference for Stable Diffusion 3 is similar to the previous versions. The only difference is that the `--sd3` flag should be used to indicate that the model is a Stable Diffusion 3 model.

```bash
swift run StableDiffusionSample <prompt> --resource-path <output-mlpackages-directory/Resources> --output-path <output-dir> --compute-units cpuAndGPU --sd3
```

</details>

## <a name="using-stable-diffusion-xl"></a> Using Stable Diffusion XL

<details>
Expand Down Expand Up @@ -356,6 +402,7 @@ Resources:
- [`stabilityai/stable-diffusion-2-1-base`](https://huggingface.co/apple/coreml-stable-diffusion-2-1-base)
- [`stabilityai/stable-diffusion-xl-base-1.0`](https://huggingface.co/apple/coreml-stable-diffusion-xl-base)
- [`stabilityai/stable-diffusion-xl-{base+refiner}-1.0`](https://huggingface.co/apple/coreml-stable-diffusion-xl-base-with-refiner)
- [`stabilityai/stable-diffusion-3-medium`](https://huggingface.co/stabilityai/stable-diffusion-3-medium)

If you want to use any of those models you may download the weights and proceed to [generate images with Python](#image-generation-with-python) or [Swift](#image-generation-with-swift).

Expand Down
209 changes: 207 additions & 2 deletions python_coreml_stable_diffusion/torch2coreml.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
DiffusionPipeline,
ControlNetModel
)
from diffusionkit.tests.torch2coreml import (
convert_mmdit_to_mlpackage,
convert_vae_to_mlpackage
)
import gc
from huggingface_hub import snapshot_download

import logging

Expand Down Expand Up @@ -207,6 +212,26 @@ def _compile_coreml_model(source_model_path, output_dir, final_name):
return target_path


def _download_t5_model(args, t5_save_path):
t5_url = args.text_encoder_t5_url
match = re.match(r'https://huggingface.co/(.+)/resolve/main/(.+)', t5_url)
if not match:
raise ValueError(f"Invalid Hugging Face URL: {t5_url}")
repo_id, model_subpath = match.groups()

download_path = snapshot_download(
repo_id=repo_id,
revision="main",
allow_patterns=[f"{model_subpath}/*"]
)
logger.info(f"Downloaded T5 model to {download_path}")

# Move the downloaded model to the top level of the Resources directory
logger.info(f"Copying T5 model from {download_path} to {t5_save_path}")
cache_path = os.path.join(download_path, model_subpath)
shutil.copytree(cache_path, t5_save_path)


def bundle_resources_for_swift_cli(args):
"""
- Compiles Core ML models from mlpackage into mlmodelc format
Expand All @@ -228,6 +253,7 @@ def bundle_resources_for_swift_cli(args):
("refiner", "UnetRefiner"),
("refiner_chunk1", "UnetRefinerChunk1"),
("refiner_chunk2", "UnetRefinerChunk2"),
("mmdit", "MultiModalDiffusionTransformer"),
("control-unet", "ControlledUnet"),
("control-unet_chunk1", "ControlledUnetChunk1"),
("control-unet_chunk2", "ControlledUnetChunk2"),
Expand All @@ -241,7 +267,7 @@ def bundle_resources_for_swift_cli(args):
logger.warning(
f"{source_path} not found, skipping compilation to {target_name}.mlmodelc"
)

if args.convert_controlnet:
for controlnet_model_version in args.convert_controlnet:
controlnet_model_name = controlnet_model_version.replace("/", "_")
Expand Down Expand Up @@ -271,6 +297,25 @@ def bundle_resources_for_swift_cli(args):
f.write(requests.get(args.text_encoder_merges_url).content)
logger.info("Done")

# Fetch and save pre-converted T5 text encoder model
t5_model_name = "TextEncoderT5.mlmodelc"
t5_save_path = os.path.join(resources_dir, t5_model_name)
if args.include_t5:
if not os.path.exists(t5_save_path):
logger.info("Downloading pre-converted T5 encoder model TextEncoderT5.mlmodelc")
_download_t5_model(args, t5_save_path)
logger.info("Done")
else:
logger.info(f"Skipping T5 download as {t5_save_path} already exists")

# Fetch and save T5 text tokenizer JSON files
logger.info("Downloading and saving T5 tokenizer files tokenizer_config.json and tokenizer.json")
with open(os.path.join(resources_dir, "tokenizer_config.json"), "wb") as f:
f.write(requests.get(args.text_encoder_t5_config_url).content)
with open(os.path.join(resources_dir, "tokenizer.json"), "wb") as f:
f.write(requests.get(args.text_encoder_t5_data_url).content)
logger.info("Done")

return resources_dir


Expand Down Expand Up @@ -557,6 +602,61 @@ def forward(self, z):
del traced_vae_decoder, pipe.vae.decoder, coreml_vae_decoder
gc.collect()

def convert_vae_decoder_sd3(args):
""" Converts the VAE component of Stable Diffusion 3
"""
out_path = _get_out_path(args, "vae_decoder")
if os.path.exists(out_path):
logger.info(
f"`vae_decoder` already exists at {out_path}, skipping conversion."
)
return

# Convert the VAE Decoder model via DiffusionKit
converted_vae_path = convert_vae_to_mlpackage(
model_version=args.model_version,
latent_h=args.latent_h,
latent_w=args.latent_w,
output_dir=args.o,
)

# Load converted model
coreml_vae_decoder = ct.models.MLModel(converted_vae_path)

# Set model metadata
coreml_vae_decoder.author = f"Please refer to the Model Card available at huggingface.co/{args.model_version}"
coreml_vae_decoder.license = "Stability AI Community License (https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE.md)"
coreml_vae_decoder.version = args.model_version
coreml_vae_decodershort_description = \
"Stable Diffusion 3 generates images conditioned on text or other images as input through the diffusion process. " \
"Please refer to https://arxiv.org/pdf/2403.03206 for details."

# Set the input descriptions
coreml_vae_decoder.input_description["z"] = \
"The denoised latent embeddings from the unet model after the last step of reverse diffusion"

# Set the output descriptions
coreml_vae_decoder.output_description[
"image"] = "Generated image normalized to range [-1, 1]"

# Set package version metadata
from python_coreml_stable_diffusion._version import __version__
coreml_vae_decoder.user_defined_metadata["com.github.apple.ml-stable-diffusion.version"] = __version__
from diffusionkit.version import __version__
coreml_vae_decoder.user_defined_metadata["com.github.argmax.diffusionkit.version"] = __version__

# Save the updated model
coreml_vae_decoder.save(out_path)

logger.info(f"Saved vae_decoder into {out_path}")

# Delete the original file
if os.path.exists(converted_vae_path):
shutil.rmtree(converted_vae_path)

del coreml_vae_decoder
gc.collect()


def convert_vae_encoder(pipe, args):
""" Converts the VAE Encoder component of Stable Diffusion
Expand Down Expand Up @@ -909,6 +1009,72 @@ def convert_unet(pipe, args, model_name = None):
chunk_mlprogram.main(args)


def convert_mmdit(args):
""" Converts the MMDiT component of Stable Diffusion 3
"""
out_path = _get_out_path(args, "mmdit")
if os.path.exists(out_path):
logger.info(
f"`mmdit` already exists at {out_path}, skipping conversion."
)
return

# Convert the MMDiT model via DiffusionKit
converted_mmdit_path = convert_mmdit_to_mlpackage(
model_version=args.model_version,
latent_h=args.latent_h,
latent_w=args.latent_w,
output_dir=args.o,
# FIXME: Hardcoding to CPU_AND_GPU since ANE doesn't support FLOAT32
compute_precision=ct.precision.FLOAT32,
compute_unit=ct.ComputeUnit.CPU_AND_GPU,
)

# Load converted model
coreml_mmdit = ct.models.MLModel(converted_mmdit_path)

# Set model metadata
coreml_mmdit.author = f"Please refer to the Model Card available at huggingface.co/{args.model_version}"
coreml_mmdit.license = "Stability AI Community License (https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE.md)"
coreml_mmdit.version = args.model_version
coreml_mmdit.short_description = \
"Stable Diffusion 3 generates images conditioned on text or other images as input through the diffusion process. " \
"Please refer to https://arxiv.org/pdf/2403.03206 for details."

# Set the input descriptions
coreml_mmdit.input_description["latent_image_embeddings"] = \
"The low resolution latent feature maps being denoised through reverse diffusion"
coreml_mmdit.input_description["token_level_text_embeddings"] = \
"Output embeddings from the associated text_encoder model to condition to generated image on text. " \
"A maximum of 77 tokens (~40 words) are allowed. Longer text is truncated. "
coreml_mmdit.input_description["pooled_text_embeddings"] = \
"Additional embeddings that if specified are added to the embeddings that are passed along to the MMDiT model."
coreml_mmdit.input_description["timestep"] = \
"A value emitted by the associated scheduler object to condition the model on a given noise schedule"

# Set the output descriptions
coreml_mmdit.output_description["denoiser_output"] = \
"Same shape and dtype as the `latent_image_embeddings` input. " \
"The predicted noise to facilitate the reverse diffusion (denoising) process"

# Set package version metadata
from python_coreml_stable_diffusion._version import __version__
coreml_mmdit.user_defined_metadata["com.github.apple.ml-stable-diffusion.version"] = __version__
from diffusionkit.version import __version__
coreml_mmdit.user_defined_metadata["com.github.argmax.diffusionkit.version"] = __version__

# Save the updated model
coreml_mmdit.save(out_path)

logger.info(f"Saved vae_decoder into {out_path}")

# Delete the original file
if os.path.exists(converted_mmdit_path):
shutil.rmtree(converted_mmdit_path)

del coreml_mmdit
gc.collect()

def convert_safety_checker(pipe, args):
""" Converts the Safety Checker component of Stable Diffusion
"""
Expand Down Expand Up @@ -1288,6 +1454,16 @@ def get_pipeline(args):
use_safetensors=True,
vae=vae,
use_auth_token=True)
elif args.sd3_version:
# SD3 uses standard SDXL diffusers pipeline besides the vae, denoiser, and T5 text encoder
sdxl_base_version = "stabilityai/stable-diffusion-xl-base-1.0"
args.xl_version = True
logger.info(f"SD3 version specified, initializing DiffusionPipeline with {sdxl_base_version} for non-SD3 components..")
pipe = DiffusionPipeline.from_pretrained(sdxl_base_version,
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
use_auth_token=True)
else:
pipe = DiffusionPipeline.from_pretrained(model_version,
torch_dtype=torch.float16,
Expand Down Expand Up @@ -1316,7 +1492,10 @@ def main(args):
# Convert models
if args.convert_vae_decoder:
logger.info("Converting vae_decoder")
convert_vae_decoder(pipe, args)
if args.sd3_version:
convert_vae_decoder_sd3(args)
else:
convert_vae_decoder(pipe, args)
logger.info("Converted vae_decoder")

if args.convert_vae_encoder:
Expand Down Expand Up @@ -1363,6 +1542,11 @@ def main(args):
del pipe
gc.collect()
logger.info(f"Converted refiner")

if args.convert_mmdit:
logger.info("Converting mmdit")
convert_mmdit(args)
logger.info("Converted mmdit")

if args.quantize_nbits is not None:
logger.info(f"Quantizing weights to {args.quantize_nbits}-bit precision")
Expand All @@ -1383,6 +1567,7 @@ def parser_spec():
parser.add_argument("--convert-vae-decoder", action="store_true")
parser.add_argument("--convert-vae-encoder", action="store_true")
parser.add_argument("--convert-unet", action="store_true")
parser.add_argument("--convert-mmdit", action="store_true")
parser.add_argument("--convert-safety-checker", action="store_true")
parser.add_argument(
"--convert-controlnet",
Expand Down Expand Up @@ -1489,6 +1674,7 @@ def parser_spec():
"If specified, enable unet to receive additional inputs from controlnet. "
"Each input added to corresponding resnet output."
)
parser.add_argument("--include-t5", action="store_true")

# Swift CLI Resource Bundling
parser.add_argument(
Expand All @@ -1508,11 +1694,30 @@ def parser_spec():
default=
"https://huggingface.co/openai/clip-vit-base-patch32/resolve/main/merges.txt",
help="The URL to the merged pairs used in by the text tokenizer.")
parser.add_argument(
"--text-encoder-t5-url",
default=
"https://huggingface.co/argmaxinc/coreml-stable-diffusion-3-medium/resolve/main/TextEncoderT5.mlmodelc",
help="The URL to the pre-converted T5 encoder model.")
parser.add_argument(
"--text-encoder-t5-config-url",
default=
"https://huggingface.co/google-t5/t5-small/resolve/main/tokenizer_config.json",
help="The URL to the merged pairs used in by the text tokenizer.")
parser.add_argument(
"--text-encoder-t5-data-url",
default=
"https://huggingface.co/google-t5/t5-small/resolve/main/tokenizer.json",
help="The URL to the merged pairs used in by the text tokenizer.")
parser.add_argument(
"--xl-version",
action="store_true",
help=("If specified, the pre-trained model will be treated as an instantiation of "
"`diffusers.pipelines.StableDiffusionXLPipeline` instead of `diffusers.pipelines.StableDiffusionPipeline`"))
parser.add_argument(
"--sd3-version",
action="store_true",
help=("If specified, the pre-trained model will be treated as an SD3 model."))

return parser

Expand Down
Loading

0 comments on commit c891f43

Please sign in to comment.