diff --git a/examples/apps/NGRVNG.safetensors b/examples/apps/NGRVNG.safetensors new file mode 100644 index 0000000000..0fe8121f51 Binary files /dev/null and b/examples/apps/NGRVNG.safetensors differ diff --git a/examples/apps/flux-demo.py b/examples/apps/flux-demo.py new file mode 100644 index 0000000000..9adf654921 --- /dev/null +++ b/examples/apps/flux-demo.py @@ -0,0 +1,154 @@ +import time + +import gradio as gr +import torch +import torch_tensorrt +from diffusers import FluxPipeline + +DEVICE = "cuda:0" +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + torch_dtype=torch.float16, +) +pipe.to(DEVICE).to(torch.float16) +backbone = pipe.transformer + + +batch_size = 2 +BATCH = torch.export.Dim("batch", min=1, max=8) + +# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model. +# To see this recommendation, you can try exporting using min=1, max=4096 +dynamic_shapes = { + "hidden_states": {0: BATCH}, + "encoder_hidden_states": {0: BATCH}, + "pooled_projections": {0: BATCH}, + "timestep": {0: BATCH}, + "txt_ids": {}, + "img_ids": {}, + "guidance": {0: BATCH}, + "joint_attention_kwargs": {}, + "return_dict": None, +} + +settings = { + "strict": False, + "allow_complex_guards_as_runtime_asserts": True, + "enabled_precisions": {torch.float32}, + "truncate_double": True, + "min_block_size": 1, + "use_fp32_acc": True, + "use_explicit_typing": True, + "debug": False, + "use_python_runtime": True, + "immutable_weights": False, + "enable_cuda_graph": True, +} + +trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings) +trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes) +pipe.transformer = trt_gm + + +def generate_image(prompt, inference_step, batch_size=2): + start_time = time.time() + image = pipe( + prompt, + output_type="pil", + num_inference_steps=inference_step, + num_images_per_prompt=batch_size, + ).images + end_time = time.time() + return image, end_time - start_time + + +generate_image(["Test"], 2) +torch.cuda.empty_cache() + + +def model_change(model): + if model == "Torch Model": + pipe.transformer = backbone + backbone.to(DEVICE) + else: + backbone.to("cpu") + pipe.transformer = trt_gm + torch.cuda.empty_cache() + + +def load_lora(path): + + pipe.load_lora_weights( + path, + adapter_name="lora1", + ) + pipe.set_adapters(["lora1"], adapter_weights=[1]) + pipe.fuse_lora() + pipe.unload_lora_weights() + print("LoRA loaded! Begin refitting") + generate_image(["Test"], 2) + print("Refitting Finished!") + + +# Create Gradio interface +with gr.Blocks(title="Flux Demo with Torch-TensorRT") as demo: + gr.Markdown("# Flux Image Generation Demo Accelerated by Torch-TensorRT") + + with gr.Row(): + with gr.Column(): + # Input components + prompt_input = gr.Textbox( + label="Prompt", placeholder="Enter your prompt here...", lines=3 + ) + model_dropdown = gr.Dropdown( + choices=["Torch Model", "Torch-TensorRT Accelerated Model"], + value="Torch-TensorRT Accelerated Model", + label="Model Variant", + ) + + lora_upload_path = gr.Textbox( + label="LoRA Path", + placeholder="Enter the LoRA checkpoint path here", + value="/home/TensorRT/examples/apps/NGRVNG.safetensors", + lines=2, + ) + num_steps = gr.Slider( + minimum=20, maximum=100, value=20, step=1, label="Inference Steps" + ) + batch_size = gr.Slider( + minimum=1, maximum=8, value=1, step=1, label="Batch Size" + ) + + generate_btn = gr.Button("Generate Image") + load_lora_btn = gr.Button("Load LoRA") + + with gr.Column(): + # Output component + output_image = gr.Gallery(label="Generated Image") + time_taken = gr.Textbox( + label="Generation Time (seconds)", interactive=False + ) + + # Connect the button to the generation function + model_dropdown.change(model_change, inputs=[model_dropdown]) + load_lora_btn.click( + fn=load_lora, + inputs=[ + lora_upload_path, + ], + ) + + # Update generate button click to include time output + generate_btn.click( + fn=generate_image, + inputs=[ + prompt_input, + num_steps, + batch_size, + ], + outputs=[output_image, time_taken], + ) + +# Launch the interface +if __name__ == "__main__": + demo.launch() diff --git a/examples/dynamo/mutable_torchtrt_module_example.py b/examples/dynamo/mutable_torchtrt_module_example.py index f264b8a8d3..665bda1b51 100644 --- a/examples/dynamo/mutable_torchtrt_module_example.py +++ b/examples/dynamo/mutable_torchtrt_module_example.py @@ -22,6 +22,7 @@ import torch import torch_tensorrt as torch_trt import torchvision.models as models +from diffusers import DiffusionPipeline np.random.seed(5) torch.manual_seed(5) @@ -31,7 +32,7 @@ # Initialize the Mutable Torch TensorRT Module with settings. # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ settings = { - "use_python": False, + "use_python_runtime": False, "enabled_precisions": {torch.float32}, "immutable_weights": False, } @@ -40,7 +41,6 @@ mutable_module = torch_trt.MutableTorchTensorRTModule(model, **settings) # You can use the mutable module just like the original pytorch module. The compilation happens while you first call the mutable module. mutable_module(*inputs) - # %% # Make modifications to the mutable module. # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -73,13 +73,12 @@ # Stable Diffusion with Huggingface # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -from diffusers import DiffusionPipeline with torch.no_grad(): settings = { "use_python_runtime": True, "enabled_precisions": {torch.float16}, - "debug": True, + "debug": False, "immutable_weights": False, } @@ -106,6 +105,7 @@ "text_embeds": {0: BATCH}, "time_ids": {0: BATCH}, }, + "return_dict": None, } pipe.unet.set_expected_dynamic_shape_range( args_dynamic_shapes, kwargs_dynamic_shapes diff --git a/examples/dynamo/refit_engine_example.py b/examples/dynamo/refit_engine_example.py index 66a1a70964..51202528c5 100644 --- a/examples/dynamo/refit_engine_example.py +++ b/examples/dynamo/refit_engine_example.py @@ -101,6 +101,7 @@ ) # Check the output +model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm(*inputs) for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): assert torch.allclose( diff --git a/examples/dynamo/torch_export_flux_dev.py b/examples/dynamo/torch_export_flux_dev.py index 3891fcbb9a..9dcd073f73 100644 --- a/examples/dynamo/torch_export_flux_dev.py +++ b/examples/dynamo/torch_export_flux_dev.py @@ -112,6 +112,8 @@ min_block_size=1, use_fp32_acc=True, use_explicit_typing=True, + use_python_runtime=True, + immutable_weights=False, ) # %% @@ -120,13 +122,13 @@ # Release the GPU memory occupied by the exported program and the pipe.transformer # Set the transformer in the Flux pipeline to the Torch-TRT compiled model -del ep -backbone.to("cpu") pipe.to(DEVICE) -torch.cuda.empty_cache() +backbone.to("cpu") pipe.transformer = trt_gm +del ep +torch.cuda.empty_cache() pipe.transformer.config = config - +trt_gm.device = torch.device("cuda") # %% # Image generation using prompt # --------------------------- diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 6928347baa..d8339ed858 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -37,6 +37,7 @@ pre_export_lowering, ) from torch_tensorrt.dynamo.utils import ( + CPU_DEVICE, get_flat_args_with_check, get_output_metadata, parse_graph_io, @@ -550,15 +551,6 @@ def compile( "`immutable_weights` must be False when `refit_identical_engine_weights` is True." ) - if ( - not immutable_weights - and not refit_identical_engine_weights - and enable_weight_streaming - ): - raise ValueError( - "TensorRT's `REFIT` flag is not compatible with `enable_weight_streaming=True` for now. This issue was reported on https://github.com/pytorch/TensorRT/issues/3305" - ) - if ( "enable_cross_compile_for_windows" in kwargs.keys() and kwargs["enable_cross_compile_for_windows"] @@ -684,12 +676,21 @@ def compile( ) gm = exported_program.module() + # Move the weights in the state_dict to CPU + exported_program.module().to("cpu") + logger.info( + "The model is moved to CPU during compilation. If you want to keep the model on GPU, call module.to('cuda') on the model after compilation." + ) logger.debug("Input graph: " + str(gm.graph)) # Apply lowering on the graph module gm = post_lowering(gm, settings) logger.debug("Lowered Input graph: " + str(gm.graph)) - + if offload_module_to_cpu: + exported_program.module().to(CPU_DEVICE) + logger.info( + "The model is offloaded to CPU during compilation. If you want to keep the model on GPU, set offload_module_to_cpu=False." + ) trt_gm = compile_module( gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache ) @@ -820,6 +821,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: trt_modules = {} # Iterate over all components that can be accelerated # Generate the corresponding TRT Module for those + for name, _ in partitioned_module.named_children(): submodule = getattr(partitioned_module, name) # filter on the GraphModule @@ -833,6 +835,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: str(name), str(submodule.graph), ) + submodule.to(torch.cuda.current_device()) continue if name not in submodule_node_dict: diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index c128e9cc82..6498f8dc57 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -2,6 +2,7 @@ import collections.abc import copy +import gc import logging from typing import Any, List, Optional, Sequence, Tuple @@ -35,7 +36,9 @@ TorchTensorRTModule, ) from torch_tensorrt.dynamo.utils import ( + CPU_DEVICE, check_module_output, + delete_module, get_model_device, get_torch_inputs, set_log_level, @@ -109,7 +112,9 @@ def construct_refit_mapping( def construct_refit_mapping_from_weight_name_map( - weight_name_map: dict[Any, Any], state_dict: dict[Any, Any] + weight_name_map: dict[Any, Any], + state_dict: dict[Any, Any], + settings: CompilationSettings, ) -> dict[Any, Any]: engine_weight_map = {} for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items(): @@ -120,7 +125,9 @@ def construct_refit_mapping_from_weight_name_map( # If weights is not in sd, we can leave it unchanged continue else: - engine_weight_map[engine_weight_name] = state_dict[sd_weight_name] + engine_weight_map[engine_weight_name] = state_dict[sd_weight_name].to( + to_torch_device(settings.device) + ) engine_weight_map[engine_weight_name] = ( engine_weight_map[engine_weight_name] @@ -163,7 +170,7 @@ def _refit_single_trt_engine_with_gm( "constant_mapping", {} ) # type: ignore mapping = construct_refit_mapping_from_weight_name_map( - weight_name_map, new_gm.state_dict() + weight_name_map, new_gm.state_dict(), settings ) constant_mapping_with_type = {} @@ -309,42 +316,68 @@ def refit_module_weights( get_decompositions(settings.enable_experimental_decompositions) ) new_gm = new_weight_module.module() + logger.debug("Input graph: " + str(new_gm.graph)) # Apply lowering on the graph module new_gm = post_lowering(new_gm, settings) - logger.info("Compilation Settings: %s\n", settings) + logger.debug("Lowered Input graph: " + str(new_gm.graph)) # Set torch-executed ops - CONVERTERS.set_disallowed_targets(settings.torch_executed_ops) + CONVERTERS.set_compilation_settings(settings) + + # Check the number of supported operations in the graph + num_supported_ops, total_ops = partitioning.get_graph_converter_support( + new_gm, settings.debug, settings.torch_executed_ops + ) + + if num_supported_ops == 0 or ( + num_supported_ops < settings.min_block_size and not settings.dryrun + ): + logger.warning( + f"{num_supported_ops} supported operations detected in subgraph containing {total_ops} computational nodes. " + f"Skipping this subgraph, since min_block_size was detected to be {settings.min_block_size}" + ) + return new_gm + else: + logger.debug( + f"Detected support for {num_supported_ops} operators out of {total_ops} in subgraph." + ) # If specified, try using the fast partitioner and fall back to the global one on failure if settings.use_fast_partitioner: try: + logger.info("Partitioning the graph via the fast partitioner") new_partitioned_module, supported_ops = partitioning.fast_partition( new_gm, verbose=settings.debug, min_block_size=settings.min_block_size, torch_executed_ops=settings.torch_executed_ops, + require_full_compilation=settings.require_full_compilation, + skip_fusion=(num_supported_ops == total_ops), ) + except torch.fx.passes.splitter_base.FxNetSplitterInternalError: logger.error( "Partitioning failed on the subgraph with fast partition. See trace above. " - + "Retrying with global partition.", + "Retrying with global partition.", exc_info=True, ) settings.use_fast_partitioner = False if not settings.use_fast_partitioner: + logger.info("Partitioning the graph via the global partitioner") new_partitioned_module, supported_ops = partitioning.global_partition( new_gm, verbose=settings.debug, min_block_size=settings.min_block_size, torch_executed_ops=settings.torch_executed_ops, + require_full_compilation=settings.require_full_compilation, ) + # Done Partition if inline_module: # Preprocess the partitioned module to be in the same format as the inline module inline_torch_modules(new_partitioned_module) @@ -361,7 +394,7 @@ def refit_module_weights( # Iterate over all components that can be accelerated # Generate the corresponding TRT Module for those - + new_weight_module.module().to(CPU_DEVICE) for name, new_submodule in new_partitioned_module.named_children(): # Refit each submodule # Extract engine from the submodule @@ -464,26 +497,33 @@ def refit_module_weights( settings=settings, weight_name_map=None, ) + delete_module(new_submodule) # clear EXCLUDE_WEIGHTS flag serialization_config = engine.create_serialization_config() serialization_config.clear_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) serialized_engine = engine.serialize_with_config(serialization_config) - if isinstance( - compiled_submodule, (PythonTorchTensorRTModule, TorchTensorRTModule) - ): + if isinstance(compiled_submodule, PythonTorchTensorRTModule): + compiled_submodule.serialized_engine = bytes(serialized_engine) + elif isinstance(compiled_submodule, TorchTensorRTModule): compiled_submodule.engine = None # Clear the engine for TorchTensorRTModule, otherwise it won't be updated compiled_submodule.serialized_engine = bytes(serialized_engine) compiled_submodule.setup_engine() - elif inline_module: new_engine_info = list(engine_info) new_engine_info[ENGINE_IDX] = bytes(serialized_engine) refitted_engine = torch.classes.tensorrt.Engine(tuple(new_engine_info)) setattr(compiled_module, f"{name}_engine", refitted_engine) + del engine + gc.collect() + torch.cuda.empty_cache() + + delete_module(new_partitioned_module) + if verify_output and arg_inputs is not None: + new_gm.to(torch.cuda.current_device()) if check_module_output( new_module=new_gm, refitted_module=compiled_module, @@ -491,6 +531,7 @@ def refit_module_weights( kwarg_inputs=torch_kwarg_inputs, ): logger.info("Refitting Succeed!") + new_gm.to(CPU_DEVICE) else: if weight_name_map: logger.warning( @@ -506,6 +547,7 @@ def refit_module_weights( in_place=in_place, ) logger.error("Refitting Failed! The outputs do not match.") + new_gm.to(CPU_DEVICE) else: logger.info("Refitting Completed! Output verification skipped.") diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 17f2fccbff..73cb685808 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -44,7 +44,7 @@ get_trt_tensor, to_torch, ) -from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, get_model_device, to_torch_device +from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, delete_module, to_torch_device from torch_tensorrt.fx.observer import Observer from torch_tensorrt.logging import TRT_LOGGER @@ -491,15 +491,11 @@ def _save_weight_mapping(self) -> None: _LOGGER.info("Building weight name mapping...") # Stage 1: Name mapping torch_device = to_torch_device(self.compilation_settings.device) - gm_is_on_cuda = get_model_device(self.module).type == "cuda" - if not gm_is_on_cuda: - # If the model original position is on CPU, move it GPU - sd = { - k: v.reshape(-1).to(torch_device) - for k, v in self.module.state_dict().items() - } - else: - sd = {k: v.reshape(-1) for k, v in self.module.state_dict().items()} + sd = { + k: v.reshape(-1).to(torch_device) + for k, v in self.module.state_dict().items() + } + weight_name_map: dict[str, Any] = {} np_map = {} constant_mapping = {} @@ -737,7 +733,8 @@ def run( self._create_timing_cache( builder_config, self.compilation_settings.timing_cache_path ) - + if self.compilation_settings.offload_module_to_cpu: + delete_module(self.module) serialized_engine = self.builder.build_serialized_network( self.ctx.net, builder_config ) diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index eaeb6a8c28..88eff5757b 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -2,17 +2,16 @@ import logging from copy import deepcopy from enum import Enum, auto -from typing import Any, Collection, Dict, Iterator, List, Optional, Set, Union +from typing import Any, Dict, Iterator, Optional, Union import numpy as np import torch -from torch.fx.node import Target +import torch_tensorrt +from torch.export._trace import _export from torch_tensorrt._Device import Device -from torch_tensorrt._enums import EngineCapability, dtype from torch_tensorrt.dynamo import _defaults from torch_tensorrt.dynamo._compiler import compile as dynamo_compile from torch_tensorrt.dynamo._refit import refit_module_weights -from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.utils import ( check_output_equal, to_torch_device, @@ -63,35 +62,11 @@ def __init__( pytorch_model: torch.nn.Module, *, device: Optional[Union[Device, torch.device, str]] = _defaults.DEVICE, - disable_tf32: bool = _defaults.DISABLE_TF32, - assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT, - sparse_weights: bool = _defaults.SPARSE_WEIGHTS, - enabled_precisions: Set[ - Union[torch.dtype, dtype] - ] = _defaults.ENABLED_PRECISIONS, - engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, - immutable_weights: bool = False, - debug: bool = _defaults.DEBUG, - num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, - workspace_size: int = _defaults.WORKSPACE_SIZE, - dla_sram_size: int = _defaults.DLA_SRAM_SIZE, - dla_local_dram_size: int = _defaults.DLA_LOCAL_DRAM_SIZE, - dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE, - truncate_double: bool = _defaults.TRUNCATE_DOUBLE, - require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION, - min_block_size: int = _defaults.MIN_BLOCK_SIZE, - torch_executed_ops: Optional[Collection[Target]] = None, - torch_executed_modules: Optional[List[str]] = None, - pass_through_build_failures: bool = _defaults.PASS_THROUGH_BUILD_FAILURES, - max_aux_streams: Optional[int] = _defaults.MAX_AUX_STREAMS, - version_compatible: bool = _defaults.VERSION_COMPATIBLE, - optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL, use_python_runtime: bool = _defaults.USE_PYTHON_RUNTIME, - use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER, - enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS, - dryrun: bool = _defaults.DRYRUN, - hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE, - timing_cache_path: str = _defaults.TIMING_CACHE_PATH, + immutable_weights: bool = False, + strict: bool = True, + allow_complex_guards_as_runtime_asserts: bool = False, + weight_streaming_budget: Optional[int] = None, **kwargs: Any, ) -> None: """ @@ -154,53 +129,35 @@ def __init__( self.exp_program: Any = None self.arg_inputs: tuple[Any, ...] = tuple() self.kwarg_inputs: dict[str, Any] = {} - device = to_torch_tensorrt_device(device) - enabled_precisions = {dtype._from(p) for p in enabled_precisions} + self.additional_settings = kwargs + self.strict = strict + self.allow_complex_guards_as_runtime_asserts = ( + allow_complex_guards_as_runtime_asserts + ) + self.use_python_runtime = use_python_runtime + self.trt_device = to_torch_tensorrt_device(device) assert ( not immutable_weights - ), "`immutable_weights` has to be False for a MutableTorchTensorRTModule." - compilation_options = { - "enabled_precisions": ( - enabled_precisions - if enabled_precisions - else _defaults.ENABLED_PRECISIONS - ), - "debug": debug, - "device": device, - "assume_dynamic_shape_support": assume_dynamic_shape_support, - "workspace_size": workspace_size, - "min_block_size": min_block_size, - "torch_executed_ops": ( - torch_executed_ops if torch_executed_ops is not None else set() - ), - "pass_through_build_failures": pass_through_build_failures, - "max_aux_streams": max_aux_streams, - "version_compatible": version_compatible, - "optimization_level": optimization_level, - "use_python_runtime": use_python_runtime, - "truncate_double": truncate_double, - "use_fast_partitioner": use_fast_partitioner, - "num_avg_timing_iters": num_avg_timing_iters, - "enable_experimental_decompositions": enable_experimental_decompositions, - "require_full_compilation": require_full_compilation, - "disable_tf32": disable_tf32, - "sparse_weights": sparse_weights, - "immutable_weights": immutable_weights, - "engine_capability": engine_capability, - "dla_sram_size": dla_sram_size, - "dla_local_dram_size": dla_local_dram_size, - "dla_global_dram_size": dla_global_dram_size, - "dryrun": dryrun, - "hardware_compatible": hardware_compatible, - "timing_cache_path": timing_cache_path, - } + ), "`immutable_weights has to be False for a MutableTorchTensorRTModule" + self.arg_dynamic_shapes: Optional[tuple[Any]] = None self.kwarg_dynamic_shapes: Optional[dict[Any, Any]] = None - - self.settings = CompilationSettings(**compilation_options) + self.serializable_dynamic_shapes_dims: dict[str, tuple[str, int, int]] = {} self.run_info: Optional[tuple[Any, ...]] = None self.state_dict_metadata: dict[str, torch.Size] = {} self._store_state_dict_metadata() + self.enable_weight_streaming = ( + kwargs["enable_weight_streaming"] + if "enable_weight_streaming" in kwargs + else False + ) + self.weight_streaming_ctx = None + self.weight_streaming_budget = weight_streaming_budget + if self.enable_weight_streaming: + if weight_streaming_budget is None: + logger.warning( + "Weight stremaing budget is not set. Using auto weight streaming budget" + ) cls = self.__class__ self.__class__ = type( @@ -293,7 +250,7 @@ def update_refit_condition(self) -> None: # to determine whether refit/recompilation is needed. If the output is the same, no further process needed. if self.run_info: args, kwargs, result = self.run_info - self.original_model.to(to_torch_device(self.settings.device)) + self.original_model.to(to_torch_device(self.trt_device)) new_result = self.original_model(*args, **kwargs) self.original_model.cpu() torch.cuda.empty_cache() @@ -325,17 +282,17 @@ def refit_gm(self) -> None: MutableTorchTensorRTModule automatically catches weight value updates and call this function to refit the module. If it fails to catch the changes, please call this function manually to update the TRT graph module. """ - self.original_model.to(to_torch_device(self.settings.device)) + if self.exp_program is None: - self.exp_program = torch.export.export( - self.original_model, self.arg_inputs, kwargs=self.kwarg_inputs - ) + self.original_model.to(to_torch_device(self.trt_device)) + self.exp_program = self.get_exported_program() else: self.exp_program._state_dict = ( MutableTorchTensorRTModule._transform_state_dict( self.original_model.state_dict() ) ) + self.exp_program.module().to(to_torch_device(self.trt_device)) self.gm = refit_module_weights( self.gm, self.exp_program, @@ -345,9 +302,28 @@ def refit_gm(self) -> None: in_place=True, ) - self.original_model.cpu() + self.original_model.to("cpu") torch.cuda.empty_cache() + def get_exported_program(self) -> torch.export.ExportedProgram: + if self.allow_complex_guards_as_runtime_asserts: + return _export( + self.original_model, + self.arg_inputs, + kwargs=self.kwarg_inputs, + dynamic_shapes=self._get_total_dynamic_shapes(), + strict=self.strict, + allow_complex_guards_as_runtime_asserts=self.allow_complex_guards_as_runtime_asserts, + ) + else: + return torch.export.export( + self.original_model, + self.arg_inputs, + kwargs=self.kwarg_inputs, + dynamic_shapes=self._get_total_dynamic_shapes(), + strict=self.strict, + ) + def compile(self) -> None: """ (Re)compile the TRT graph module using the PyTorch module. @@ -356,25 +332,37 @@ def compile(self) -> None: If it fails to catch the changes, please call this function manually to recompile the TRT graph module. """ # Export the module - self.original_model.to(to_torch_device(self.settings.device)) - self.exp_program = torch.export.export( - self.original_model, - self.arg_inputs, - kwargs=self.kwarg_inputs, - dynamic_shapes=self._get_total_dynamic_shapes(), - ) + self.original_model.to(to_torch_device(self.trt_device)) + self.exp_program = self.get_exported_program() self.gm = dynamo_compile( self.exp_program, arg_inputs=self.arg_inputs, kwarg_inputs=self.kwarg_inputs, - **self.settings.__dict__, + immutable_weights=False, + use_python_runtime=self.use_python_runtime, + **self.additional_settings, ) - self.original_model.cpu() + self.original_model.to("cpu") torch.cuda.empty_cache() + if self.enable_weight_streaming: + self.set_weight_streaming_ctx(self.weight_streaming_budget) + + def set_weight_streaming_ctx(self, requested_budget: Optional[int] = None) -> None: + """ + Set the weight streaming budget. If budget is not set, then automatic weight streaming budget + is used. + """ + self.weight_streaming_ctx = torch_tensorrt.runtime.weight_streaming(self.gm) + requested_budget = ( + requested_budget + if requested_budget is not None + else self.weight_streaming_ctx.get_automatic_weight_streaming_budget() + ) + self.weight_streaming_ctx.device_budget = requested_budget def _validate_inputs(self, *args: Any, **kwargs: Any) -> None: - if not self.arg_inputs: + if not self.arg_inputs and not self.kwarg_inputs: logger.info("First time compilation initiated. This may take some time.") self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE) self._store_inputs(args, kwargs) @@ -491,14 +479,24 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: self._store_state_dict_metadata() self.refit_state.set_state(RefitFlag.LIVE) + weight_streaming_ctx = ( + self.weight_streaming_ctx if self.enable_weight_streaming else None + ) result = self.gm(*args, **kwargs) # Storing inputs and outputs for verification when the state is unknown self.run_info = (args, kwargs, result) return result - def to(self, device: str) -> None: - logger.warning("Original PyTorch model is moved. CPU offload may failed.") - self.original_model.to(device) + def to(self, *args: Any, **kwargs: Any) -> None: + logger.warning( + "Trying to move the original PyTorch model. This will cause CPU offloading failing and increase GPU memory usage." + + "If this is absolute necessary, please call module.pytorch_model.to(...) \n" + + "The model is still on the original device." + ) + + @property + def device(self) -> torch.device: + return to_torch_device(self.trt_device) def __deepcopy__(self, memo: Any) -> Any: cls = self.__class__ @@ -624,18 +622,58 @@ def _check_tensor_shapes_with_dynamic_shapes( return True + def serialize_dynamic_shapes(self) -> None: + dims = self.serializable_dynamic_shapes_dims + + def resursivly_serialize_dynamic_shape(obj: Any) -> None: + if isinstance(obj, dict): + for axis, v in obj.items(): + if isinstance(v, torch.export.dynamic_shapes._Dim): + name = str(v).split("'")[1].split(".")[-1] + # We use string of the hash to be the unique identifier of Dim object + dims.setdefault(str(hash(v)), (name, v.min, v.max)) + obj[axis] = str(hash(v)) + else: + resursivly_serialize_dynamic_shape(v) + if isinstance(obj, (tuple, list)): + for v in obj: + resursivly_serialize_dynamic_shape(v) + + resursivly_serialize_dynamic_shape(self.arg_dynamic_shapes) + resursivly_serialize_dynamic_shape(self.kwarg_dynamic_shapes) + + def deserialize_dynamic_shapes(self) -> None: + dims = self.serializable_dynamic_shapes_dims + + def resursivly_deserialize_dynamic_shape(obj: Any) -> None: + if isinstance(obj, dict): + for axis, v in obj.items(): + if isinstance(v, str): + obj[axis] = torch.export.Dim( + dims[v][0], min=dims[v][1], max=dims[v][2] + ) + else: + resursivly_deserialize_dynamic_shape(v) + if isinstance(obj, (tuple, list)): + for v in obj: + resursivly_deserialize_dynamic_shape(v) + + resursivly_deserialize_dynamic_shape(self.arg_dynamic_shapes) + resursivly_deserialize_dynamic_shape(self.kwarg_dynamic_shapes) + @staticmethod def save(module: Any, path: str) -> None: # Cast the object back to MutableTorchTensorRTModule to save assert ( - not module.settings.use_python_runtime + not module.use_python_runtime ), "Python runtime does not support serialization. Save failed." module.init_finished = False module.__class__ = MutableTorchTensorRTModule exp_program = module.exp_program module.pytorch_model = None module.exp_program = None - torch.save(module, path) + module.serialize_dynamic_shapes() + torch.save(module, path, pickle_protocol=4) # Restore deleted attributes module.exp_program = exp_program module.pytorch_model = _make_refit_change_trigger( @@ -658,7 +696,7 @@ def load(path: str) -> Any: module.pytorch_model = _make_refit_change_trigger( module.original_model, module.refit_state ) - module.original_model.to(to_torch_device(module.settings.device)) + module.original_model.to(to_torch_device(module.device)) module.exp_program = torch.export.export( module.original_model, module.arg_inputs, kwargs=module.kwarg_inputs ) @@ -669,6 +707,7 @@ def load(path: str) -> Any: (cls, module.original_model.__class__), {}, ) + module.deserialize_dynamic_shapes() module.init_finished = True return module diff --git a/py/torch_tensorrt/runtime/_cudagraphs.py b/py/torch_tensorrt/runtime/_cudagraphs.py index 346132145e..de0a7b9fdf 100644 --- a/py/torch_tensorrt/runtime/_cudagraphs.py +++ b/py/torch_tensorrt/runtime/_cudagraphs.py @@ -69,48 +69,16 @@ def __init__(self, compiled_module: torch.nn.Module) -> None: self.old_mode = _PY_RT_CUDAGRAPHS self.compiled_module = compiled_module self.cudagraphs_module: Optional[CudaGraphsTorchTensorRTModule] = None + self.old_module = None - def __enter__(self) -> torch.nn.Module: - global _PY_RT_CUDAGRAPHS - - num_torch_module = 0 - num_trt_module = 0 - for name, module in self.compiled_module.named_children(): - # need to disable cudagraphs if any model requires output allocator - if ( - hasattr(module, "requires_output_allocator") - and module.requires_output_allocator - ): - raise RuntimeError( - "The model contains submodules that require a dynamic output allocator at runtime, which is incompatible with CUDA Graphs. Please disable CUDA Graphs." - ) - if "_run_on_acc" in name: - num_trt_module += 1 - elif "_run_on_gpu" in name: - num_torch_module += 1 - - if num_torch_module > 0: - # Set whole cudagraphs mode and returns wrapped module - _PY_RT_CUDAGRAPHS = CudaGraphsMode.WHOLE_GRAPH_CUDAGRAPHS - # Set new mode for C++ - if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime: - torch.ops.tensorrt.set_cudagraphs_mode(_PY_RT_CUDAGRAPHS) + def __enter__(self) -> torch.nn.Module | torch.fx.GraphModule: - logger.debug( - "Found pytorch subgraphs in module, wrapping module in CudaGraphsTorchTensorRTModule" - ) - self.cudagraphs_module = CudaGraphsTorchTensorRTModule(self.compiled_module) - return self.cudagraphs_module - else: - if num_trt_module > 0: - logger.debug("No graph breaks detected, using runtime cudagraphs mode") - else: - logger.debug( - "Please consider dynamo if there is graph breaks. Using runtime cudagraphs mode" - ) - # Enable cudagraphs for TRT submodule - set_cudagraphs_mode(True) + if isinstance(self.compiled_module, torch_tensorrt.MutableTorchTensorRTModule): + self.old_module = self.compiled_module.gm + self.compiled_module.gm = get_cuda_graph_module(self.compiled_module.gm) return self.compiled_module + else: + return get_cuda_graph_module(self.compiled_module) def __exit__(self, *args: Any) -> None: # Set cudagraphs back to old mode @@ -118,6 +86,52 @@ def __exit__(self, *args: Any) -> None: # __del__ is not entirely predictable, so we reset cudagraph here if self.cudagraphs_module: self.cudagraphs_module._reset_captured_graph() + if self.old_module: # MutableTorchTRTModule + self.compiled_module.gm = self.old_module + + +def get_cuda_graph_module( + compiled_module: torch.fx.GraphModule, +) -> torch.nn.Module | torch.fx.GraphModule: + global _PY_RT_CUDAGRAPHS + + num_torch_module = 0 + num_trt_module = 0 + for name, module in compiled_module.named_children(): + # need to disable cudagraphs if any model requires output allocator + if ( + hasattr(module, "requires_output_allocator") + and module.requires_output_allocator + ): + raise RuntimeError( + "The model contains submodules that require a dynamic output allocator at runtime, which is incompatible with CUDA Graphs. Please disable CUDA Graphs." + ) + if "_run_on_acc" in name: + num_trt_module += 1 + elif "_run_on_gpu" in name: + num_torch_module += 1 + + if num_torch_module > 0: + # Set whole cudagraphs mode and returns wrapped module + _PY_RT_CUDAGRAPHS = CudaGraphsMode.WHOLE_GRAPH_CUDAGRAPHS + # Set new mode for C++ + if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime: + torch.ops.tensorrt.set_cudagraphs_mode(_PY_RT_CUDAGRAPHS) + + logger.debug( + "Found pytorch subgraphs in module, wrapping module in CudaGraphsTorchTensorRTModule" + ) + return CudaGraphsTorchTensorRTModule(compiled_module) + else: + if num_trt_module > 0: + logger.debug("No graph breaks detected, using runtime cudagraphs mode") + else: + logger.debug( + "Please consider dynamo if there is graph breaks. Using runtime cudagraphs mode" + ) + # Enable cudagraphs for TRT submodule + set_cudagraphs_mode(True) + return compiled_module def enable_cudagraphs( diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index a0b3292c29..a18fee7c44 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -93,7 +93,7 @@ def test_refit_one_engine_with_weightmap(): enabled_precisions = {torch.float} debug = False min_block_size = 1 - use_python_runtime = False + use_python_runtime = True exp_program = torch.export.export(model, tuple(inputs)) exp_program2 = torch.export.export(model2, tuple(inputs)) @@ -117,6 +117,7 @@ def test_refit_one_engine_with_weightmap(): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -167,6 +168,7 @@ def test_refit_one_engine_no_map_with_weightmap(): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -192,7 +194,7 @@ def test_refit_one_engine_with_wrong_weightmap(): enabled_precisions = {torch.float} debug = False min_block_size = 1 - use_python_runtime = False + use_python_runtime = True exp_program = torch.export.export(model, tuple(inputs)) exp_program2 = torch.export.export(model2, tuple(inputs)) @@ -221,6 +223,7 @@ def test_refit_one_engine_with_wrong_weightmap(): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -249,7 +252,7 @@ def test_refit_one_engine_bert_with_weightmap(): enabled_precisions = {torch.float} debug = False min_block_size = 1 - use_python_runtime = False + use_python_runtime = True exp_program = torch.export.export(model, tuple(inputs)) exp_program2 = torch.export.export(model2, tuple(inputs)) @@ -272,6 +275,7 @@ def test_refit_one_engine_bert_with_weightmap(): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -294,7 +298,7 @@ def test_refit_one_engine_bert_with_weightmap(): "TorchScript Frontend is not available", ) @pytest.mark.unit -def test_refit_one_engine_inline_runtime__with_weightmap(): +def test_refit_one_engine_inline_runtime_with_weightmap(): trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep") model = models.resnet18(pretrained=False).eval().to("cuda") model2 = models.resnet18(pretrained=True).eval().to("cuda") @@ -326,6 +330,7 @@ def test_refit_one_engine_inline_runtime__with_weightmap(): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -370,6 +375,7 @@ def test_refit_one_engine_python_runtime_with_weightmap(): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -441,6 +447,7 @@ def forward(self, x): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -489,6 +496,7 @@ def test_refit_one_engine_without_weightmap(): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -540,6 +548,7 @@ def test_refit_one_engine_bert_without_weightmap(): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -594,6 +603,7 @@ def test_refit_one_engine_inline_runtime_without_weightmap(): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -638,6 +648,7 @@ def test_refit_one_engine_python_runtime_without_weightmap(): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -709,6 +720,7 @@ def forward(self, x): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -763,6 +775,7 @@ def forward(self, x): ) # Check the output + model.to("cuda") pyt_outputs, trt_outputs = exp_program.module()(*inputs), trt_gm(*inputs) for pyt_output, trt_output in zip(pyt_outputs, trt_outputs): assertions.assertTrue( diff --git a/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py b/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py index c07e04b6a4..d9105f5a75 100644 --- a/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py +++ b/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py @@ -75,7 +75,7 @@ def test_check_input_shape_dynamic(): @pytest.mark.unit -def test_model_complex_dynamic_shape(): +def test_model_complex_dynamic_shape_with_saving(): device = "cuda:0" class Model(torch.nn.Module): @@ -111,6 +111,13 @@ def forward(self, a, b, c=None): # Run inference trt_gm(*inputs, **kwargs) + try: + save_path = os.path.join(tempfile.gettempdir(), "mutable_module.pkl") + torch_trt.MutableTorchTensorRTModule.save(mutable_module, save_path) + model = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl") + except Exception as e: + assert "Module saving and reloading with dynamic shape failed." + inputs_2 = [torch.rand(10, 9).to(device)] kwargs_2 = { "b": torch.rand(9, 30).to(device), diff --git a/tools/perf/Flux/benchmark.sh b/tools/perf/Flux/benchmark.sh new file mode 100644 index 0000000000..fc4468f1ed --- /dev/null +++ b/tools/perf/Flux/benchmark.sh @@ -0,0 +1,4 @@ +#TODO: Enter the HF Token +huggingface-cli login --token HF_TOKEN + +python flux_perf.py > benchmark_output.txt \ No newline at end of file diff --git a/tools/perf/Flux/create_env.sh b/tools/perf/Flux/create_env.sh new file mode 100644 index 0000000000..d46d82ad31 --- /dev/null +++ b/tools/perf/Flux/create_env.sh @@ -0,0 +1,26 @@ +%bash + +git config --global --add safe.directory /home/TensorRT + +#Install bazel +apt install apt-transport-https curl gnupg -y +curl -fsSL https://bazel.build/bazel-release.pub.gpg | gpg --dearmor >bazel-archive-keyring.gpg +mv bazel-archive-keyring.gpg /usr/share/keyrings +echo "deb [arch=amd64 signed-by=/usr/share/keyrings/bazel-archive-keyring.gpg] https://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list + + +apt update && apt install bazel-7.2.1 +apt install bazel +bazel +cd /home/TensorRT + +python -m pip install --pre -e . --extra-index-url https://download.pytorch.org/whl/nightly/cu128 +pip install tensorrt==10.9.0.34 --force-reinstall + +pip3 install --pre torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 + + +pip install sentencepiece=="0.2.0" transformers=="4.48.2" accelerate=="1.3.0" diffusers=="0.32.2" protobuf=="5.29.3" + +pip install notebook +pip install gradio safetensors peft pyinstrument diff --git a/tools/perf/Flux/flux_perf.py b/tools/perf/Flux/flux_perf.py new file mode 100644 index 0000000000..e5e7dceecd --- /dev/null +++ b/tools/perf/Flux/flux_perf.py @@ -0,0 +1,93 @@ +from time import time + +import torch +import torch_tensorrt +from diffusers import FluxPipeline + +for i in range(torch.cuda.device_count()): + print(torch.cuda.get_device_properties(i).name) + +DEVICE = "cuda:0" +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + torch_dtype=torch.float32, +) +pipe.to(DEVICE).to(torch.float32) +backbone = pipe.transformer + + +batch_size = 2 +BATCH = torch.export.Dim("batch", min=1, max=8) + +# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model. +# To see this recommendation, you can try exporting using min=1, max=4096 +dynamic_shapes = { + "hidden_states": {0: BATCH}, + "encoder_hidden_states": {0: BATCH}, + "pooled_projections": {0: BATCH}, + "timestep": {0: BATCH}, + "txt_ids": {}, + "img_ids": {}, + "guidance": {0: BATCH}, + "joint_attention_kwargs": {}, + "return_dict": None, +} + +settings = { + "strict": False, + "allow_complex_guards_as_runtime_asserts": True, + "enabled_precisions": {torch.float32}, + "truncate_double": True, + "min_block_size": 1, + "use_fp32_acc": True, + "use_explicit_typing": True, + "debug": False, + "use_python_runtime": True, + "immutable_weights": False, +} + + +def generate_image(prompt, inference_step, batch_size=2, benchmark=False, iterations=1): + + start = time() + for i in range(iterations): + image = pipe( + prompt, + output_type="pil", + num_inference_steps=inference_step, + num_images_per_prompt=batch_size, + ).images + end = time() + if benchmark: + print("Time Elapse for", iterations, "iterations:", end - start) + print("Average Latency Per Step:", (end - start) / inference_step / iterations) + return image + + +generate_image(["Test"], 2) +print("Benchmark Original PyTorch Module Latency (float32)") +generate_image(["Test"], 50, benchmark=True, iterations=3) + +pipe.to(torch.float16) +print("Benchmark Original PyTorch Module Latency (float16)") +generate_image(["Test"], 50, benchmark=True, iterations=3) + + +trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings) +trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes) +pipe.transformer = trt_gm + +start = time() +generate_image(["Test"], 2) +end = time() +print("Time Elapse compilation:", end - start) +print() +print("Benchmark TRT Accelerated Latency") +generate_image(["Test"], 50, benchmark=True, iterations=3) +torch.cuda.empty_cache() + + +with torch_tensorrt.runtime.enable_cudagraphs(trt_gm): + generate_image(["Test"], 2) + print("Benchmark TRT Accelerated Latency with Cuda Graph") + generate_image(["Test"], 50, benchmark=True, iterations=3)