Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Tensor Parallel to torch_native_llama #1876

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,12 @@ def __init__(
min_per_gpu_memory = self.init_torch_distributed()
self.sampler = Sampler()
self.load_model()

# Apply torch TP if model supports it
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
if self.tp_size > 1 and supports_torch_tp:
self.apply_torch_tp()

if server_args.lora_paths is not None:
self.init_lora_manager()
self.init_memory_pool(
Expand Down Expand Up @@ -556,6 +562,15 @@ def init_cuda_graphs(self):
logger.info("Capture cuda graph begin. This can take up to several minutes.")
self.cuda_graph_runner = CudaGraphRunner(self)

def apply_torch_tp(self):
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
from sglang.srt.model_parallel import tensor_parallel

device_mesh = torch.distributed.init_device_mesh(
self.device, (self.tp_size,)
)
tensor_parallel(self.model, device_mesh)

def forward_decode(self, forward_batch: ForwardBatch):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
return self.cuda_graph_runner.replay(forward_batch)
Expand Down
73 changes: 73 additions & 0 deletions python/sglang/srt/model_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
Common utilities for torch model parallelism.
"""

from typing import Optional, Sequence

import torch
from torch.distributed.device_mesh import DeviceMesh
try:
from torch.distributed.tensor import DTensor, Shard
except ImportError:
# torch 2.4 or older
from torch.distributed._tensor import DTensor, Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
)


class ColwiseParallelSharded(ColwiseParallel):
"""
A version of ColwiseParallel where the local weight has been already
sharded. This is used for the fused wqkv case, where during loading, we
already sharded wq, wk, wv before fusing them.
"""
# Override the _partition_linear_fn in ColwiseParallel
def _partition_linear_fn(self, name, module, device_mesh):
# colwise shard weight/bias to Shard(0), weight be Shard(0)
# means Colwise as Linear is input * weight^T + bias, where
# weight would become Shard(1)
for name, param in module.named_parameters():
dtensor = DTensor.from_local(
param, device_mesh, [Shard(0)]
)
dist_param = torch.nn.Parameter(dtensor, requires_grad=False)
module.register_parameter(name, dist_param)


def tensor_parallel(
module: torch.nn.Module,
device_mesh: Optional[DeviceMesh] = None,
):
"""
Tensor parallelize the model across the given device mesh.
Args:
module (`torch.nn.Module`):
The module to tensor parallelize.
device_mesh (`torch.distributed.DeviceMesh`):
The device mesh to use for tensor parallelism.
"""
# Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module.
# No op if `_tp_plan` attribute does not exist under the module.
# This is a helper function to be used with `model.apply` to recursively
# parallelize a model.
def tplize(mod: torch.nn.Module) -> None:
tp_plan = getattr(mod, "_tp_plan", None)
if tp_plan is None:
return
for child_name, tp_style in tp_plan.items():
submod = mod.get_submodule(child_name)
if tp_style == "Colwise":
parallelize_module(submod, device_mesh, ColwiseParallel())
elif tp_style == "Rowwise":
parallelize_module(submod, device_mesh, RowwiseParallel())
elif tp_style == "Colwise_Sharded":
parallelize_module(submod, device_mesh, ColwiseParallelSharded())
else:
raise ValueError(f"Unknown TP style {tp_style}")

# `apply` is a native method of `nn.Module` that recursively applies a
# function to every submodule.
module.apply(tplize)
93 changes: 38 additions & 55 deletions python/sglang/srt/models/torch_native_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torch import nn
from torch.nn.parameter import Parameter
from transformers import LlamaConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader

Expand All @@ -42,34 +42,45 @@
from sglang.srt.model_executor.forward_batch_info import ForwardBatch


tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()


def gate_up_proj_weight_loader(
self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None,
):
# shard_id: (shard_offset, shard_size)
gate_up_offsets = {}
current_shard_offset = 0
for i, output_size in enumerate(self.output_sizes):
gate_up_offsets[i] = (current_shard_offset, output_size)
current_shard_offset += output_size
if loaded_shard_id is None:
shard_offsets: List[Tuple[int, int, int]] = []
for i, output_size in enumerate(self.output_sizes):
shard_offsets.append((i, current_shard_offset, output_size))
current_shard_offset += output_size
for shard_id, shard_offset, shard_size in shard_offsets:
for shard_id, (shard_offset, shard_size) in gate_up_offsets.items():
loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size
0, shard_offset, shard_size
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are style changes only.

)
self.weight_loader(param, loaded_weight_shard, shard_id)
else:
assert loaded_shard_id < len(self.output_sizes)
param_data = param.data
shard_size = loaded_weight.shape[0]
shard_offset = loaded_shard_id * shard_size
shard_offset, shard_size = gate_up_offsets[loaded_shard_id]
param_data = param_data.narrow(0, shard_offset, shard_size)
loaded_weight = loaded_weight.narrow(0, tp_rank * shard_size, shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return


class LlamaMLP(nn.Module):
_tp_plan = {
"gate_up_proj": "Colwise_Sharded",
"down_proj": "Rowwise",
}

def __init__(
self,
hidden_size: int,
Expand All @@ -84,7 +95,7 @@ def __init__(
intermediate_size * 2,
bias=False,
)
self.gate_up_proj.output_sizes = [intermediate_size] * 2
self.gate_up_proj.output_sizes = [intermediate_size // tp_size] * 2
kwen2501 marked this conversation as resolved.
Show resolved Hide resolved
self.gate_up_proj.weight_loader = types.MethodType(
gate_up_proj_weight_loader, self.gate_up_proj
)
Expand All @@ -104,62 +115,40 @@ def forward(self, x):
return x


def _get_shard_offset_mapping(self, loaded_shard_id: str):
shard_offset_mapping = {
"q": 0,
"k": self.num_heads * self.head_size,
"v": (self.num_heads + self.num_kv_heads) * self.head_size,
"total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size,
}
return shard_offset_mapping.get(loaded_shard_id)


def _get_shard_size_mapping(self, loaded_shard_id: str):
shard_size_mapping = {
"q": self.num_heads * self.head_size,
"k": self.num_kv_heads * self.head_size,
"v": self.num_kv_heads * self.head_size,
}
return shard_size_mapping.get(loaded_shard_id)


def qkv_proj_weight_loader(
self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None,
):
# shard_id: (shard_offset, shard_size)
qkv_offsets = {
"q": (0, self.num_heads * self.head_size),
"k": (self.num_heads * self.head_size, self.num_kv_heads * self.head_size),
"v": ((self.num_heads + self.num_kv_heads) * self.head_size, self.num_kv_heads * self.head_size),
}
if loaded_shard_id is None:
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("q", 0, self.total_num_heads * self.head_size),
(
"k",
self.total_num_heads * self.head_size,
self.total_num_kv_heads * self.head_size,
),
(
"v",
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
self.total_num_kv_heads * self.head_size,
),
]
for shard_id, shard_offset, shard_size in shard_offsets:
for shard_id, (shard_offset, shard_size) in qkv_offsets.items():
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are style changes only.

loaded_weight_shard = loaded_weight.narrow(
param.output_dim, shard_offset, shard_size
)
self.weight_loader(param, loaded_weight_shard, shard_id)
else:
shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
shard_size = self._get_shard_size_mapping(loaded_shard_id)
shard_offset, shard_size = qkv_offsets[loaded_shard_id]
param_data = param.data
param_data = param_data.narrow(0, shard_offset, shard_size)
loaded_weight = loaded_weight.narrow(0, tp_rank * shard_size, shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return


class LlamaAttention(nn.Module):
_tp_plan = {
"qkv_proj": "Colwise_Sharded",
"o_proj": "Rowwise",
}

def __init__(
self,
config: LlamaConfig,
Expand All @@ -176,7 +165,6 @@ def __init__(
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see it seems that we are already doing manual sharding here, I do feel this code should move to separate tp related code instead of being embedded in the model if possible

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it seems the "local" n_heads are needed for constructing the RadixAttention later.

self.attn = RadixAttention(
    self.num_heads,
    ...

I am not sure if I can remove it given that it involves a contract change with that module.

Copy link
Contributor

@jerryzh168 jerryzh168 Nov 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think ideally we want to keep all tp logic separate from the model code, that way we can apply this to other models without modifying model at all

assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
Expand Down Expand Up @@ -208,17 +196,11 @@ def __init__(
self.qkv_proj.total_num_heads = self.total_num_heads
self.qkv_proj.head_size = self.head_dim
self.qkv_proj.total_num_kv_heads = self.total_num_kv_heads
self.qkv_proj.num_heads = self.total_num_heads
self.qkv_proj.num_kv_heads = self.total_num_kv_heads
self.qkv_proj.num_heads = self.num_heads
self.qkv_proj.num_kv_heads = self.num_kv_heads
kwen2501 marked this conversation as resolved.
Show resolved Hide resolved
self.qkv_proj.weight_loader = types.MethodType(
qkv_proj_weight_loader, self.qkv_proj
)
self.qkv_proj._get_shard_offset_mapping = types.MethodType(
_get_shard_offset_mapping, self.qkv_proj
)
self.qkv_proj._get_shard_size_mapping = types.MethodType(
_get_shard_size_mapping, self.qkv_proj
)
Comment on lines -216 to -221
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not used now.

self.qkv_proj.weight.weight_loader = self.qkv_proj.weight_loader
self.qkv_proj.weight.output_dim = 0
self.o_proj = torch.nn.Linear(
Expand Down Expand Up @@ -385,6 +367,7 @@ def __init__(
self.config = config
self.quant_config = quant_config
self.torchao_config = global_server_args_dict["torchao_config"]
self.supports_torch_tp = True
self.model = LlamaModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
Expand Down