Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Tensor Parallel to torch_native_llama #1876

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,13 @@ def __init__(
min_per_gpu_memory = self.init_torch_distributed()
self.sampler = Sampler()
self.load_model()
if self.tp_size > 1:
logger.info(f"Tensor parallelism is enabled, {self.tp_size} devices will be used.")
merrymercy marked this conversation as resolved.
Show resolved Hide resolved
device_mesh = torch.distributed.init_device_mesh(
self.device, (self.tp_size,)
)
self.model.tensor_parallel(device_mesh)

if server_args.lora_paths is not None:
self.init_lora_manager()
self.init_memory_pool(
Expand Down
188 changes: 141 additions & 47 deletions python/sglang/srt/models/torch_native_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@
"""Inference-only LLaMA model compatible with HuggingFace weights."""

import types
from typing import Any, Dict, Iterable, Optional, Tuple
from typing import Any, Dict, Iterable, Optional, Tuple, Sequence, List

import torch
from torch import nn
from torch.nn.parameter import Parameter
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
)
from transformers import LlamaConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
Expand All @@ -42,34 +46,73 @@
from sglang.srt.model_executor.forward_batch_info import ForwardBatch


gate_up_proj_shard_offsets = {}


def build_gate_up_proj_shard_offsets(
output_sizes: Sequence[int],
):
global gate_up_proj_shard_offsets
# shard_id: (shard_offset, shard_size)
current_shard_offset = 0
for i, output_size in enumerate(output_sizes):
gate_up_proj_shard_offsets.setdefault(
i, (current_shard_offset, output_size)
)
current_shard_offset += output_size


def gate_up_proj_weight_loader(
self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None,
):
if loaded_shard_id is None:
shard_offsets: List[Tuple[int, int, int]] = []
for i, output_size in enumerate(self.output_sizes):
shard_offsets.append((i, current_shard_offset, output_size))
current_shard_offset += output_size
for shard_id, shard_offset, shard_size in shard_offsets:
for shard_id, (shard_offset, shard_size) in gate_up_proj_shard_offsets.items():
loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size
0, shard_offset, shard_size
)
self.weight_loader(param, loaded_weight_shard, shard_id)
else:
assert loaded_shard_id < len(self.output_sizes)
param_data = param.data
shard_size = loaded_weight.shape[0]
shard_offset = loaded_shard_id * shard_size
shard_offset, shard_size = gate_up_proj_shard_offsets[loaded_shard_id]
param_data = param_data.narrow(0, shard_offset, shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return


def shuffle_gate_up_proj_weight(
param: Parameter,
tp_size: int,
):
if tp_size == 1:
return

param_data = param.data
new_tensor = torch.empty_like(param_data)
tp_dim = 0
tp_chunk_size = param_data.shape[tp_dim] // tp_size
for shard_id, (shard_offset, shard_size) in gate_up_proj_shard_offsets.items():
tp_slice_size = shard_size // tp_size
for i in range(tp_size):
tp_slice_src_offset = shard_offset + i * tp_slice_size
tp_slice_dst_offset = i * tp_chunk_size + shard_offset // tp_size
src_slice = param_data.narrow(tp_dim, tp_slice_src_offset, tp_slice_size)
dst_slice = new_tensor.narrow(tp_dim, tp_slice_dst_offset, tp_slice_size)
dst_slice.copy_(src_slice)

param.data = new_tensor


class LlamaMLP(nn.Module):
_tp_plan = {
"gate_up_proj": ColwiseParallel(),
"down_proj": RowwiseParallel(),
}

def __init__(
self,
hidden_size: int,
Expand All @@ -85,6 +128,7 @@ def __init__(
bias=False,
)
self.gate_up_proj.output_sizes = [intermediate_size] * 2
build_gate_up_proj_shard_offsets(self.gate_up_proj.output_sizes)
self.gate_up_proj.weight_loader = types.MethodType(
gate_up_proj_weight_loader, self.gate_up_proj
)
Expand All @@ -104,23 +148,25 @@ def forward(self, x):
return x


def _get_shard_offset_mapping(self, loaded_shard_id: str):
shard_offset_mapping = {
"q": 0,
"k": self.num_heads * self.head_size,
"v": (self.num_heads + self.num_kv_heads) * self.head_size,
"total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size,
}
return shard_offset_mapping.get(loaded_shard_id)
qkv_proj_shard_offsets = {}


def _get_shard_size_mapping(self, loaded_shard_id: str):
shard_size_mapping = {
"q": self.num_heads * self.head_size,
"k": self.num_kv_heads * self.head_size,
"v": self.num_kv_heads * self.head_size,
}
return shard_size_mapping.get(loaded_shard_id)
def build_qkv_proj_shard_offsets(
num_heads: int,
num_kv_heads: int,
head_size: int,
):
global qkv_proj_shard_offsets
# shard_id: (shard_offset, shard_size)
qkv_proj_shard_offsets.setdefault(
"q", (0, num_heads * head_size)
)
qkv_proj_shard_offsets.setdefault(
"k", (num_heads * head_size, num_kv_heads * head_size)
)
qkv_proj_shard_offsets.setdefault(
"v", ((num_heads + num_kv_heads) * head_size, num_kv_heads * head_size)
)


def qkv_proj_weight_loader(
Expand All @@ -130,36 +176,50 @@ def qkv_proj_weight_loader(
loaded_shard_id: Optional[str] = None,
):
if loaded_shard_id is None:
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("q", 0, self.total_num_heads * self.head_size),
(
"k",
self.total_num_heads * self.head_size,
self.total_num_kv_heads * self.head_size,
),
(
"v",
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
self.total_num_kv_heads * self.head_size,
),
]
for shard_id, shard_offset, shard_size in shard_offsets:
for shard_id, (shard_offset, shard_size) in qkv_proj_shard_offsets.items():
loaded_weight_shard = loaded_weight.narrow(
param.output_dim, shard_offset, shard_size
)
self.weight_loader(param, loaded_weight_shard, shard_id)
else:
shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
shard_size = self._get_shard_size_mapping(loaded_shard_id)
shard_offset, shard_size = qkv_proj_shard_offsets[loaded_shard_id]
param_data = param.data
param_data = param_data.narrow(0, shard_offset, shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return


def shuffle_qkv_proj_weight(
merrymercy marked this conversation as resolved.
Show resolved Hide resolved
param: Parameter,
tp_size: int,
):
if tp_size == 1:
return

param_data = param.data
new_tensor = torch.empty_like(param_data)
tp_dim = 0
tp_chunk_size = param_data.shape[tp_dim] // tp_size
for shard_id in ["q", "k", "v"]:
shard_offset, shard_size = qkv_proj_shard_offsets[shard_id]
tp_slice_size = shard_size // tp_size
for i in range(tp_size):
tp_slice_src_offset = shard_offset + i * tp_slice_size
tp_slice_dst_offset = i * tp_chunk_size + shard_offset // tp_size
src_slice = param_data.narrow(tp_dim, tp_slice_src_offset, tp_slice_size)
dst_slice = new_tensor.narrow(tp_dim, tp_slice_dst_offset, tp_slice_size)
dst_slice.copy_(src_slice)

param.data = new_tensor


class LlamaAttention(nn.Module):
_tp_plan = {
"qkv_proj": ColwiseParallel(),
"o_proj": RowwiseParallel(),
}

def __init__(
self,
config: LlamaConfig,
Expand Down Expand Up @@ -205,6 +265,11 @@ def __init__(
(self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim,
bias=False,
)
build_qkv_proj_shard_offsets(
self.total_num_heads,
self.total_num_kv_heads,
self.head_dim,
)
self.qkv_proj.total_num_heads = self.total_num_heads
self.qkv_proj.head_size = self.head_dim
self.qkv_proj.total_num_kv_heads = self.total_num_kv_heads
Expand All @@ -213,12 +278,6 @@ def __init__(
self.qkv_proj.weight_loader = types.MethodType(
qkv_proj_weight_loader, self.qkv_proj
)
self.qkv_proj._get_shard_offset_mapping = types.MethodType(
_get_shard_offset_mapping, self.qkv_proj
)
self.qkv_proj._get_shard_size_mapping = types.MethodType(
_get_shard_size_mapping, self.qkv_proj
)
Comment on lines -216 to -221
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not used now.

self.qkv_proj.weight.weight_loader = self.qkv_proj.weight_loader
self.qkv_proj.weight.output_dim = 0
self.o_proj = torch.nn.Linear(
Expand Down Expand Up @@ -495,8 +554,43 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
param = self.lm_head.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, self.model.embed_tokens.weight)

# Re-arrange fused matrix for TP
tp_size = get_tensor_model_parallel_world_size()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here, is it possible to do:

  1. split qkv to 3 Tensors
  2. apply tp to each of the Tensor
  3. concat 3 Tensors to a single DTensor

This way we can rely on split/concat ops in DTensor itself instead of worrying about the implementation details?

Copy link
Author

@kwen2501 kwen2501 Nov 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe? Although, at this location, we haven't applied TP yet, so there is no notion of DTensor.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I see what you mean. We can use DTensor API instead of TP API (higher level) here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the newer version, I added support for TP'lized weight loading. Then we directly construct DTensor from the local shard. See the ColwiseParallelSharded strategy.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to understand, currently step 2 is manual right?

Copy link
Author

@kwen2501 kwen2501 Nov 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not manual per se. It is already packaged and can be called with parallelize_module as like other styles. So no evolvement needed from user or model author.

for name, param in params_dict.items():
# For these modules, we need to re-arrange the fused matrix to match
# tensor parallelism.
if ".qkv_proj" in name:
shuffle_qkv_proj_weight(param, tp_size)
elif ".gate_up_proj" in name:
shuffle_gate_up_proj_weight(param, tp_size)

apply_torchao_config_(self, params_dict, set(["proj.weight"]))

def tensor_parallel(self, device_mesh=None):
"""
Tensor parallelize the model across the given device mesh.
Args:
device_mesh (`torch.distributed.DeviceMesh`):
The device mesh to use for tensor parallelism.
"""
# Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module.
# No op if `_tp_plan` attribute does not exist under the module.
# This is a helper function to be used with `model.apply` to recursively
# parallelize a model.
def tplize(mod: torch.nn.Module) -> None:
tp_plan = getattr(mod, "_tp_plan", None)
if tp_plan:
torch.distributed.tensor.parallel.parallelize_module(
mod,
device_mesh=device_mesh,
parallelize_plan=tp_plan,
)

# `apply` is a native method of `nn.Module` that recursively applies a
# function to every submodule.
self.apply(tplize)


class TorchNativePhi3ForCausalLM(TorchNativeLlamaForCausalLM):
pass
Expand Down
Loading