Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 39 additions & 28 deletions examples/distributed_inference/tensor_parallel_initialize_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,13 @@
import tensorrt as trt
import torch
import torch.distributed as dist
from torch.distributed._tensor.device_mesh import init_device_mesh
from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh

logger = logging.getLogger(__name__)

def find_repo_root(max_depth=10):
dir_path = os.path.dirname(os.path.realpath(__file__))
for i in range(max_depth):
files = os.listdir(dir_path)
if "MODULE.bazel" in files:
return dir_path
else:
dir_path = os.path.dirname(dir_path)

raise RuntimeError("Could not find repo root")


def initialize_logger(rank, logger_file_name):
logger = logging.getLogger()
logger.setLevel(logging.INFO)
fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w")
fh.setLevel(logging.INFO)
logger.addHandler(fh)
return logger


# This is required for env initialization since we use mpirun
def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500):
# this is kept at the application level, when mpirun is used to run the application
def initialize_distributed_env(rank=0, world_size=1, port=29500):
local_rank = int(
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
)
Expand All @@ -50,9 +31,6 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(port)
os.environ["TRTLLM_PLUGINS_PATH"] = (
find_repo_root() + "/lib/libnvinfer_plugin_tensorrt_llm.so"
)

# Necessary to assign a device to each rank.
torch.cuda.set_device(local_rank)
Expand All @@ -66,16 +44,49 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,))
rank = device_mesh.get_rank()
assert rank == local_rank
logger = initialize_logger(rank, logger_file_name)
device_id = (
rank % torch.cuda.device_count()
) # Ensure each rank gets a unique device
torch.cuda.set_device(device_id)

return device_mesh, world_size, rank, logger
return device_mesh, world_size, rank


def cleanup_distributed_env():
"""Clean up distributed process group to prevent resource leaks."""
if dist.is_initialized():
dist.destroy_process_group()


def check_tensor_parallel_device_number(world_size: int) -> None:
if world_size % 2 != 0:
raise ValueError(
f"TP examples require even number of GPUs, but got {world_size} gpus"
)


def get_tensor_parallel_device_mesh(
rank: int = 0, world_size: int = 1
) -> tuple[DeviceMesh, int, int]:
local_rank = int(
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
)
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size))
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,))
rank = device_mesh.get_rank()
assert rank == local_rank
device_id = (
rank % torch.cuda.device_count()
) # Ensure each rank gets a unique device
torch.cuda.set_device(device_id)

return device_mesh, world_size, rank


def initialize_distributed_logger(rank: int, logger_file_name: str) -> logging.Logger:
logger = logging.getLogger()
logger.setLevel(logging.INFO)
fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w")
fh.setLevel(logging.INFO)
logger.addHandler(fh)
return logger
21 changes: 13 additions & 8 deletions examples/distributed_inference/tensor_parallel_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,31 @@

"""

import logging
import os
import time

import torch
import torch_tensorrt
from rotary_embedding import RotaryAttention, parallel_rotary_block
import torch.distributed as dist
from tensor_parallel_initialize_dist import (
cleanup_distributed_env,
get_tensor_parallel_device_mesh,
initialize_distributed_env,
initialize_distributed_logger,
)

device_mesh, _world_size, _rank, logger = initialize_distributed_env(
"./tensor_parallel_rotary_embedding"
)
if not dist.is_initialized():
initialize_distributed_env()

import torch_tensorrt

device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh()
logger = initialize_distributed_logger(_rank, "tensor_parallel_rotary_embedding")

from rotary_embedding import RotaryAttention, parallel_rotary_block

"""
This example covers the rotary embedding in Llama3 model and is derived from https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning
Command to run with single GPU: mpirun -n 1 --allow-run-as-root python tensor_parallel_rotary_embedding.py
Command to run with single GPU: USE_TRTLLM_PLUGINS=1 mpirun -n 1 --allow-run-as-root python tensor_parallel_rotary_embedding.py
Command to run with 2 GPUs: USE_TRTLLM_PLUGINS=1 mpirun -n 2 --allow-run-as-root python tensor_parallel_rotary_embedding.py
"""

BATCH = 2
Expand Down
19 changes: 14 additions & 5 deletions examples/distributed_inference/tensor_parallel_simple_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
-----
.. code-block:: bash

mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py
USE_TRTLLM_PLUGINS=1 mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py
"""

