-
Notifications
You must be signed in to change notification settings - Fork 362
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
change in TRT-LLM loading mechanism and exposing aot_joint_export in _compiler.py #3398
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,9 @@ | |
import functools | ||
import logging | ||
import os | ||
import shutil | ||
import subprocess | ||
import sys | ||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload | ||
|
||
import numpy as np | ||
|
@@ -12,6 +15,7 @@ | |
from torch.fx.node import Argument, Target | ||
from torch.fx.passes.shape_prop import TensorMetadata | ||
from torch_tensorrt import _enums | ||
from torch_tensorrt._enums import Platform | ||
from torch_tensorrt.dynamo._settings import CompilationSettings | ||
from torch_tensorrt.dynamo._SourceIR import SourceIR | ||
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext | ||
|
@@ -923,64 +927,151 @@ def args_bounds_check( | |
return args[i] if len(args) > i and args[i] is not None else replacement | ||
|
||
|
||
def install_wget(platform: str) -> None: | ||
if shutil.which("wget"): | ||
_LOGGER.debug("wget is already installed") | ||
return | ||
if platform.startswith("linux"): | ||
try: | ||
# if its root | ||
if os.geteuid() == 0: | ||
subprocess.run(["apt-get", "update"], check=True) | ||
subprocess.run(["apt-get", "install", "-y", "wget"], check=True) | ||
else: | ||
_LOGGER.debug("Please run with sudo permissions") | ||
subprocess.run(["sudo", "apt-get", "update"], check=True) | ||
subprocess.run(["sudo", "apt-get", "install", "-y", "wget"], check=True) | ||
except subprocess.CalledProcessError as e: | ||
_LOGGER.debug("Error installing wget:", e) | ||
|
||
|
||
def install_mpi(platform: str) -> None: | ||
if platform.startswith("linux"): | ||
try: | ||
# if its root | ||
if os.geteuid() == 0: | ||
subprocess.run(["apt-get", "update"], check=True) | ||
subprocess.run(["apt-get", "install", "-y", "libmpich-dev"], check=True) | ||
subprocess.run( | ||
["apt-get", "install", "-y", "libopenmpi-dev"], check=True | ||
) | ||
else: | ||
_LOGGER.debug("Please run with sudo permissions") | ||
subprocess.run(["sudo", "apt-get", "update"], check=True) | ||
subprocess.run( | ||
["sudo", "apt-get", "install", "-y", "libmpich-dev"], check=True | ||
) | ||
subprocess.run( | ||
["sudo", "apt-get", "install", "-y", "libopenmpi-dev"], check=True | ||
) | ||
except subprocess.CalledProcessError as e: | ||
_LOGGER.debug("Error installing mpi libs:", e) | ||
|
||
|
||
def download_plugin_lib_path(py_version: str, platform: str) -> str: | ||
plugin_lib_path = None | ||
if py_version not in ("cp310", "cp312"): | ||
_LOGGER.warning( | ||
"No available wheel for python versions other than py3.10 and py3.12" | ||
) | ||
install_wget(platform) | ||
base_url = "https://pypi.nvidia.com/tensorrt-llm/" | ||
file_name = f"tensorrt_llm-0.17.0.post1-{py_version}-{py_version}-{platform}.whl" | ||
download_url = base_url + file_name | ||
cmd = ["wget", download_url] | ||
try: | ||
if not (os.path.exists(file_name)): | ||
_LOGGER.info(f"Running command: {' '.join(cmd)}") | ||
subprocess.run(cmd) | ||
_LOGGER.info("Download complete of wheel") | ||
if os.path.exists(file_name): | ||
_LOGGER.info("filename now present") | ||
if os.path.exists("./tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so"): | ||
plugin_lib_path = ( | ||
"./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so" | ||
) | ||
else: | ||
import zipfile | ||
|
||
with zipfile.ZipFile(file_name, "r") as zip_ref: | ||
zip_ref.extractall(".") # Extract to a folder named 'tensorrt_llm' | ||
plugin_lib_path = ( | ||
"./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so" | ||
) | ||
except subprocess.CalledProcessError as e: | ||
_LOGGER.debug(f"Error occurred while trying to download: {e}") | ||
except Exception as e: | ||
_LOGGER.debug(f"An unexpected error occurred: {e}") | ||
return plugin_lib_path | ||
|
||
|
||
def load_tensorrt_llm() -> bool: | ||
""" | ||
Attempts to load the TensorRT-LLM plugin and initialize it. | ||
|
||
Returns: | ||
bool: True if the plugin was successfully loaded and initialized, False otherwise. | ||
""" | ||
try: | ||
import tensorrt_llm as trt_llm # noqa: F401 | ||
|
||
_LOGGER.info("TensorRT-LLM successfully imported") | ||
return True | ||
except (ImportError, AssertionError) as e_import_error: | ||
# Check for environment variable for the plugin library path | ||
plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH") | ||
if not plugin_lib_path: | ||
plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH") | ||
if not plugin_lib_path: | ||
_LOGGER.warning( | ||
"Please set the TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops or else set the USE_TRTLLM_PLUGINS variable to download the shared library", | ||
) | ||
for key, value in os.environ.items(): | ||
print(f"{key}: {value}") | ||
use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in ( | ||
"1", | ||
"true", | ||
"yes", | ||
"on", | ||
) | ||
if not use_trtllm_plugin: | ||
_LOGGER.warning( | ||
"TensorRT-LLM is not installed. Please install TensorRT-LLM or set TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops", | ||
"Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library" | ||
) | ||
return False | ||
else: | ||
py_version = f"cp{sys.version_info.major}{sys.version_info.minor}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we restrict to cp310 and cp312, It shouldnt matter if we are pulling the whl and unzipping ourselves There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. https://pypi.nvidia.com/tensorrt-llm/ In this since I see the tags for only cp310 and cp312 I added the check |
||
platform = Platform.current_platform() | ||
|
||
_LOGGER.info(f"TensorRT-LLM Plugin lib path found: {plugin_lib_path}") | ||
try: | ||
# Load the shared library | ||
handle = ctypes.CDLL(plugin_lib_path) | ||
_LOGGER.info(f"Successfully loaded plugin library: {plugin_lib_path}") | ||
except OSError as e_os_error: | ||
_LOGGER.error( | ||
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}" | ||
f"Ensure the path is correct and the library is compatible", | ||
exc_info=e_os_error, | ||
) | ||
return False | ||
platform = str(platform).lower() | ||
plugin_lib_path = download_plugin_lib_path(py_version, platform) | ||
try: | ||
# Load the shared | ||
install_mpi(platform) | ||
handle = ctypes.CDLL(plugin_lib_path) | ||
_LOGGER.info(f"Successfully loaded plugin library: {plugin_lib_path}") | ||
except OSError as e_os_error: | ||
_LOGGER.error( | ||
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}" | ||
f"Ensure the path is correct and the library is compatible", | ||
exc_info=e_os_error, | ||
) | ||
return False | ||
|
||
try: | ||
# Configure plugin initialization arguments | ||
handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p] | ||
handle.initTrtLlmPlugins.restype = ctypes.c_bool | ||
except AttributeError as e_plugin_unavailable: | ||
_LOGGER.warning( | ||
"Unable to initialize the TensorRT-LLM plugin library", | ||
exc_info=e_plugin_unavailable, | ||
) | ||
return False | ||
try: | ||
# Configure plugin initialization arguments | ||
handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p] | ||
handle.initTrtLlmPlugins.restype = ctypes.c_bool | ||
except AttributeError as e_plugin_unavailable: | ||
_LOGGER.warning( | ||
"Unable to initialize the TensorRT-LLM plugin library", | ||
exc_info=e_plugin_unavailable, | ||
) | ||
return False | ||
|
||
try: | ||
# Initialize the plugin | ||
TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm" | ||
if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")): | ||
_LOGGER.info("TensorRT-LLM plugin successfully initialized") | ||
return True | ||
else: | ||
_LOGGER.warning("TensorRT-LLM plugin library failed in initialization") | ||
return False | ||
except Exception as e_initialization_error: | ||
_LOGGER.warning( | ||
"Exception occurred during TensorRT-LLM plugin library initialization", | ||
exc_info=e_initialization_error, | ||
) | ||
try: | ||
# Initialize the plugin | ||
TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm" | ||
if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")): | ||
_LOGGER.info("TensorRT-LLM plugin successfully initialized") | ||
return True | ||
else: | ||
_LOGGER.warning("TensorRT-LLM plugin library failed in initialization") | ||
return False | ||
return False | ||
except Exception as e_initialization_error: | ||
_LOGGER.warning( | ||
"Exception occurred during TensorRT-LLM plugin library initialization", | ||
exc_info=e_initialization_error, | ||
) | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets put this utility in some place that is more user accessible like torch_tensorrt.utils