Skip to content

Commit

Permalink
benchmarks: add TensorRT
Browse files Browse the repository at this point in the history
  • Loading branch information
isidentical committed Nov 4, 2023
1 parent 64babfd commit 879ada1
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 8 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Running on an A100 80G SXM hosted at [fal.ai](https://fal.ai).
| Diffusers (fp16, xformers) | 1.758s | 1.759s | 1.746s | 1.772s | 28.43 it/s |
| Diffusers (fp16, SDPA, compiled) | 1.352s | 1.351s | 1.348s | 1.356s | 37.01 it/s |
| Diffusers (fp16, SDPA, compiled, NCHW channels last) | 1.066s | 1.065s | 1.062s | 1.076s | 46.95 it/s |
| TensorRT 9.0 (cuda graphs, static shapes) | 0.819s | 0.818s | 0.817s | 0.821s | 61.14 it/s |

### SDXL Benchmarks
| | mean (s) | median (s) | min (s) | max (s) | speed (it/s) |
Expand All @@ -23,6 +24,7 @@ Running on an A100 80G SXM hosted at [fal.ai](https://fal.ai).
| Diffusers (fp16, xformers) | 5.724s | 5.724s | 5.714s | 5.731s | 8.74 it/s |
| Diffusers (fp16, SDPA, compiled) | 5.246s | 5.247s | 5.233s | 5.259s | 9.53 it/s |
| Diffusers (fp16, SDPA, compiled, NCHW channels last) | 5.132s | 5.132s | 5.121s | 5.142s | 9.74 it/s |
| TensorRT 9.0 (cuda graphs, static shapes) | 4.102s | 4.104s | 4.091s | 4.107s | 12.18 it/s |

<!-- END TABLE -->

Expand Down
5 changes: 3 additions & 2 deletions benchmarks/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@

from rich.progress import track

from benchmarks import diffusers
from benchmarks import benchmark_diffusers, benchmark_tensorrt
from benchmarks.settings import BenchmarkSettings, InputParameters

ALL_BENCHMARKS = [
*diffusers.LOCAL_BENCHMARKS,
*benchmark_diffusers.LOCAL_BENCHMARKS,
*benchmark_tensorrt.LOCAL_BENCHMARKS,
]


Expand Down
11 changes: 5 additions & 6 deletions benchmarks/diffusers.py → benchmarks/benchmark_diffusers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import os
from functools import partial

Expand Down Expand Up @@ -57,13 +59,10 @@ def diffusers_any(
pipeline.unet, fullgraph=True, mode="reduce-overhead"
)

return benchmark_settings.apply(
partial(
pipeline,
parameters.prompt,
num_inference_steps=parameters.steps,
)
inference_func = partial(
pipeline, parameters.prompt, num_inference_steps=parameters.steps
)
return benchmark_settings.apply(inference_func)


LOCAL_BENCHMARKS = [
Expand Down
165 changes: 165 additions & 0 deletions benchmarks/benchmark_tensorrt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
from __future__ import annotations

import contextlib
import shutil
import subprocess
import sys
from functools import partial
from pathlib import Path

import fal

from benchmarks.settings import BenchmarkResults, BenchmarkSettings, InputParameters

DATA_DIR = Path("/data/tensorrt")
REPO_DIR = DATA_DIR / "repo"


def prepare_tensorrt() -> Path:
DATA_DIR.mkdir(exist_ok=True)

if not REPO_DIR.exists():
try:
subprocess.check_call(
[
"git",
"clone",
"https://github.com/rajeevsrao/TensorRT",
"--branch",
"release/9.0",
"--single-branch",
str(REPO_DIR),
]
)
except subprocess.CalledProcessError:
print("Failed to clone TensorRT repo")
shutil.rmtree(REPO_DIR)
raise

return REPO_DIR


@fal.function(
# Copied from https://github.com/rajeevsrao/TensorRT/blob/release/9.0/demo/Diffusion/requirements.txt
requirements=[
"--pre",
"accelerate==0.24.1",
"colored",
"controlnet_aux==0.0.6",
"cuda-python",
"diffusers==0.19.3",
"ftfy",
"matplotlib",
"nvtx",
"onnx-graphsurgeon",
"onnx==1.14.0",
"onnxruntime==1.15.1",
"polygraphy==0.47.1",
"scipy",
"tensorrt==9.0.1.post12.dev4",
"torch==2.1",
"transformers==4.31.0",
"--extra-index-url",
"https://pypi.nvidia.com",
"--extra-index-url",
"https://pypi.ngc.nvidia.com",
],
machine_type="GPU",
_scheduler="nomad",
_scheduler_options={
"target_node": "65.21.219.34",
},
)
def tensorrt_any(
benchmark_settings: BenchmarkSettings,
parameters: InputParameters,
model_version: str,
image_height: int,
image_width: int,
) -> BenchmarkResults:
trt_path = prepare_tensorrt()
diffusion_dir = trt_path / "demo" / "Diffusion"
if str(diffusion_dir) not in sys.path:
sys.path.insert(0, str(diffusion_dir))

with contextlib.chdir(diffusion_dir):
from cuda import cudart
from stable_diffusion_pipeline import StableDiffusionPipeline
from utilities import PIPELINE_TYPE

# Initialize demo
options = {
"version": model_version,
"denoising_steps": parameters.steps,
"use_cuda_graph": True,
"max_batch_size": 4,
"output_dir": "output",
}

if model_version == "1.5":
options["pipeline_type"] = PIPELINE_TYPE.TXT2IMG
elif model_version == "xl-1.0":
options["pipeline_type"] = PIPELINE_TYPE.XL_BASE
options["vae_scaling_factor"] = 0.13025
else:
raise ValueError(f"Unknown model version: {model_version}")

pipeline = StableDiffusionPipeline(**options)
pipeline.loadEngines(
engine_dir=f"engine-{model_version}",
framework_model_dir="pytorch_model",
onnx_dir=f"onnx-{model_version}",
onnx_opset=18,
opt_batch_size=1,
opt_image_height=image_height,
opt_image_width=image_width,
enable_all_tactics=False,
enable_refit=False,
force_build=False,
force_export=False,
force_optimize=False,
static_batch=True,
static_shape=True,
timing_cache=f"cache-{model_version}",
)

# Load resources
_, shared_device_memory = cudart.cudaMalloc(pipeline.calculateMaxDeviceMemory())
pipeline.activateEngines(shared_device_memory)
pipeline.loadResources(image_height, image_width, 1, seed=0)
inference_func = partial(
pipeline.infer,
[parameters.prompt],
[""],
image_height=image_height,
image_width=image_width,
save_image=False,
)
results = benchmark_settings.apply(inference_func)
pipeline.teardown()

return results


LOCAL_BENCHMARKS = [
{
"name": "TensorRT 9.0 (cuda graphs, static shapes)",
"category": "SD1.5",
"function": tensorrt_any,
"kwargs": {
"model_version": "1.5",
"image_height": 512,
"image_width": 512,
},
},
{
"name": "TensorRT 9.0 (cuda graphs, static shapes)",
"category": "SDXL",
"function": tensorrt_any,
"kwargs": {
"model_version": "xl-1.0",
"image_height": 512,
"image_width": 512,
},
},
]

0 comments on commit 879ada1

Please sign in to comment.