import time
Expand All @@ -25,22 +25,31 @@
import torch
import torch.distributed as dist
import torch.nn as nn
import torch_tensorrt
from tensor_parallel_initialize_dist import (
cleanup_distributed_env,
get_tensor_parallel_device_mesh,
initialize_distributed_env,
initialize_distributed_logger,
)

if not dist.is_initialized():
initialize_distributed_env()
import torch_tensorrt
from torch.distributed._tensor import Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
)

device_mesh, _world_size, _rank, logger = initialize_distributed_env(
"./tensor_parallel_simple_example"
from torch_tensorrt.dynamo.distributed.utils import (
get_tensor_parallel_device_mesh,
initialize_distributed_logger,
)

device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh()
logger = initialize_distributed_logger(_rank, "tensor_parallel_simple_example")


"""
This example takes some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
"""
Expand Down
88 changes: 56 additions & 32 deletions py/torch_tensorrt/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import platform
import sys
import tempfile
import time
import urllib.request
from pathlib import Path
from typing import Any, Optional
Expand Down Expand Up @@ -143,13 +144,65 @@ def _extracted_dir_trtllm(platform_system: str, platform_machine: str) -> Path:
)


def extract_wheel_file(wheel_path: Path, extract_dir: Path) -> None:
"""
Safely extract a wheel file to a directory with a lock to prevent concurrent extraction.
"""
rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", 0)) # MPI rank from OpenMPI
torch.cuda.set_device(rank)
lock_file = extract_dir / ".extracting"

# Rank 0 performs extraction
if rank == 0:
logger.debug(
f"[Rank {rank}] Starting extraction of {wheel_path} to {extract_dir}"
)
try:
import zipfile
except ImportError as e:
raise ImportError(
"zipfile module is required but not found. Please install zipfile"
)
# Create lock file to signal extraction in progress
extract_dir.mkdir(parents=True, exist_ok=False)
lock_file.touch(exist_ok=False)
try:
with zipfile.ZipFile(wheel_path) as zip_ref:
zip_ref.extractall(extract_dir)
logger.debug(f"[Rank {rank}] Extraction complete: {extract_dir}")
except FileNotFoundError as e:
logger.error(f"[Rank {rank}] Wheel file not found at {wheel_path}: {e}")
raise RuntimeError(
f"Failed to find downloaded wheel file at {wheel_path}"
) from e
except zipfile.BadZipFile as e:
logger.error(f"[Rank {rank}] Invalid or corrupted wheel file: {e}")
raise RuntimeError(
"Downloaded wheel file is corrupted or not a valid zip archive"
) from e
except Exception as e:
logger.error(f"[Rank {rank}] Unexpected error while extracting wheel: {e}")
raise RuntimeError(
"Unexpected error during extraction of TensorRT-LLM wheel"
) from e
finally:
# Remove lock file to signal completion
lock_file.unlink(missing_ok=True)

else:
# Other ranks wait for extraction to complete
while lock_file.exists():
logger.debug(
f"[Rank {rank}] Waiting for extraction to finish at {extract_dir}..."
)
time.sleep(0.5)


def download_and_get_plugin_lib_path() -> Optional[str]:
"""
Returns the path to the TensorRT‑LLM shared library, downloading and extracting if necessary.

Args:
platform (str): Platform identifier (e.g., 'linux_x86_64')

Returns:
Optional[str]: Path to shared library or None if operation fails.
"""
Expand All @@ -174,7 +227,6 @@ def download_and_get_plugin_lib_path() -> Optional[str]:
return str(plugin_lib_path)

wheel_path.parent.mkdir(parents=True, exist_ok=True)
extract_dir.mkdir(parents=True, exist_ok=True)

