-
Notifications
You must be signed in to change notification settings - Fork 496
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Tensor Parallel to torch_native_llama #1876
base: main
Are you sure you want to change the base?
Changes from 5 commits
d25b19e
ebb4c75
e78520c
11a153c
5cc3ca6
ee80b5d
4488442
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
""" | ||
Common utilities for torch model parallelism. | ||
""" | ||
|
||
from typing import Optional, Sequence | ||
|
||
import torch | ||
from torch.distributed.device_mesh import DeviceMesh | ||
try: | ||
from torch.distributed.tensor import DTensor, Shard | ||
except ImportError: | ||
# torch 2.4 or older | ||
from torch.distributed._tensor import DTensor, Shard | ||
from torch.distributed.tensor.parallel import ( | ||
ColwiseParallel, | ||
RowwiseParallel, | ||
parallelize_module, | ||
) | ||
|
||
|
||
class ColwiseParallelSharded(ColwiseParallel): | ||
""" | ||
A version of ColwiseParallel where the local weight has been already | ||
sharded. This is used for the fused wqkv case, where during loading, we | ||
already sharded wq, wk, wv before fusing them. | ||
""" | ||
# Override the _partition_linear_fn in ColwiseParallel | ||
def _partition_linear_fn(self, name, module, device_mesh): | ||
# colwise shard weight/bias to Shard(0), weight be Shard(0) | ||
# means Colwise as Linear is input * weight^T + bias, where | ||
# weight would become Shard(1) | ||
for name, param in module.named_parameters(): | ||
dtensor = DTensor.from_local( | ||
param, device_mesh, [Shard(0)] | ||
) | ||
dist_param = torch.nn.Parameter(dtensor, requires_grad=False) | ||
module.register_parameter(name, dist_param) | ||
|
||
|
||
def tensor_parallel( | ||
module: torch.nn.Module, | ||
device_mesh: Optional[DeviceMesh] = None, | ||
): | ||
""" | ||
Tensor parallelize the model across the given device mesh. | ||
Args: | ||
module (`torch.nn.Module`): | ||
The module to tensor parallelize. | ||
device_mesh (`torch.distributed.DeviceMesh`): | ||
The device mesh to use for tensor parallelism. | ||
""" | ||
# Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module. | ||
# No op if `_tp_plan` attribute does not exist under the module. | ||
# This is a helper function to be used with `model.apply` to recursively | ||
# parallelize a model. | ||
def tplize(mod: torch.nn.Module) -> None: | ||
tp_plan = getattr(mod, "_tp_plan", None) | ||
if tp_plan is None: | ||
return | ||
for child_name, tp_style in tp_plan.items(): | ||
submod = mod.get_submodule(child_name) | ||
if tp_style == "Colwise": | ||
parallelize_module(submod, device_mesh, ColwiseParallel()) | ||
elif tp_style == "Rowwise": | ||
parallelize_module(submod, device_mesh, RowwiseParallel()) | ||
elif tp_style == "Colwise_Sharded": | ||
parallelize_module(submod, device_mesh, ColwiseParallelSharded()) | ||
else: | ||
raise ValueError(f"Unknown TP style {tp_style}") | ||
|
||
# `apply` is a native method of `nn.Module` that recursively applies a | ||
# function to every submodule. | ||
module.apply(tplize) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,7 +24,7 @@ | |
from torch import nn | ||
from torch.nn.parameter import Parameter | ||
from transformers import LlamaConfig | ||
from vllm.distributed import get_tensor_model_parallel_world_size | ||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank | ||
from vllm.model_executor.layers.rotary_embedding import get_rope | ||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader | ||
|
||
|
@@ -42,34 +42,45 @@ | |
from sglang.srt.model_executor.forward_batch_info import ForwardBatch | ||
|
||
|
||
tp_size = get_tensor_model_parallel_world_size() | ||
tp_rank = get_tensor_model_parallel_rank() | ||
|
||
|
||
def gate_up_proj_weight_loader( | ||
self, | ||
param: Parameter, | ||
loaded_weight: torch.Tensor, | ||
loaded_shard_id: Optional[int] = None, | ||
): | ||
# shard_id: (shard_offset, shard_size) | ||
gate_up_offsets = {} | ||
current_shard_offset = 0 | ||
for i, output_size in enumerate(self.output_sizes): | ||
gate_up_offsets[i] = (current_shard_offset, output_size) | ||
current_shard_offset += output_size | ||
if loaded_shard_id is None: | ||
shard_offsets: List[Tuple[int, int, int]] = [] | ||
for i, output_size in enumerate(self.output_sizes): | ||
shard_offsets.append((i, current_shard_offset, output_size)) | ||
current_shard_offset += output_size | ||
for shard_id, shard_offset, shard_size in shard_offsets: | ||
for shard_id, (shard_offset, shard_size) in gate_up_offsets.items(): | ||
loaded_weight_shard = loaded_weight.narrow( | ||
output_dim, shard_offset, shard_size | ||
0, shard_offset, shard_size | ||
) | ||
self.weight_loader(param, loaded_weight_shard, shard_id) | ||
else: | ||
assert loaded_shard_id < len(self.output_sizes) | ||
param_data = param.data | ||
shard_size = loaded_weight.shape[0] | ||
shard_offset = loaded_shard_id * shard_size | ||
shard_offset, shard_size = gate_up_offsets[loaded_shard_id] | ||
param_data = param_data.narrow(0, shard_offset, shard_size) | ||
loaded_weight = loaded_weight.narrow(0, tp_rank * shard_size, shard_size) | ||
assert param_data.shape == loaded_weight.shape | ||
param_data.copy_(loaded_weight) | ||
return | ||
|
||
|
||
class LlamaMLP(nn.Module): | ||
_tp_plan = { | ||
"gate_up_proj": "Colwise_Sharded", | ||
"down_proj": "Rowwise", | ||
} | ||
|
||
def __init__( | ||
self, | ||
hidden_size: int, | ||
|
@@ -84,7 +95,7 @@ def __init__( | |
intermediate_size * 2, | ||
bias=False, | ||
) | ||
self.gate_up_proj.output_sizes = [intermediate_size] * 2 | ||
self.gate_up_proj.output_sizes = [intermediate_size // tp_size] * 2 | ||
kwen2501 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.gate_up_proj.weight_loader = types.MethodType( | ||
gate_up_proj_weight_loader, self.gate_up_proj | ||
) | ||
|
@@ -104,62 +115,40 @@ def forward(self, x): | |
return x | ||
|
||
|
||
def _get_shard_offset_mapping(self, loaded_shard_id: str): | ||
shard_offset_mapping = { | ||
"q": 0, | ||
"k": self.num_heads * self.head_size, | ||
"v": (self.num_heads + self.num_kv_heads) * self.head_size, | ||
"total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size, | ||
} | ||
return shard_offset_mapping.get(loaded_shard_id) | ||
|
||
|
||
def _get_shard_size_mapping(self, loaded_shard_id: str): | ||
shard_size_mapping = { | ||
"q": self.num_heads * self.head_size, | ||
"k": self.num_kv_heads * self.head_size, | ||
"v": self.num_kv_heads * self.head_size, | ||
} | ||
return shard_size_mapping.get(loaded_shard_id) | ||
|
||
|
||
def qkv_proj_weight_loader( | ||
self, | ||
param: Parameter, | ||
loaded_weight: torch.Tensor, | ||
loaded_shard_id: Optional[str] = None, | ||
): | ||
# shard_id: (shard_offset, shard_size) | ||
qkv_offsets = { | ||
"q": (0, self.num_heads * self.head_size), | ||
"k": (self.num_heads * self.head_size, self.num_kv_heads * self.head_size), | ||
"v": ((self.num_heads + self.num_kv_heads) * self.head_size, self.num_kv_heads * self.head_size), | ||
} | ||
if loaded_shard_id is None: | ||
shard_offsets = [ | ||
# (shard_id, shard_offset, shard_size) | ||
("q", 0, self.total_num_heads * self.head_size), | ||
( | ||
"k", | ||
self.total_num_heads * self.head_size, | ||
self.total_num_kv_heads * self.head_size, | ||
), | ||
( | ||
"v", | ||
(self.total_num_heads + self.total_num_kv_heads) * self.head_size, | ||
self.total_num_kv_heads * self.head_size, | ||
), | ||
] | ||
for shard_id, shard_offset, shard_size in shard_offsets: | ||
for shard_id, (shard_offset, shard_size) in qkv_offsets.items(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are style changes only. |
||
loaded_weight_shard = loaded_weight.narrow( | ||
param.output_dim, shard_offset, shard_size | ||
) | ||
self.weight_loader(param, loaded_weight_shard, shard_id) | ||
else: | ||
shard_offset = self._get_shard_offset_mapping(loaded_shard_id) | ||
shard_size = self._get_shard_size_mapping(loaded_shard_id) | ||
shard_offset, shard_size = qkv_offsets[loaded_shard_id] | ||
param_data = param.data | ||
param_data = param_data.narrow(0, shard_offset, shard_size) | ||
loaded_weight = loaded_weight.narrow(0, tp_rank * shard_size, shard_size) | ||
assert param_data.shape == loaded_weight.shape | ||
param_data.copy_(loaded_weight) | ||
return | ||
|
||
|
||
class LlamaAttention(nn.Module): | ||
_tp_plan = { | ||
"qkv_proj": "Colwise_Sharded", | ||
"o_proj": "Rowwise", | ||
} | ||
|
||
def __init__( | ||
self, | ||
config: LlamaConfig, | ||
|
@@ -176,7 +165,6 @@ def __init__( | |
) -> None: | ||
super().__init__() | ||
self.hidden_size = hidden_size | ||
tp_size = get_tensor_model_parallel_world_size() | ||
self.total_num_heads = num_heads | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see it seems that we are already doing manual sharding here, I do feel this code should move to separate tp related code instead of being embedded in the model if possible There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, it seems the "local"
I am not sure if I can remove it given that it involves a contract change with that module. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think ideally we want to keep all tp logic separate from the model code, that way we can apply this to other models without modifying model at all |
||
assert self.total_num_heads % tp_size == 0 | ||
self.num_heads = self.total_num_heads // tp_size | ||
|
@@ -208,17 +196,11 @@ def __init__( | |
self.qkv_proj.total_num_heads = self.total_num_heads | ||
self.qkv_proj.head_size = self.head_dim | ||
self.qkv_proj.total_num_kv_heads = self.total_num_kv_heads | ||
self.qkv_proj.num_heads = self.total_num_heads | ||
self.qkv_proj.num_kv_heads = self.total_num_kv_heads | ||
self.qkv_proj.num_heads = self.num_heads | ||
self.qkv_proj.num_kv_heads = self.num_kv_heads | ||
kwen2501 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.qkv_proj.weight_loader = types.MethodType( | ||
qkv_proj_weight_loader, self.qkv_proj | ||
) | ||
self.qkv_proj._get_shard_offset_mapping = types.MethodType( | ||
_get_shard_offset_mapping, self.qkv_proj | ||
) | ||
self.qkv_proj._get_shard_size_mapping = types.MethodType( | ||
_get_shard_size_mapping, self.qkv_proj | ||
) | ||
Comment on lines
-216
to
-221
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not used now. |
||
self.qkv_proj.weight.weight_loader = self.qkv_proj.weight_loader | ||
self.qkv_proj.weight.output_dim = 0 | ||
self.o_proj = torch.nn.Linear( | ||
|
@@ -385,6 +367,7 @@ def __init__( | |
self.config = config | ||
self.quant_config = quant_config | ||
self.torchao_config = global_server_args_dict["torchao_config"] | ||
self.supports_torch_tp = True | ||
self.model = LlamaModel(config, quant_config=quant_config) | ||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) | ||
self.logits_processor = LogitsProcessor(config) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are style changes only.