diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index d355cefe77..47c8a41db6 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -96,6 +96,7 @@ def cross_compile_for_windows( strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS, immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS, enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING, + use_aot_joint_export: bool = _defaults.USE_AOT_JOINT_EXPORT, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows @@ -169,6 +170,7 @@ def cross_compile_for_windows( strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required. immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored. enable_weight_streaming (bool): Enable weight streaming. + use_aot_joint_export (bool): Use aot_export_joint_simple, else wrap backend with AOT_autograd, required for distributed tensors **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -326,6 +328,7 @@ def cross_compile_for_windows( "immutable_weights": immutable_weights, "enable_cross_compile_for_windows": True, "enable_weight_streaming": enable_weight_streaming, + "use_aot_joint_export": use_aot_joint_export, } # disable the following settings is not supported for cross compilation for windows feature @@ -413,6 +416,7 @@ def compile( strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS, immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS, enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING, + use_aot_joint_export: bool = _defaults.USE_AOT_JOINT_EXPORT, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -488,6 +492,7 @@ def compile( strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required. immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored. enable_weight_streaming (bool): Enable weight streaming. + **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -662,6 +667,7 @@ def compile( "immutable_weights": immutable_weights, "enable_cross_compile_for_windows": False, "enable_weight_streaming": enable_weight_streaming, + "use_aot_joint_export": use_aot_joint_export, } settings = CompilationSettings(**compilation_options) @@ -950,6 +956,7 @@ def convert_exported_program_to_serialized_trt_engine( strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS, immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS, enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING, + use_aot_joint_export: bool = _defaults.USE_AOT_JOINT_EXPORT, **kwargs: Any, ) -> bytes: """Convert an ExportedProgram to a serialized TensorRT engine @@ -1013,6 +1020,7 @@ def convert_exported_program_to_serialized_trt_engine( strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required. immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored. enable_weight_streaming (bool): Enable weight streaming. + use_aot_joint_export (bool): Use aot_export_joint_simple, else wrap backend with AOT_autograd, required for distributed tensors Returns: bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs """ @@ -1129,6 +1137,7 @@ def convert_exported_program_to_serialized_trt_engine( "strip_engine_weights": strip_engine_weights, "immutable_weights": immutable_weights, "enable_weight_streaming": enable_weight_streaming, + "use_aot_joint_export": use_aot_joint_export, } settings = CompilationSettings(**compilation_options) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 62526080c4..83e398e2fd 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -3,6 +3,9 @@ import functools import logging import os +import shutil +import subprocess +import sys from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload import numpy as np @@ -12,6 +15,7 @@ from torch.fx.node import Argument, Target from torch.fx.passes.shape_prop import TensorMetadata from torch_tensorrt import _enums +from torch_tensorrt._enums import Platform from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext @@ -923,6 +927,84 @@ def args_bounds_check( return args[i] if len(args) > i and args[i] is not None else replacement +def install_wget(platform: str) -> None: + if shutil.which("wget"): + _LOGGER.debug("wget is already installed") + return + if platform.startswith("linux"): + try: + # if its root + if os.geteuid() == 0: + subprocess.run(["apt-get", "update"], check=True) + subprocess.run(["apt-get", "install", "-y", "wget"], check=True) + else: + _LOGGER.debug("Please run with sudo permissions") + subprocess.run(["sudo", "apt-get", "update"], check=True) + subprocess.run(["sudo", "apt-get", "install", "-y", "wget"], check=True) + except subprocess.CalledProcessError as e: + _LOGGER.debug("Error installing wget:", e) + + +def install_mpi(platform: str) -> None: + if platform.startswith("linux"): + try: + # if its root + if os.geteuid() == 0: + subprocess.run(["apt-get", "update"], check=True) + subprocess.run(["apt-get", "install", "-y", "libmpich-dev"], check=True) + subprocess.run( + ["apt-get", "install", "-y", "libopenmpi-dev"], check=True + ) + else: + _LOGGER.debug("Please run with sudo permissions") + subprocess.run(["sudo", "apt-get", "update"], check=True) + subprocess.run( + ["sudo", "apt-get", "install", "-y", "libmpich-dev"], check=True + ) + subprocess.run( + ["sudo", "apt-get", "install", "-y", "libopenmpi-dev"], check=True + ) + except subprocess.CalledProcessError as e: + _LOGGER.debug("Error installing mpi libs:", e) + + +def download_plugin_lib_path(py_version: str, platform: str) -> str: + plugin_lib_path = None + if py_version not in ("cp310", "cp312"): + _LOGGER.warning( + "No available wheel for python versions other than py3.10 and py3.12" + ) + install_wget(platform) + base_url = "https://pypi.nvidia.com/tensorrt-llm/" + file_name = f"tensorrt_llm-0.17.0.post1-{py_version}-{py_version}-{platform}.whl" + download_url = base_url + file_name + cmd = ["wget", download_url] + try: + if not (os.path.exists(file_name)): + _LOGGER.info(f"Running command: {' '.join(cmd)}") + subprocess.run(cmd) + _LOGGER.info("Download complete of wheel") + if os.path.exists(file_name): + _LOGGER.info("filename now present") + if os.path.exists("./tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so"): + plugin_lib_path = ( + "./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so" + ) + else: + import zipfile + + with zipfile.ZipFile(file_name, "r") as zip_ref: + zip_ref.extractall(".") # Extract to a folder named 'tensorrt_llm' + plugin_lib_path = ( + "./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so" + ) + except subprocess.CalledProcessError as e: + _LOGGER.debug(f"Error occurred while trying to download: {e}") + except Exception as e: + _LOGGER.debug(f"An unexpected error occurred: {e}") + return plugin_lib_path + + def load_tensorrt_llm() -> bool: """ Attempts to load the TensorRT-LLM plugin and initialize it. @@ -930,57 +1012,66 @@ def load_tensorrt_llm() -> bool: Returns: bool: True if the plugin was successfully loaded and initialized, False otherwise. """ - try: - import tensorrt_llm as trt_llm # noqa: F401 - - _LOGGER.info("TensorRT-LLM successfully imported") - return True - except (ImportError, AssertionError) as e_import_error: - # Check for environment variable for the plugin library path - plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH") - if not plugin_lib_path: + plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH") + if not plugin_lib_path: + _LOGGER.warning( + "Please set the TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops or else set the USE_TRTLLM_PLUGINS variable to download the shared library", + ) + for key, value in os.environ.items(): + print(f"{key}: {value}") + use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in ( + "1", + "true", + "yes", + "on", + ) + if not use_trtllm_plugin: _LOGGER.warning( - "TensorRT-LLM is not installed. Please install TensorRT-LLM or set TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops", + "Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library" ) return False + else: + py_version = f"cp{sys.version_info.major}{sys.version_info.minor}" + platform = Platform.current_platform() - _LOGGER.info(f"TensorRT-LLM Plugin lib path found: {plugin_lib_path}") - try: - # Load the shared library - handle = ctypes.CDLL(plugin_lib_path) - _LOGGER.info(f"Successfully loaded plugin library: {plugin_lib_path}") - except OSError as e_os_error: - _LOGGER.error( - f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}" - f"Ensure the path is correct and the library is compatible", - exc_info=e_os_error, - ) - return False + platform = str(platform).lower() + plugin_lib_path = download_plugin_lib_path(py_version, platform) + try: + # Load the shared + install_mpi(platform) + handle = ctypes.CDLL(plugin_lib_path) + _LOGGER.info(f"Successfully loaded plugin library: {plugin_lib_path}") + except OSError as e_os_error: + _LOGGER.error( + f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}" + f"Ensure the path is correct and the library is compatible", + exc_info=e_os_error, + ) + return False - try: - # Configure plugin initialization arguments - handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p] - handle.initTrtLlmPlugins.restype = ctypes.c_bool - except AttributeError as e_plugin_unavailable: - _LOGGER.warning( - "Unable to initialize the TensorRT-LLM plugin library", - exc_info=e_plugin_unavailable, - ) - return False + try: + # Configure plugin initialization arguments + handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p] + handle.initTrtLlmPlugins.restype = ctypes.c_bool + except AttributeError as e_plugin_unavailable: + _LOGGER.warning( + "Unable to initialize the TensorRT-LLM plugin library", + exc_info=e_plugin_unavailable, + ) + return False - try: - # Initialize the plugin - TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm" - if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")): - _LOGGER.info("TensorRT-LLM plugin successfully initialized") - return True - else: - _LOGGER.warning("TensorRT-LLM plugin library failed in initialization") - return False - except Exception as e_initialization_error: - _LOGGER.warning( - "Exception occurred during TensorRT-LLM plugin library initialization", - exc_info=e_initialization_error, - ) + try: + # Initialize the plugin + TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm" + if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")): + _LOGGER.info("TensorRT-LLM plugin successfully initialized") + return True + else: + _LOGGER.warning("TensorRT-LLM plugin library failed in initialization") return False - return False + except Exception as e_initialization_error: + _LOGGER.warning( + "Exception occurred during TensorRT-LLM plugin library initialization", + exc_info=e_initialization_error, + ) + return False diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 9813548a10..a7c8d4759b 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -351,6 +351,7 @@ def generate_graph( enable_passes: bool, propagate_shapes: bool = False, settings: CompilationSettings = CompilationSettings(), + fuse_distributed_ops: bool = False, torch_export_dynamic_shapes: Optional[Any] = None, ): mod = mod.eval() @@ -366,6 +367,16 @@ def generate_graph( tuple(torch_export_inputs), dynamic_shapes=torch_export_dynamic_shapes, ) + if fuse_distributed_ops: + from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import ( + fuse_distributed_ops, + ) + + gm = exported_program.graph_module + gm = fuse_distributed_ops(gm, settings) + exported_program = exported_program.run_decompositions( + get_decompositions(False) + ) if enable_passes: exported_program = pre_export_lowering(exported_program, settings) exported_program = exported_program.run_decompositions( @@ -404,6 +415,7 @@ def run_test( propagate_shapes=False, int32_reqd=False, immutable_weights=True, + fuse_distributed_ops=False, ): # TODO: lan to remove this and set use_dynamo_traccer to True by default # once all the converter test files are moved to use_dynamo_tracer @@ -424,6 +436,7 @@ def run_test( enable_passes=enable_passes, propagate_shapes=propagate_shapes, settings=compilation_settings, + fuse_distributed_ops=fuse_distributed_ops, ) num_inputs = len(inputs) diff --git a/tests/py/dynamo/conversion/test_nccl_ops.py b/tests/py/dynamo/conversion/test_nccl_ops.py new file mode 100644 index 0000000000..4db24881c8 --- /dev/null +++ b/tests/py/dynamo/conversion/test_nccl_ops.py @@ -0,0 +1,80 @@ +import os + +import torch +import torch.distributed as dist +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + + +def set_environment_variables(): + os.environ["WORLD_SIZE"] = str(1) + os.environ["RANK"] = str(0) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(29500) + os.environ["USE_TRTLLM_PLUGINS"] = "1" + + +set_environment_variables() +dist.init_process_group(backend="nccl", init_method="env://") +group = dist.new_group(ranks=[0]) +group_name = group.group_name + +from .harness import DispatchTestCase + + +class TestGatherNcclOpsConverter(DispatchTestCase): + @parameterized.expand([(8)]) + def test_nccl_ops(self, linear_layer_dim): + class DistributedGatherModel(nn.Module): + def __init__(self, input_dim): + super().__init__() + self.fc = torch.nn.Linear(input_dim, input_dim) + + def forward(self, x): + x = self.fc(x) + world_size = 1 + gathered_tensor = torch.ops._c10d_functional.all_gather_into_tensor( + x, world_size, group_name + ) + gathered_tensor = torch.ops._c10d_functional.wait_tensor( + gathered_tensor + ) + return gathered_tensor + + inputs = [torch.randn(1, linear_layer_dim).to("cuda")] + + self.run_test( + DistributedGatherModel(linear_layer_dim).cuda(), + inputs, + use_dynamo_tracer=True, + fuse_distributed_ops=True, + ) + + # TODO: Look at this + # @parameterized.expand( + # [ + # (8) + # ] + # ) + # def test_nccl_ops_scatter(self, linear_layer_dim): + + # class DistributedReduceScatterModel(nn.Module): + # def __init__(self, input_dim): + # super().__init__() + # def forward(self, x): + # world_size = 1 + # scatter_reduce_tensor = torch.ops._c10d_functional.reduce_scatter_tensor(x, "sum", world_size, group_name) + # scatter_reduce_tensor = torch.ops._c10d_functional.wait_tensor(scatter_reduce_tensor) + # return scatter_reduce_tensor + # inputs = [torch.zeros(1, linear_layer_dim).to("cuda")] + + # self.run_test( + # DistributedReduceScatterModel(linear_layer_dim).cuda(), + # inputs, + # use_dynamo_tracer=True, + # ) + + +if __name__ == "__main__": + run_tests()