Skip to content

Commit

Permalink
initial
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Nov 22, 2024
1 parent 562ddce commit 6fb7369
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,14 @@ pip install optimum diffusers onnx onnxruntime-gpu
optimum-cli export onnx --model stabilityai/stable-diffusion-xl-base-1.0 --task stable-diffusion-xl ./sd_xl_base_onnx
```

SD3 and Flux requires transformers >= 4.45, and optimum > 1.23.3:
```
git clone https://github.com/huggingface/optimum
pip install -e .
optimum-cli export onnx --model stabilityai/stable-diffusion-3-medium-diffusers sd3_onnx_fp32
optimum-cli export onnx --model stabilityai/stable-diffusion-3.5-medium-diffusers sd3.5_onnx_fp32
```

### Optimize ONNX Pipeline

Example to optimize the exported float32 ONNX models, and save to float16 models:
Expand All @@ -230,6 +238,10 @@ For SDXL model, it is recommended to use a machine with 48 GB or more memory to
python optimize_pipeline.py -i ./sd_xl_base_onnx -o ./sd_xl_base_fp16 --float16
```

For SD3 model:
```
python optimize_pipeline.py -i sd3_onnx_fp32 -o sd3_onnx_fp16 --float16
```
### Run Benchmark

The benchmark.py script will run a warm-up prompt twice, and measure the peak GPU memory usage in these two runs, then record them as first_run_memory_MB and second_run_memory_MB. Then it will run 5 runs to get average latency (in seconds), and output the results to benchmark_result.csv.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
"2.0": "stabilityai/stable-diffusion-2",
"2.1": "stabilityai/stable-diffusion-2-1",
"xl-1.0": "stabilityai/stable-diffusion-xl-refiner-1.0",
"3.0": "stabilityai/stable-diffusion-3-medium-diffusers",
# "3.5": "stabilityai/stable-diffusion-3.5-medium",
# "3.5-large": "stabilityai/stable-diffusion-3.5-large",
# "flux.1-schnell": "black-forest-labs/FLUX.1-schnell",
# "flux.1-dev": "black-forest-labs/FLUX.1-dev",
}

