Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change in TRT-LLM loading mechanism and exposing aot_joint_export in _compiler.py #3398

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def cross_compile_for_windows(
strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS,
immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS,
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
use_aot_joint_export: bool = _defaults.USE_AOT_JOINT_EXPORT,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows
Expand Down Expand Up @@ -169,6 +170,7 @@ def cross_compile_for_windows(
strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required.
immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored.
enable_weight_streaming (bool): Enable weight streaming.
use_aot_joint_export (bool): Use aot_export_joint_simple, else wrap backend with AOT_autograd, required for distributed tensors
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -326,6 +328,7 @@ def cross_compile_for_windows(
"immutable_weights": immutable_weights,
"enable_cross_compile_for_windows": True,
"enable_weight_streaming": enable_weight_streaming,
"use_aot_joint_export": use_aot_joint_export,
}

# disable the following settings is not supported for cross compilation for windows feature
Expand Down Expand Up @@ -413,6 +416,7 @@ def compile(
strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS,
immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS,
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
use_aot_joint_export: bool = _defaults.USE_AOT_JOINT_EXPORT,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -488,6 +492,7 @@ def compile(
strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required.
immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored.
enable_weight_streaming (bool): Enable weight streaming.

**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -662,6 +667,7 @@ def compile(
"immutable_weights": immutable_weights,
"enable_cross_compile_for_windows": False,
"enable_weight_streaming": enable_weight_streaming,
"use_aot_joint_export": use_aot_joint_export,
}

settings = CompilationSettings(**compilation_options)
Expand Down Expand Up @@ -950,6 +956,7 @@ def convert_exported_program_to_serialized_trt_engine(
strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS,
immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS,
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
use_aot_joint_export: bool = _defaults.USE_AOT_JOINT_EXPORT,
**kwargs: Any,
) -> bytes:
"""Convert an ExportedProgram to a serialized TensorRT engine
Expand Down Expand Up @@ -1013,6 +1020,7 @@ def convert_exported_program_to_serialized_trt_engine(
strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required.
immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored.
enable_weight_streaming (bool): Enable weight streaming.
use_aot_joint_export (bool): Use aot_export_joint_simple, else wrap backend with AOT_autograd, required for distributed tensors
Returns:
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
"""
Expand Down Expand Up @@ -1129,6 +1137,7 @@ def convert_exported_program_to_serialized_trt_engine(
"strip_engine_weights": strip_engine_weights,
"immutable_weights": immutable_weights,
"enable_weight_streaming": enable_weight_streaming,
"use_aot_joint_export": use_aot_joint_export,
}

settings = CompilationSettings(**compilation_options)
Expand Down
185 changes: 138 additions & 47 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import functools
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets put this utility in some place that is more user accessible like torch_tensorrt.utils

import logging
import os
import shutil
import subprocess
import sys
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload

import numpy as np
Expand All @@ -12,6 +15,7 @@
from torch.fx.node import Argument, Target
from torch.fx.passes.shape_prop import TensorMetadata
from torch_tensorrt import _enums
from torch_tensorrt._enums import Platform
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
Expand Down Expand Up @@ -923,64 +927,151 @@ def args_bounds_check(
return args[i] if len(args) > i and args[i] is not None else replacement


def install_wget(platform: str) -> None:
if shutil.which("wget"):
_LOGGER.debug("wget is already installed")
return
if platform.startswith("linux"):
try:
# if its root
if os.geteuid() == 0:
subprocess.run(["apt-get", "update"], check=True)
subprocess.run(["apt-get", "install", "-y", "wget"], check=True)
else:
_LOGGER.debug("Please run with sudo permissions")
subprocess.run(["sudo", "apt-get", "update"], check=True)
subprocess.run(["sudo", "apt-get", "install", "-y", "wget"], check=True)
except subprocess.CalledProcessError as e:
_LOGGER.debug("Error installing wget:", e)


def install_mpi(platform: str) -> None:
if platform.startswith("linux"):
try:
# if its root
if os.geteuid() == 0:
subprocess.run(["apt-get", "update"], check=True)
subprocess.run(["apt-get", "install", "-y", "libmpich-dev"], check=True)
subprocess.run(
["apt-get", "install", "-y", "libopenmpi-dev"], check=True
)
else:
_LOGGER.debug("Please run with sudo permissions")
subprocess.run(["sudo", "apt-get", "update"], check=True)
subprocess.run(
["sudo", "apt-get", "install", "-y", "libmpich-dev"], check=True
)
subprocess.run(
["sudo", "apt-get", "install", "-y", "libopenmpi-dev"], check=True
)
except subprocess.CalledProcessError as e:
_LOGGER.debug("Error installing mpi libs:", e)


def download_plugin_lib_path(py_version: str, platform: str) -> str:
plugin_lib_path = None
if py_version not in ("cp310", "cp312"):
_LOGGER.warning(
"No available wheel for python versions other than py3.10 and py3.12"
)
install_wget(platform)
base_url = "https://pypi.nvidia.com/tensorrt-llm/"
file_name = f"tensorrt_llm-0.17.0.post1-{py_version}-{py_version}-{platform}.whl"
download_url = base_url + file_name
cmd = ["wget", download_url]
try:
if not (os.path.exists(file_name)):
_LOGGER.info(f"Running command: {' '.join(cmd)}")
subprocess.run(cmd)
_LOGGER.info("Download complete of wheel")
if os.path.exists(file_name):
_LOGGER.info("filename now present")
if os.path.exists("./tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so"):
plugin_lib_path = (
"./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so"
)
else:
import zipfile

with zipfile.ZipFile(file_name, "r") as zip_ref:
zip_ref.extractall(".") # Extract to a folder named 'tensorrt_llm'
plugin_lib_path = (
"./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so"
)
except subprocess.CalledProcessError as e:
_LOGGER.debug(f"Error occurred while trying to download: {e}")
except Exception as e:
_LOGGER.debug(f"An unexpected error occurred: {e}")
return plugin_lib_path


def load_tensorrt_llm() -> bool:
"""
Attempts to load the TensorRT-LLM plugin and initialize it.

Returns:
bool: True if the plugin was successfully loaded and initialized, False otherwise.
"""
try:
import tensorrt_llm as trt_llm # noqa: F401

_LOGGER.info("TensorRT-LLM successfully imported")
return True
except (ImportError, AssertionError) as e_import_error:
# Check for environment variable for the plugin library path
plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH")
if not plugin_lib_path:
plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH")
if not plugin_lib_path:
_LOGGER.warning(
"Please set the TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops or else set the USE_TRTLLM_PLUGINS variable to download the shared library",
)
for key, value in os.environ.items():
print(f"{key}: {value}")
use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in (
"1",
"true",
"yes",
"on",
)
if not use_trtllm_plugin:
_LOGGER.warning(
"TensorRT-LLM is not installed. Please install TensorRT-LLM or set TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops",
"Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library"
)
return False
else:
py_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we restrict to cp310 and cp312, It shouldnt matter if we are pulling the whl and unzipping ourselves

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://pypi.nvidia.com/tensorrt-llm/ In this since I see the tags for only cp310 and cp312 I added the check

platform = Platform.current_platform()

_LOGGER.info(f"TensorRT-LLM Plugin lib path found: {plugin_lib_path}")
try:
# Load the shared library
handle = ctypes.CDLL(plugin_lib_path)
_LOGGER.info(f"Successfully loaded plugin library: {plugin_lib_path}")
except OSError as e_os_error:
_LOGGER.error(
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}"
f"Ensure the path is correct and the library is compatible",
exc_info=e_os_error,
)
return False
platform = str(platform).lower()
plugin_lib_path = download_plugin_lib_path(py_version, platform)
try:
# Load the shared
install_mpi(platform)
handle = ctypes.CDLL(plugin_lib_path)
_LOGGER.info(f"Successfully loaded plugin library: {plugin_lib_path}")
except OSError as e_os_error:
_LOGGER.error(
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}"
f"Ensure the path is correct and the library is compatible",
exc_info=e_os_error,
)
return False

try:
# Configure plugin initialization arguments
handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
handle.initTrtLlmPlugins.restype = ctypes.c_bool
except AttributeError as e_plugin_unavailable:
_LOGGER.warning(
"Unable to initialize the TensorRT-LLM plugin library",
exc_info=e_plugin_unavailable,
)
return False
try:
# Configure plugin initialization arguments
handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
handle.initTrtLlmPlugins.restype = ctypes.c_bool
except AttributeError as e_plugin_unavailable:
_LOGGER.warning(
"Unable to initialize the TensorRT-LLM plugin library",
exc_info=e_plugin_unavailable,
)
return False

try:
# Initialize the plugin
TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm"
if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")):
_LOGGER.info("TensorRT-LLM plugin successfully initialized")
return True
else:
_LOGGER.warning("TensorRT-LLM plugin library failed in initialization")
return False
except Exception as e_initialization_error:
_LOGGER.warning(
"Exception occurred during TensorRT-LLM plugin library initialization",
exc_info=e_initialization_error,
)
try:
# Initialize the plugin
TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm"
if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")):
_LOGGER.info("TensorRT-LLM plugin successfully initialized")
return True
else:
_LOGGER.warning("TensorRT-LLM plugin library failed in initialization")
return False
return False
except Exception as e_initialization_error:
_LOGGER.warning(
"Exception occurred during TensorRT-LLM plugin library initialization",
exc_info=e_initialization_error,
)
return False
13 changes: 13 additions & 0 deletions tests/py/dynamo/conversion/harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ def generate_graph(
enable_passes: bool,
propagate_shapes: bool = False,
settings: CompilationSettings = CompilationSettings(),
fuse_distributed_ops: bool = False,
torch_export_dynamic_shapes: Optional[Any] = None,
):
mod = mod.eval()
Expand All @@ -366,6 +367,16 @@ def generate_graph(
tuple(torch_export_inputs),
dynamic_shapes=torch_export_dynamic_shapes,
)
if fuse_distributed_ops:
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
fuse_distributed_ops,
)

gm = exported_program.graph_module
gm = fuse_distributed_ops(gm, settings)
exported_program = exported_program.run_decompositions(
get_decompositions(False)
)
if enable_passes:
exported_program = pre_export_lowering(exported_program, settings)
exported_program = exported_program.run_decompositions(
Expand Down Expand Up @@ -404,6 +415,7 @@ def run_test(
propagate_shapes=False,
int32_reqd=False,
immutable_weights=True,
fuse_distributed_ops=False,
):
# TODO: lan to remove this and set use_dynamo_traccer to True by default
# once all the converter test files are moved to use_dynamo_tracer
Expand All @@ -424,6 +436,7 @@ def run_test(
enable_passes=enable_passes,
propagate_shapes=propagate_shapes,
settings=compilation_settings,
fuse_distributed_ops=fuse_distributed_ops,
)

num_inputs = len(inputs)
Expand Down
Loading
Loading