Skip to content

Commit 0313372

Browse files
committed
Tensor parallel Llama3 tutorial illustrating use of torch.distributed and nccl ops
1 parent 543bc9b commit 0313372

File tree

2 files changed

+67
-3
lines changed

2 files changed

+67
-3
lines changed

docsrc/index.rst

+2
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ Tutorials
6868
* :ref:`mutable_torchtrt_module_example`
6969
* :ref:`weight_streaming_example`
7070
* :ref:`pre_allocated_output_example`
71+
* :ref:`tensor_parallel_llama`
7172

7273
.. toctree::
7374
:caption: Tutorials
@@ -87,6 +88,7 @@ Tutorials
8788
tutorials/_rendered_examples/dynamo/mutable_torchtrt_module_example
8889
tutorials/_rendered_examples/dynamo/weight_streaming_example
8990
tutorials/_rendered_examples/dynamo/pre_allocated_output_example
91+
tutorials/_rendered_examples/dynamo/tensor_parallel_llama
9092

9193
Dynamo Frontend
9294
----------------

examples/distributed_inference/tensor_parallel_llama3.py

+65-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,27 @@
11
# Taken and modified pytorch lightening
22
# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning
3+
"""
4+
.. _tensor_parallel_llama:
5+
6+
Torch distributed example for llama3-7B model
7+
======================================================
8+
9+
As model sizes are increasing, large models with billions of parameters are trained with many GPUs, where regular data parallel training is no longer possible. In this example, we illustrate the Llama3-7B model inference using Torch-TensorRT backend, split across multiple GPUs using a form of model parallelism called Tensor Parallelism. We make use of Pytorch Distributed Tensor Parallelism Module. Please refer to these tutorials- https://pytorch.org/tutorials/intermediate/TP_tutorial.html and https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning?section=featured"""
10+
11+
# %%
12+
# Imports and Model Definition
13+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
14+
315
import logging
416
import os
517
import time
6-
718
import torch
19+
20+
# %%
21+
# Pytorch Tensor Parallel APIs offer set of module level primitives(ParallelStyle) to configure the sharding of tensors in each layer of the model
22+
# ParallelTransformer creates the parallelize_plan for the FeedForward layer of the model
823
from llama3_model import ModelArgs, ParallelTransformer
24+
925
from tensor_parallel_initialize_dist import initialize_distributed_env
1026
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
1127
from torch.distributed._composable.fsdp.fully_shard import fully_shard
@@ -14,11 +30,24 @@
1430
checkpoint_wrapper,
1531
)
1632

33+
# %%
34+
# Initialize the distributed environment
35+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
36+
37+
# Depending on the inputs/outputs sharded DTensors layout specified above, proper communication operations are required to transform DTensor layouts
38+
# eg operations: allreduce, allgather, reduce_gather
39+
# NCCL operations enable these operations.
40+
# The below API does the following
41+
# Initialize the communicators and the distributed environment
42+
# Sets the path for the TRT-LLM plugin .so path which is required for the NCCL operations in Torch-TRT backend. Please note that if you are in python3.10 environment, `import tensorrt_llm` should be enough
43+
# Initialize the logger. eg: In case of 2 GPUs, the log files are `./tensor_parallel_llama3_0.log` and `./tensor_parallel_llama3_1.log`
1744
device_mesh, _world_size, _rank, logger = initialize_distributed_env(
1845
"./tensor_parallel_llama3"
1946
)
20-
# Import should be after initialization of the TRT-LLM plugin .so path
21-
import tensorrt_llm
47+
48+
# %%
49+
# Model initialization with torch distributed parallel plan
50+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2251

2352
logger.info(f"Starting PyTorch TP example on rank {_rank}.")
2453
assert (
@@ -36,7 +65,38 @@
3665
)
3766

3867
with torch.no_grad():
68+
# The plan is
69+
#plan = {
70+
# "attention": PrepareModuleInput(
71+
# input_layouts=(Shard(1), None),
72+
# desired_input_layouts=(Replicate(), None),
73+
# ),
74+
# "attention.wq": ColwiseParallel(),
75+
# "attention.wk": ColwiseParallel(),
76+
# "attention.wv": ColwiseParallel(),
77+
# "attention.wo": RowwiseParallel(output_layouts=Shard(1)),
78+
# "attention_norm": SequenceParallel(),
79+
# "feed_forward": PrepareModuleInput(
80+
# input_layouts=(Shard(1),),
81+
# desired_input_layouts=(Replicate(),),
82+
# ),
83+
# "feed_forward.w1": ColwiseParallel(),
84+
# "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
85+
# "feed_forward.w3": ColwiseParallel(),
86+
# "ffn_norm": SequenceParallel(),
87+
#}
88+
3989
model = ParallelTransformer(model_args, device_mesh)
90+
91+
# %%
92+
# Model inference with Torch-TensorRT backend
93+
# -------------------------------------------
94+
# When we compile the distributed model using Torch-TensorRT backend, pytorch distributed libraries create the sharded model
95+
# on multiple GPUs and the communicator operations are used for proper communication. In the above,
96+
# `ColwiseParallel` and `RowwiseParallel` shard the attention layers in the column or row fashion.
97+
# `SequenceParallel` performs sharded computations of the normalization layer
98+
# `PrepareModuleInput` configures the model input with proper communication operations
99+
40100
torch.manual_seed(0)
41101
inp = torch.randint(32000, (8, 256), device="cuda")
42102
python_result = model(inp)
@@ -62,9 +122,11 @@
62122
output = model(inp)
63123
end = time.time()
64124
if i == 0:
125+
# Logging the Compilation time
65126
logger.info(f"Compilation time is {end-start}")
66127
assert (
67128
python_result - output
68129
).std() < 0.01, "Compilation result is not correct."
69130
elif _rank == 0:
131+
# Logging the inference time
70132
logger.info(f"Inference time is {end-start}")

0 commit comments

Comments
 (0)