PROVIDERS = {
Expand Down Expand Up @@ -322,22 +327,10 @@ def get_optimum_ort_pipeline(
disable_safety_checker: bool = True,
use_io_binding: bool = False,
):
from optimum.onnxruntime import ORTStableDiffusionPipeline, ORTStableDiffusionXLPipeline
from optimum.onnxruntime import ORTPipelineForText2Image, ORTStableDiffusionPipeline, ORTStableDiffusionXLPipeline

if directory is not None and os.path.exists(directory):
if "xl" in model_name:
pipeline = ORTStableDiffusionXLPipeline.from_pretrained(
directory,
provider=provider,
session_options=None,
use_io_binding=False, # Not supported by Optimum version 1.17.1 at the time of verification.
)
else:
pipeline = ORTStableDiffusionPipeline.from_pretrained(
directory,
provider=provider,
use_io_binding=use_io_binding,
)
pipeline = ORTPipelineForText2Image.from_pretrained(directory, provider=provider, use_io_binding=use_io_binding)
elif "xl" in model_name:
pipeline = ORTStableDiffusionXLPipeline.from_pretrained(
model_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import coloredlogs
import onnx
from fusion_options import FusionOptions
from onnx_model_bert import BertOnnxModel
from onnx_model_clip import ClipOnnxModel
from onnx_model_unet import UnetOnnxModel
from onnx_model_vae import VaeOnnxModel
Expand All @@ -46,9 +47,20 @@ def has_external_data(onnx_model_path):
return False


def _get_model_list(source_dir: Path):
is_xl = (source_dir / "text_encoder_2").exists()
is_sd3 = (source_dir / "text_encoder_3").exists()
model_list_sd3 = ["text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "vae_encoder", "vae_decoder"]
model_list_sdxl = ["text_encoder", "text_encoder_2", "unet", "vae_encoder", "vae_decoder"]
model_list_sd = ["text_encoder", "unet", "vae_encoder", "vae_decoder"]
model_list = model_list_sd3 if is_sd3 else (model_list_sdxl if is_xl else model_list_sd)
return model_list


def _optimize_sd_pipeline(
source_dir: Path,
target_dir: Path,
model_list: List[str],
use_external_data_format: Optional[bool],
float16: bool,
force_fp32_ops: List[str],
Expand All @@ -60,6 +72,7 @@ def _optimize_sd_pipeline(
Args:
source_dir (Path): Root of input directory of stable diffusion onnx pipeline with float32 models.
target_dir (Path): Root of output directory of stable diffusion onnx pipeline with optimized models.
model_list (List[str]): list of directory names with onnx model.
use_external_data_format (Optional[bool]): use external data format.
float16 (bool): use half precision
force_fp32_ops(List[str]): operators that are forced to run in float32.
Expand All @@ -70,18 +83,21 @@ def _optimize_sd_pipeline(
RuntimeError: output onnx model path existed
"""
model_type_mapping = {
"transformer": "mmdit",
"unet": "unet",
"vae_encoder": "vae",
"vae_decoder": "vae",
"text_encoder": "clip",
"text_encoder_2": "clip",
"safety_checker": "unet",
"text_encoder_3": "clip",
}

model_type_class_mapping = {
"unet": UnetOnnxModel,
"vae": VaeOnnxModel,
"clip": ClipOnnxModel,
"mmdit": BertOnnxModel, # TODO: have a new class for DiT
}

force_fp32_operators = {
Expand All @@ -91,10 +107,10 @@ def _optimize_sd_pipeline(
"text_encoder": [],
"text_encoder_2": [],
"safety_checker": [],
"text_encoder_3": [],
"transformer": [],
}

is_xl = (source_dir / "text_encoder_2").exists()

if force_fp32_ops:
for fp32_operator in force_fp32_ops:
parts = fp32_operator.split(":")
Expand All @@ -108,8 +124,8 @@ def _optimize_sd_pipeline(
for name, model_type in model_type_mapping.items():
onnx_model_path = source_dir / name / "model.onnx"
if not os.path.exists(onnx_model_path):
if name != "safety_checker":
logger.info("input onnx model does not exist: %s", onnx_model_path)
if name != "safety_checker" and name in model_list:
logger.warning("input onnx model does not exist: %s", onnx_model_path)
# some model are optional so we do not raise error here.
continue

Expand All @@ -122,7 +138,7 @@ def _optimize_sd_pipeline(
use_external_data_format = has_external_data(onnx_model_path)

# Graph fusion before fp16 conversion, otherwise they cannot be fused later.
logger.info(f"Optimize {onnx_model_path}...")
logger.info("Optimize %s ...", onnx_model_path)

args.model_type = model_type
fusion_options = FusionOptions.parse(args)
Expand All @@ -147,6 +163,7 @@ def _optimize_sd_pipeline(

if float16:
# For SD-XL, use FP16 in VAE decoder will cause NaN and black image so we keep it in FP32.
is_xl = (source_dir / "text_encoder_2").exists()
if is_xl and name == "vae_decoder":
logger.info("Skip converting %s to float16 to avoid NaN", name)
else:
Expand Down Expand Up @@ -181,17 +198,18 @@ def _optimize_sd_pipeline(
logger.info("*" * 20)


def _copy_extra_directory(source_dir: Path, target_dir: Path):
def _copy_extra_directory(source_dir: Path, target_dir: Path, model_list: List[str]):
"""Copy extra directory that does not have onnx model
Args:
source_dir (Path): source directory
target_dir (Path): target directory
model_list (List[str]): list of directory names with onnx model.
Raises:
RuntimeError: source path does not exist
"""
extra_dirs = ["scheduler", "tokenizer", "tokenizer_2", "feature_extractor"]
extra_dirs = ["scheduler", "tokenizer", "tokenizer_2", "tokenizer_3", "feature_extractor"]

for name in extra_dirs:
source_path = source_dir / name
Expand All @@ -213,8 +231,7 @@ def _copy_extra_directory(source_dir: Path, target_dir: Path):
logger.info("%s => %s", source_path, target_path)

# Some directory are optional
onnx_model_dirs = ["text_encoder", "text_encoder_2", "unet", "vae_encoder", "vae_decoder", "safety_checker"]
for onnx_model_dir in onnx_model_dirs:
for onnx_model_dir in model_list:
source_path = source_dir / onnx_model_dir / "config.json"
target_path = target_dir / onnx_model_dir / "config.json"
if source_path.exists():
Expand All @@ -236,17 +253,20 @@ def optimize_stable_diffusion_pipeline(
if overwrite:
shutil.rmtree(output_dir, ignore_errors=True)
else:
raise RuntimeError("output directory existed:{output_dir}. Add --overwrite to empty the directory.")
raise RuntimeError(f"output directory existed:{output_dir}. Add --overwrite to empty the directory.")

source_dir = Path(input_dir)
target_dir = Path(output_dir)
target_dir.mkdir(parents=True, exist_ok=True)

_copy_extra_directory(source_dir, target_dir)
model_list = _get_model_list(source_dir)

_copy_extra_directory(source_dir, target_dir, model_list)

_optimize_sd_pipeline(
source_dir,
target_dir,
model_list,
use_external_data_format,
float16,
args.force_fp32_ops,
Expand Down

0 comments on commit 6fb7369

Please sign in to comment.