Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added flux demo #3418

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions examples/apps/flux-demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import time

import gradio as gr
import torch
import torch_tensorrt
from diffusers import FluxPipeline

DEVICE = "cuda:0"
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.float16,
)
pipe.to(DEVICE).to(torch.float16)
backbone = pipe.transformer


batch_size = 2
BATCH = torch.export.Dim("batch", min=1, max=8)

# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model.
# To see this recommendation, you can try exporting using min=1, max=4096
dynamic_shapes = {
"hidden_states": {0: BATCH},
"encoder_hidden_states": {0: BATCH},
"pooled_projections": {0: BATCH},
"timestep": {0: BATCH},
"txt_ids": {},
"img_ids": {},
"guidance": {0: BATCH},
"joint_attention_kwargs": {},
"return_dict": None,
}

settings = {
"strict": False,
"allow_complex_guards_as_runtime_asserts": True,
"enabled_precisions": {torch.float32},
"truncate_double": True,
"min_block_size": 1,
"use_fp32_acc": True,
"use_explicit_typing": True,
"debug": False,
"use_python_runtime": True,
"immutable_weights": False,
"enable_cuda_graph": True,
}

trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings)
trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes)
pipe.transformer = trt_gm


def generate_image(prompt, inference_step, batch_size=2):
start_time = time.time()
image = pipe(
prompt,
output_type="pil",
num_inference_steps=inference_step,
num_images_per_prompt=batch_size,
).images
end_time = time.time()
return image, end_time - start_time


generate_image(["Test"], 2)
torch.cuda.empty_cache()


def model_change(model):
if model == "Torch Model":
pipe.transformer = backbone
backbone.to(DEVICE)
else:
backbone.to("cpu")
pipe.transformer = trt_gm
torch.cuda.empty_cache()


def load_lora(path):

pipe.load_lora_weights(
path,
adapter_name="lora1",
)
pipe.set_adapters(["lora1"], adapter_weights=[1])
pipe.fuse_lora()
pipe.unload_lora_weights()
print("LoRA loaded! Begin refitting")
generate_image(["Test"], 2)
print("Refitting Finished!")


# Create Gradio interface
with gr.Blocks(title="Flux Demo with Torch-TensorRT") as demo:
gr.Markdown("# Flux Image Generation Demo Accelerated by Torch-TensorRT")

with gr.Row():
with gr.Column():
# Input components
prompt_input = gr.Textbox(
label="Prompt", placeholder="Enter your prompt here...", lines=3
)
model_dropdown = gr.Dropdown(
choices=["Torch Model", "Torch-TensorRT Accelerated Model"],
value="Torch-TensorRT Accelerated Model",
label="Model Variant",
)

lora_upload_path = gr.Textbox(
label="LoRA Path",
placeholder="Enter the LoRA checkpoint path here",
value="/home/TensorRT/examples/apps/NGRVNG.safetensors",
lines=2,
)
num_steps = gr.Slider(
minimum=20, maximum=100, value=20, step=1, label="Inference Steps"
)
batch_size = gr.Slider(
minimum=1, maximum=8, value=1, step=1, label="Batch Size"
)

generate_btn = gr.Button("Generate Image")
load_lora_btn = gr.Button("Load LoRA")

with gr.Column():
# Output component
output_image = gr.Gallery(label="Generated Image")
time_taken = gr.Textbox(
label="Generation Time (seconds)", interactive=False
)

# Connect the button to the generation function
model_dropdown.change(model_change, inputs=[model_dropdown])
load_lora_btn.click(
fn=load_lora,
inputs=[
lora_upload_path,
],
)

# Update generate button click to include time output
generate_btn.click(
fn=generate_image,
inputs=[
prompt_input,
num_steps,
batch_size,
],
outputs=[output_image, time_taken],
)

# Launch the interface
if __name__ == "__main__":
demo.launch()
8 changes: 4 additions & 4 deletions examples/dynamo/mutable_torchtrt_module_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch
import torch_tensorrt as torch_trt
import torchvision.models as models
from diffusers import DiffusionPipeline