if not wheel_path.exists():
base_url = "https://pypi.nvidia.com/tensorrt-llm/"
Expand All @@ -194,32 +246,7 @@ def download_and_get_plugin_lib_path() -> Optional[str]:
except OSError as e:
logger.error(f"Local file write error: {e}")

try:
import zipfile
except ImportError as e:
raise ImportError(
"zipfile module is required but not found. Please install zipfile"
)
try:
with zipfile.ZipFile(wheel_path) as zip_ref:
zip_ref.extractall(extract_dir)
logger.debug(f"Extracted wheel to {extract_dir}")
except FileNotFoundError as e:
# This should capture the errors in the download failure above
logger.error(f"Wheel file not found at {wheel_path}: {e}")
raise RuntimeError(
f"Failed to find downloaded wheel file at {wheel_path}"
) from e
except zipfile.BadZipFile as e:
logger.error(f"Invalid or corrupted wheel file: {e}")
raise RuntimeError(
"Downloaded wheel file is corrupted or not a valid zip archive"
) from e
except Exception as e:
logger.error(f"Unexpected error while extracting wheel: {e}")
raise RuntimeError(
"Unexpected error during extraction of TensorRT-LLM wheel"
) from e
extract_wheel_file(wheel_path, extract_dir)

try:
wheel_path.unlink(missing_ok=True)
Expand All @@ -238,10 +265,8 @@ def download_and_get_plugin_lib_path() -> Optional[str]:
def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool:
"""
Loads and initializes the TensorRT-LLM plugin from the given shared library path.

Args:
plugin_lib_path (str): Path to the shared TensorRT-LLM plugin library.

Returns:
bool: True if successful, False otherwise.
"""
Expand Down Expand Up @@ -293,7 +318,6 @@ def load_tensorrt_llm_for_nccl() -> bool:
Attempts to load the TensorRT-LLM plugin and initialize it.
Either the env variable TRTLLM_PLUGINS_PATH can specify the path
Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it

Returns:
bool: True if the plugin was successfully loaded and initialized, False otherwise.
"""
Expand Down
36 changes: 11 additions & 25 deletions tests/py/dynamo/distributed/distributed_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import random

import numpy as np
import tensorrt as trt
Expand All @@ -8,24 +9,21 @@
from torch.distributed._tensor.device_mesh import init_device_mesh


def set_environment_variables_pytest():
# the below two functions are used to set the environment variables for the pytest single and multi process
# this is for the github CI where we use pytest
def set_environment_variables_pytest_single_process():
port = 29500 + random.randint(1, 1000)
os.environ["WORLD_SIZE"] = str(1)
os.environ["RANK"] = str(0)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(29500)


def initialize_logger(rank, logger_file_name):
logger = logging.getLogger()
logger.setLevel(logging.INFO)
fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w")
fh.setLevel(logging.INFO)
logger.addHandler(fh)
return logger
os.environ["MASTER_PORT"] = str(port)


# This is required for env initialization since we use mpirun
def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500):
def set_environment_variables_pytest_multi_process(
rank: int = 0, world_size: int = 1
) -> None:
port = 29500 + random.randint(1, 1000)
# these variables are set by mpirun -n 2
local_rank = int(
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
)
Expand All @@ -36,7 +34,6 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(port)
os.environ["TRTLLM_PLUGINS_PATH"] = "./tmp/lib/libnvinfer_plugin_tensorrt_llm.so"

# Necessary to assign a device to each rank.
torch.cuda.set_device(local_rank)
Expand All @@ -46,14 +43,3 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950

# set a manual seed for reproducibility
torch.manual_seed(1111)

device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,))
rank = device_mesh.get_rank()
assert rank == local_rank
logger = initialize_logger(rank, logger_file_name)
device_id = (
rank % torch.cuda.device_count()
) # Ensure each rank gets a unique device
torch.cuda.set_device(device_id)

return device_mesh, world_size, rank, logger
Loading
Loading