np.random.seed(5)
torch.manual_seed(5)
Expand All @@ -31,7 +32,7 @@
# Initialize the Mutable Torch TensorRT Module with settings.
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
settings = {
"use_python": False,
"use_python_runtime": False,
"enabled_precisions": {torch.float32},
"immutable_weights": False,
}
Expand All @@ -40,7 +41,6 @@
mutable_module = torch_trt.MutableTorchTensorRTModule(model, **settings)
# You can use the mutable module just like the original pytorch module. The compilation happens while you first call the mutable module.
mutable_module(*inputs)

# %%
# Make modifications to the mutable module.
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -73,13 +73,12 @@
# Stable Diffusion with Huggingface
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

from diffusers import DiffusionPipeline

with torch.no_grad():
settings = {
"use_python_runtime": True,
"enabled_precisions": {torch.float16},
"debug": True,
"debug": False,
"immutable_weights": False,
}

Expand All @@ -106,6 +105,7 @@
"text_embeds": {0: BATCH},
"time_ids": {0: BATCH},
},
"return_dict": None,
}
pipe.unet.set_expected_dynamic_shape_range(
args_dynamic_shapes, kwargs_dynamic_shapes
Expand Down
1 change: 1 addition & 0 deletions examples/dynamo/refit_engine_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
)

# Check the output
model2.to("cuda")
expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm(*inputs)
for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
assert torch.allclose(
Expand Down
10 changes: 6 additions & 4 deletions examples/dynamo/torch_export_flux_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@
min_block_size=1,
use_fp32_acc=True,
use_explicit_typing=True,
use_python_runtime=True,
immutable_weights=False,
)

# %%
Expand All @@ -120,13 +122,13 @@
# Release the GPU memory occupied by the exported program and the pipe.transformer
# Set the transformer in the Flux pipeline to the Torch-TRT compiled model

del ep
backbone.to("cpu")
pipe.to(DEVICE)
torch.cuda.empty_cache()
backbone.to("cpu")
pipe.transformer = trt_gm
del ep
torch.cuda.empty_cache()
pipe.transformer.config = config

trt_gm.device = torch.device("cuda")
# %%
# Image generation using prompt
# ---------------------------
Expand Down
17 changes: 8 additions & 9 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
pre_export_lowering,
)
from torch_tensorrt.dynamo.utils import (
CPU_DEVICE,
get_flat_args_with_check,
get_output_metadata,
parse_graph_io,
Expand Down Expand Up @@ -550,15 +551,6 @@ def compile(
"`immutable_weights` must be False when `refit_identical_engine_weights` is True."
)

if (
not immutable_weights
and not refit_identical_engine_weights
and enable_weight_streaming
):
raise ValueError(
"TensorRT's `REFIT` flag is not compatible with `enable_weight_streaming=True` for now. This issue was reported on https://github.com/pytorch/TensorRT/issues/3305"
)

if (
"enable_cross_compile_for_windows" in kwargs.keys()
and kwargs["enable_cross_compile_for_windows"]
Expand Down Expand Up @@ -684,12 +676,17 @@ def compile(
)

gm = exported_program.module()
# Move the weights in the state_dict to CPU
logger.info(
"The model is moved to CPU during compilation. If you want to keep the model on GPU, call module.to('cuda') on the model after compilation."
)
logger.debug("Input graph: " + str(gm.graph))

# Apply lowering on the graph module
gm = post_lowering(gm, settings)
logger.debug("Lowered Input graph: " + str(gm.graph))

exported_program.module().to(CPU_DEVICE)
trt_gm = compile_module(
gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache
)
Expand Down Expand Up @@ -820,6 +817,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
trt_modules = {}
# Iterate over all components that can be accelerated
# Generate the corresponding TRT Module for those

for name, _ in partitioned_module.named_children():
submodule = getattr(partitioned_module, name)
# filter on the GraphModule
Expand All @@ -833,6 +831,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
str(name),
str(submodule.graph),
)
submodule.to(torch.cuda.current_device())
continue

if name not in submodule_node_dict:
Expand Down
Loading