From d25b19e6a834f31bbba5149e9c84b0f990a206ca Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 31 Oct 2024 15:27:12 -0700 Subject: [PATCH 1/9] Add Tensor Parallel to torch_native_llama --- .../sglang/srt/model_executor/model_runner.py | 7 + .../sglang/srt/models/torch_native_llama.py | 188 +++++++++++++----- 2 files changed, 148 insertions(+), 47 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 583cbd968c6..a511eb91904 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -153,6 +153,13 @@ def __init__( min_per_gpu_memory = self.init_torch_distributed() self.sampler = Sampler() self.load_model() + if self.tp_size > 1: + logger.info(f"Tensor parallelism is enabled, {self.tp_size} devices will be used.") + device_mesh = torch.distributed.init_device_mesh( + self.device, (self.tp_size,) + ) + self.model.tensor_parallel(device_mesh) + if server_args.lora_paths is not None: self.init_lora_manager() self.init_memory_pool( diff --git a/python/sglang/srt/models/torch_native_llama.py b/python/sglang/srt/models/torch_native_llama.py index d9ce05b8a65..46d78f8f991 100644 --- a/python/sglang/srt/models/torch_native_llama.py +++ b/python/sglang/srt/models/torch_native_llama.py @@ -18,11 +18,15 @@ """Inference-only LLaMA model compatible with HuggingFace weights.""" import types -from typing import Any, Dict, Iterable, Optional, Tuple +from typing import Any, Dict, Iterable, Optional, Tuple, Sequence, List import torch from torch import nn from torch.nn.parameter import Parameter +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + RowwiseParallel, +) from transformers import LlamaConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.rotary_embedding import get_rope @@ -42,6 +46,22 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch +gate_up_proj_shard_offsets = {} + + +def build_gate_up_proj_shard_offsets( + output_sizes: Sequence[int], +): + global gate_up_proj_shard_offsets + # shard_id: (shard_offset, shard_size) + current_shard_offset = 0 + for i, output_size in enumerate(output_sizes): + gate_up_proj_shard_offsets.setdefault( + i, (current_shard_offset, output_size) + ) + current_shard_offset += output_size + + def gate_up_proj_weight_loader( self, param: Parameter, @@ -49,27 +69,50 @@ def gate_up_proj_weight_loader( loaded_shard_id: Optional[int] = None, ): if loaded_shard_id is None: - shard_offsets: List[Tuple[int, int, int]] = [] - for i, output_size in enumerate(self.output_sizes): - shard_offsets.append((i, current_shard_offset, output_size)) - current_shard_offset += output_size - for shard_id, shard_offset, shard_size in shard_offsets: + for shard_id, (shard_offset, shard_size) in gate_up_proj_shard_offsets.items(): loaded_weight_shard = loaded_weight.narrow( - output_dim, shard_offset, shard_size + 0, shard_offset, shard_size ) self.weight_loader(param, loaded_weight_shard, shard_id) else: assert loaded_shard_id < len(self.output_sizes) param_data = param.data - shard_size = loaded_weight.shape[0] - shard_offset = loaded_shard_id * shard_size + shard_offset, shard_size = gate_up_proj_shard_offsets[loaded_shard_id] param_data = param_data.narrow(0, shard_offset, shard_size) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) return +def shuffle_gate_up_proj_weight( + param: Parameter, + tp_size: int, +): + if tp_size == 1: + return + + param_data = param.data + new_tensor = torch.empty_like(param_data) + tp_dim = 0 + tp_chunk_size = param_data.shape[tp_dim] // tp_size + for shard_id, (shard_offset, shard_size) in gate_up_proj_shard_offsets.items(): + tp_slice_size = shard_size // tp_size + for i in range(tp_size): + tp_slice_src_offset = shard_offset + i * tp_slice_size + tp_slice_dst_offset = i * tp_chunk_size + shard_offset // tp_size + src_slice = param_data.narrow(tp_dim, tp_slice_src_offset, tp_slice_size) + dst_slice = new_tensor.narrow(tp_dim, tp_slice_dst_offset, tp_slice_size) + dst_slice.copy_(src_slice) + + param.data = new_tensor + + class LlamaMLP(nn.Module): + _tp_plan = { + "gate_up_proj": ColwiseParallel(), + "down_proj": RowwiseParallel(), + } + def __init__( self, hidden_size: int, @@ -85,6 +128,7 @@ def __init__( bias=False, ) self.gate_up_proj.output_sizes = [intermediate_size] * 2 + build_gate_up_proj_shard_offsets(self.gate_up_proj.output_sizes) self.gate_up_proj.weight_loader = types.MethodType( gate_up_proj_weight_loader, self.gate_up_proj ) @@ -104,23 +148,25 @@ def forward(self, x): return x -def _get_shard_offset_mapping(self, loaded_shard_id: str): - shard_offset_mapping = { - "q": 0, - "k": self.num_heads * self.head_size, - "v": (self.num_heads + self.num_kv_heads) * self.head_size, - "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size, - } - return shard_offset_mapping.get(loaded_shard_id) +qkv_proj_shard_offsets = {} -def _get_shard_size_mapping(self, loaded_shard_id: str): - shard_size_mapping = { - "q": self.num_heads * self.head_size, - "k": self.num_kv_heads * self.head_size, - "v": self.num_kv_heads * self.head_size, - } - return shard_size_mapping.get(loaded_shard_id) +def build_qkv_proj_shard_offsets( + num_heads: int, + num_kv_heads: int, + head_size: int, +): + global qkv_proj_shard_offsets + # shard_id: (shard_offset, shard_size) + qkv_proj_shard_offsets.setdefault( + "q", (0, num_heads * head_size) + ) + qkv_proj_shard_offsets.setdefault( + "k", (num_heads * head_size, num_kv_heads * head_size) + ) + qkv_proj_shard_offsets.setdefault( + "v", ((num_heads + num_kv_heads) * head_size, num_kv_heads * head_size) + ) def qkv_proj_weight_loader( @@ -130,28 +176,13 @@ def qkv_proj_weight_loader( loaded_shard_id: Optional[str] = None, ): if loaded_shard_id is None: - shard_offsets = [ - # (shard_id, shard_offset, shard_size) - ("q", 0, self.total_num_heads * self.head_size), - ( - "k", - self.total_num_heads * self.head_size, - self.total_num_kv_heads * self.head_size, - ), - ( - "v", - (self.total_num_heads + self.total_num_kv_heads) * self.head_size, - self.total_num_kv_heads * self.head_size, - ), - ] - for shard_id, shard_offset, shard_size in shard_offsets: + for shard_id, (shard_offset, shard_size) in qkv_proj_shard_offsets.items(): loaded_weight_shard = loaded_weight.narrow( param.output_dim, shard_offset, shard_size ) self.weight_loader(param, loaded_weight_shard, shard_id) else: - shard_offset = self._get_shard_offset_mapping(loaded_shard_id) - shard_size = self._get_shard_size_mapping(loaded_shard_id) + shard_offset, shard_size = qkv_proj_shard_offsets[loaded_shard_id] param_data = param.data param_data = param_data.narrow(0, shard_offset, shard_size) assert param_data.shape == loaded_weight.shape @@ -159,7 +190,36 @@ def qkv_proj_weight_loader( return +def shuffle_qkv_proj_weight( + param: Parameter, + tp_size: int, +): + if tp_size == 1: + return + + param_data = param.data + new_tensor = torch.empty_like(param_data) + tp_dim = 0 + tp_chunk_size = param_data.shape[tp_dim] // tp_size + for shard_id in ["q", "k", "v"]: + shard_offset, shard_size = qkv_proj_shard_offsets[shard_id] + tp_slice_size = shard_size // tp_size + for i in range(tp_size): + tp_slice_src_offset = shard_offset + i * tp_slice_size + tp_slice_dst_offset = i * tp_chunk_size + shard_offset // tp_size + src_slice = param_data.narrow(tp_dim, tp_slice_src_offset, tp_slice_size) + dst_slice = new_tensor.narrow(tp_dim, tp_slice_dst_offset, tp_slice_size) + dst_slice.copy_(src_slice) + + param.data = new_tensor + + class LlamaAttention(nn.Module): + _tp_plan = { + "qkv_proj": ColwiseParallel(), + "o_proj": RowwiseParallel(), + } + def __init__( self, config: LlamaConfig, @@ -205,6 +265,11 @@ def __init__( (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim, bias=False, ) + build_qkv_proj_shard_offsets( + self.total_num_heads, + self.total_num_kv_heads, + self.head_dim, + ) self.qkv_proj.total_num_heads = self.total_num_heads self.qkv_proj.head_size = self.head_dim self.qkv_proj.total_num_kv_heads = self.total_num_kv_heads @@ -213,12 +278,6 @@ def __init__( self.qkv_proj.weight_loader = types.MethodType( qkv_proj_weight_loader, self.qkv_proj ) - self.qkv_proj._get_shard_offset_mapping = types.MethodType( - _get_shard_offset_mapping, self.qkv_proj - ) - self.qkv_proj._get_shard_size_mapping = types.MethodType( - _get_shard_size_mapping, self.qkv_proj - ) self.qkv_proj.weight.weight_loader = self.qkv_proj.weight_loader self.qkv_proj.weight.output_dim = 0 self.o_proj = torch.nn.Linear( @@ -495,8 +554,43 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param = self.lm_head.weight weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, self.model.embed_tokens.weight) + + # Re-arrange fused matrix for TP + tp_size = get_tensor_model_parallel_world_size() + for name, param in params_dict.items(): + # For these modules, we need to re-arrange the fused matrix to match + # tensor parallelism. + if ".qkv_proj" in name: + shuffle_qkv_proj_weight(param, tp_size) + elif ".gate_up_proj" in name: + shuffle_gate_up_proj_weight(param, tp_size) + apply_torchao_config_(self, params_dict, set(["proj.weight"])) + def tensor_parallel(self, device_mesh=None): + """ + Tensor parallelize the model across the given device mesh. + Args: + device_mesh (`torch.distributed.DeviceMesh`): + The device mesh to use for tensor parallelism. + """ + # Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module. + # No op if `_tp_plan` attribute does not exist under the module. + # This is a helper function to be used with `model.apply` to recursively + # parallelize a model. + def tplize(mod: torch.nn.Module) -> None: + tp_plan = getattr(mod, "_tp_plan", None) + if tp_plan: + torch.distributed.tensor.parallel.parallelize_module( + mod, + device_mesh=device_mesh, + parallelize_plan=tp_plan, + ) + + # `apply` is a native method of `nn.Module` that recursively applies a + # function to every submodule. + self.apply(tplize) + class TorchNativePhi3ForCausalLM(TorchNativeLlamaForCausalLM): pass From ebb4c756bb1cfc2cb2af0bf3b4baaa1e710f59f4 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 8 Nov 2024 11:11:15 -0800 Subject: [PATCH 2/9] Support loading in sharded mode Move tp to utils Add ColwiseParallelSharded --- .../sglang/srt/model_executor/model_runner.py | 3 +- python/sglang/srt/model_parallel.py | 72 ++++++++ .../sglang/srt/models/torch_native_llama.py | 162 +++--------------- 3 files changed, 99 insertions(+), 138 deletions(-) create mode 100644 python/sglang/srt/model_parallel.py diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index a511eb91904..7760ba55c68 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -54,6 +54,7 @@ ReqToTokenPool, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_parallel import tensor_parallel from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( @@ -158,7 +159,7 @@ def __init__( device_mesh = torch.distributed.init_device_mesh( self.device, (self.tp_size,) ) - self.model.tensor_parallel(device_mesh) + tensor_parallel(self.model, device_mesh) if server_args.lora_paths is not None: self.init_lora_manager() diff --git a/python/sglang/srt/model_parallel.py b/python/sglang/srt/model_parallel.py new file mode 100644 index 00000000000..40f6717fb3a --- /dev/null +++ b/python/sglang/srt/model_parallel.py @@ -0,0 +1,72 @@ +""" +Common utilities for torch model parallelism. +""" + +from typing import Optional, Sequence + +import torch +from torch.distributed.device_mesh import DeviceMesh +try: + from torch.distributed.tensor import DTensor, Shard +except ImportError: + from torch.distributed._tensor import DTensor, Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + RowwiseParallel, + parallelize_module, +) + + +class ColwiseParallelSharded(ColwiseParallel): + """ + A version of ColwiseParallel where the local weight has been already + sharded. This is used for the fused wqkv case, where during loading, we + already sharded wq, wk, wv before fusing them. + """ + # Override the _partition_linear_fn in ColwiseParallel + def _partition_linear_fn(self, name, module, device_mesh): + # colwise shard weight/bias to Shard(0), weight be Shard(0) + # means Colwise as Linear is input * weight^T + bias, where + # weight would become Shard(1) + for name, param in module.named_parameters(): + dtensor = DTensor.from_local( + param, device_mesh, [Shard(0)] + ) + dist_param = torch.nn.Parameter(dtensor, requires_grad=False) + module.register_parameter(name, dist_param) + + +def tensor_parallel( + module: torch.nn.Module, + device_mesh: Optional[DeviceMesh] = None, +): + """ + Tensor parallelize the model across the given device mesh. + Args: + module (`torch.nn.Module`): + The module to tensor parallelize. + device_mesh (`torch.distributed.DeviceMesh`): + The device mesh to use for tensor parallelism. + """ + # Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module. + # No op if `_tp_plan` attribute does not exist under the module. + # This is a helper function to be used with `model.apply` to recursively + # parallelize a model. + def tplize(mod: torch.nn.Module) -> None: + tp_plan = getattr(mod, "_tp_plan", None) + if tp_plan is None: + return + for child_name, tp_style in tp_plan.items(): + submod = mod.get_submodule(child_name) + if tp_style == "Colwise": + parallelize_module(submod, device_mesh, ColwiseParallel()) + elif tp_style == "Rowwise": + parallelize_module(submod, device_mesh, RowwiseParallel()) + elif tp_style == "Colwise_Sharded": + parallelize_module(submod, device_mesh, ColwiseParallelSharded()) + else: + raise ValueError(f"Unknown TP style {tp_style}") + + # `apply` is a native method of `nn.Module` that recursively applies a + # function to every submodule. + module.apply(tplize) diff --git a/python/sglang/srt/models/torch_native_llama.py b/python/sglang/srt/models/torch_native_llama.py index 46d78f8f991..bc03c484fb7 100644 --- a/python/sglang/srt/models/torch_native_llama.py +++ b/python/sglang/srt/models/torch_native_llama.py @@ -18,17 +18,13 @@ """Inference-only LLaMA model compatible with HuggingFace weights.""" import types -from typing import Any, Dict, Iterable, Optional, Tuple, Sequence, List +from typing import Any, Dict, Iterable, Optional, Tuple import torch from torch import nn from torch.nn.parameter import Parameter -from torch.distributed.tensor.parallel import ( - ColwiseParallel, - RowwiseParallel, -) from transformers import LlamaConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -46,20 +42,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch -gate_up_proj_shard_offsets = {} - - -def build_gate_up_proj_shard_offsets( - output_sizes: Sequence[int], -): - global gate_up_proj_shard_offsets - # shard_id: (shard_offset, shard_size) - current_shard_offset = 0 - for i, output_size in enumerate(output_sizes): - gate_up_proj_shard_offsets.setdefault( - i, (current_shard_offset, output_size) - ) - current_shard_offset += output_size +tp_size = get_tensor_model_parallel_world_size() +tp_rank = get_tensor_model_parallel_rank() def gate_up_proj_weight_loader( @@ -68,6 +52,12 @@ def gate_up_proj_weight_loader( loaded_weight: torch.Tensor, loaded_shard_id: Optional[int] = None, ): + # shard_id: (shard_offset, shard_size) + gate_up_proj_shard_offsets = {} + current_shard_offset = 0 + for i, output_size in enumerate(self.output_sizes): + gate_up_proj_shard_offsets[i] = (current_shard_offset, output_size) + current_shard_offset += output_size if loaded_shard_id is None: for shard_id, (shard_offset, shard_size) in gate_up_proj_shard_offsets.items(): loaded_weight_shard = loaded_weight.narrow( @@ -79,38 +69,16 @@ def gate_up_proj_weight_loader( param_data = param.data shard_offset, shard_size = gate_up_proj_shard_offsets[loaded_shard_id] param_data = param_data.narrow(0, shard_offset, shard_size) + loaded_weight = loaded_weight.narrow(0, tp_rank * shard_size, shard_size) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) return -def shuffle_gate_up_proj_weight( - param: Parameter, - tp_size: int, -): - if tp_size == 1: - return - - param_data = param.data - new_tensor = torch.empty_like(param_data) - tp_dim = 0 - tp_chunk_size = param_data.shape[tp_dim] // tp_size - for shard_id, (shard_offset, shard_size) in gate_up_proj_shard_offsets.items(): - tp_slice_size = shard_size // tp_size - for i in range(tp_size): - tp_slice_src_offset = shard_offset + i * tp_slice_size - tp_slice_dst_offset = i * tp_chunk_size + shard_offset // tp_size - src_slice = param_data.narrow(tp_dim, tp_slice_src_offset, tp_slice_size) - dst_slice = new_tensor.narrow(tp_dim, tp_slice_dst_offset, tp_slice_size) - dst_slice.copy_(src_slice) - - param.data = new_tensor - - class LlamaMLP(nn.Module): _tp_plan = { - "gate_up_proj": ColwiseParallel(), - "down_proj": RowwiseParallel(), + "gate_up_proj": "Colwise_Sharded", + "down_proj": "Rowwise", } def __init__( @@ -127,8 +95,7 @@ def __init__( intermediate_size * 2, bias=False, ) - self.gate_up_proj.output_sizes = [intermediate_size] * 2 - build_gate_up_proj_shard_offsets(self.gate_up_proj.output_sizes) + self.gate_up_proj.output_sizes = [intermediate_size // tp_size] * 2 self.gate_up_proj.weight_loader = types.MethodType( gate_up_proj_weight_loader, self.gate_up_proj ) @@ -148,33 +115,18 @@ def forward(self, x): return x -qkv_proj_shard_offsets = {} - - -def build_qkv_proj_shard_offsets( - num_heads: int, - num_kv_heads: int, - head_size: int, -): - global qkv_proj_shard_offsets - # shard_id: (shard_offset, shard_size) - qkv_proj_shard_offsets.setdefault( - "q", (0, num_heads * head_size) - ) - qkv_proj_shard_offsets.setdefault( - "k", (num_heads * head_size, num_kv_heads * head_size) - ) - qkv_proj_shard_offsets.setdefault( - "v", ((num_heads + num_kv_heads) * head_size, num_kv_heads * head_size) - ) - - def qkv_proj_weight_loader( self, param: Parameter, loaded_weight: torch.Tensor, loaded_shard_id: Optional[str] = None, ): + # shard_id: (shard_offset, shard_size) + qkv_proj_shard_offsets = { + "q": (0, self.num_heads * self.head_size), + "k": (self.num_heads * self.head_size, self.num_kv_heads * self.head_size), + "v": ((self.num_heads + self.num_kv_heads) * self.head_size, self.num_kv_heads * self.head_size), + } if loaded_shard_id is None: for shard_id, (shard_offset, shard_size) in qkv_proj_shard_offsets.items(): loaded_weight_shard = loaded_weight.narrow( @@ -185,39 +137,16 @@ def qkv_proj_weight_loader( shard_offset, shard_size = qkv_proj_shard_offsets[loaded_shard_id] param_data = param.data param_data = param_data.narrow(0, shard_offset, shard_size) + loaded_weight = loaded_weight.narrow(0, tp_rank * shard_size, shard_size) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) return -def shuffle_qkv_proj_weight( - param: Parameter, - tp_size: int, -): - if tp_size == 1: - return - - param_data = param.data - new_tensor = torch.empty_like(param_data) - tp_dim = 0 - tp_chunk_size = param_data.shape[tp_dim] // tp_size - for shard_id in ["q", "k", "v"]: - shard_offset, shard_size = qkv_proj_shard_offsets[shard_id] - tp_slice_size = shard_size // tp_size - for i in range(tp_size): - tp_slice_src_offset = shard_offset + i * tp_slice_size - tp_slice_dst_offset = i * tp_chunk_size + shard_offset // tp_size - src_slice = param_data.narrow(tp_dim, tp_slice_src_offset, tp_slice_size) - dst_slice = new_tensor.narrow(tp_dim, tp_slice_dst_offset, tp_slice_size) - dst_slice.copy_(src_slice) - - param.data = new_tensor - - class LlamaAttention(nn.Module): _tp_plan = { - "qkv_proj": ColwiseParallel(), - "o_proj": RowwiseParallel(), + "qkv_proj": "Colwise_Sharded", + "o_proj": "Rowwise", } def __init__( @@ -236,7 +165,6 @@ def __init__( ) -> None: super().__init__() self.hidden_size = hidden_size - tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size @@ -265,16 +193,11 @@ def __init__( (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim, bias=False, ) - build_qkv_proj_shard_offsets( - self.total_num_heads, - self.total_num_kv_heads, - self.head_dim, - ) self.qkv_proj.total_num_heads = self.total_num_heads self.qkv_proj.head_size = self.head_dim self.qkv_proj.total_num_kv_heads = self.total_num_kv_heads - self.qkv_proj.num_heads = self.total_num_heads - self.qkv_proj.num_kv_heads = self.total_num_kv_heads + self.qkv_proj.num_heads = self.num_heads + self.qkv_proj.num_kv_heads = self.num_kv_heads self.qkv_proj.weight_loader = types.MethodType( qkv_proj_weight_loader, self.qkv_proj ) @@ -554,43 +477,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param = self.lm_head.weight weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, self.model.embed_tokens.weight) - - # Re-arrange fused matrix for TP - tp_size = get_tensor_model_parallel_world_size() - for name, param in params_dict.items(): - # For these modules, we need to re-arrange the fused matrix to match - # tensor parallelism. - if ".qkv_proj" in name: - shuffle_qkv_proj_weight(param, tp_size) - elif ".gate_up_proj" in name: - shuffle_gate_up_proj_weight(param, tp_size) - apply_torchao_config_(self, params_dict, set(["proj.weight"])) - def tensor_parallel(self, device_mesh=None): - """ - Tensor parallelize the model across the given device mesh. - Args: - device_mesh (`torch.distributed.DeviceMesh`): - The device mesh to use for tensor parallelism. - """ - # Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module. - # No op if `_tp_plan` attribute does not exist under the module. - # This is a helper function to be used with `model.apply` to recursively - # parallelize a model. - def tplize(mod: torch.nn.Module) -> None: - tp_plan = getattr(mod, "_tp_plan", None) - if tp_plan: - torch.distributed.tensor.parallel.parallelize_module( - mod, - device_mesh=device_mesh, - parallelize_plan=tp_plan, - ) - - # `apply` is a native method of `nn.Module` that recursively applies a - # function to every submodule. - self.apply(tplize) - class TorchNativePhi3ForCausalLM(TorchNativeLlamaForCausalLM): pass From e78520c98405d3c0783e616a14caf248b86bdab5 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 8 Nov 2024 15:52:20 -0800 Subject: [PATCH 3/9] Add supports_torch_tp gate --- python/sglang/srt/model_executor/model_runner.py | 5 +++-- python/sglang/srt/models/torch_native_llama.py | 15 ++++++++------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 7760ba55c68..dfdfc64ca75 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -154,8 +154,9 @@ def __init__( min_per_gpu_memory = self.init_torch_distributed() self.sampler = Sampler() self.load_model() - if self.tp_size > 1: - logger.info(f"Tensor parallelism is enabled, {self.tp_size} devices will be used.") + supports_torch_tp = getattr(self.model, "supports_torch_tp", False) + if self.tp_size > 1 and supports_torch_tp: + logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.") device_mesh = torch.distributed.init_device_mesh( self.device, (self.tp_size,) ) diff --git a/python/sglang/srt/models/torch_native_llama.py b/python/sglang/srt/models/torch_native_llama.py index bc03c484fb7..22c288216e3 100644 --- a/python/sglang/srt/models/torch_native_llama.py +++ b/python/sglang/srt/models/torch_native_llama.py @@ -53,13 +53,13 @@ def gate_up_proj_weight_loader( loaded_shard_id: Optional[int] = None, ): # shard_id: (shard_offset, shard_size) - gate_up_proj_shard_offsets = {} + gate_up_offsets = {} current_shard_offset = 0 for i, output_size in enumerate(self.output_sizes): - gate_up_proj_shard_offsets[i] = (current_shard_offset, output_size) + gate_up_offsets[i] = (current_shard_offset, output_size) current_shard_offset += output_size if loaded_shard_id is None: - for shard_id, (shard_offset, shard_size) in gate_up_proj_shard_offsets.items(): + for shard_id, (shard_offset, shard_size) in gate_up_offsets.items(): loaded_weight_shard = loaded_weight.narrow( 0, shard_offset, shard_size ) @@ -67,7 +67,7 @@ def gate_up_proj_weight_loader( else: assert loaded_shard_id < len(self.output_sizes) param_data = param.data - shard_offset, shard_size = gate_up_proj_shard_offsets[loaded_shard_id] + shard_offset, shard_size = gate_up_offsets[loaded_shard_id] param_data = param_data.narrow(0, shard_offset, shard_size) loaded_weight = loaded_weight.narrow(0, tp_rank * shard_size, shard_size) assert param_data.shape == loaded_weight.shape @@ -122,19 +122,19 @@ def qkv_proj_weight_loader( loaded_shard_id: Optional[str] = None, ): # shard_id: (shard_offset, shard_size) - qkv_proj_shard_offsets = { + qkv_offsets = { "q": (0, self.num_heads * self.head_size), "k": (self.num_heads * self.head_size, self.num_kv_heads * self.head_size), "v": ((self.num_heads + self.num_kv_heads) * self.head_size, self.num_kv_heads * self.head_size), } if loaded_shard_id is None: - for shard_id, (shard_offset, shard_size) in qkv_proj_shard_offsets.items(): + for shard_id, (shard_offset, shard_size) in qkv_offsets.items(): loaded_weight_shard = loaded_weight.narrow( param.output_dim, shard_offset, shard_size ) self.weight_loader(param, loaded_weight_shard, shard_id) else: - shard_offset, shard_size = qkv_proj_shard_offsets[loaded_shard_id] + shard_offset, shard_size = qkv_offsets[loaded_shard_id] param_data = param.data param_data = param_data.narrow(0, shard_offset, shard_size) loaded_weight = loaded_weight.narrow(0, tp_rank * shard_size, shard_size) @@ -367,6 +367,7 @@ def __init__( self.config = config self.quant_config = quant_config self.torchao_config = global_server_args_dict["torchao_config"] + self.supports_torch_tp = True self.model = LlamaModel(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) From 11a153cd09588c4340a2876a51c343075e76f3ce Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 8 Nov 2024 16:00:05 -0800 Subject: [PATCH 4/9] Add torch version --- python/sglang/srt/model_parallel.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/model_parallel.py b/python/sglang/srt/model_parallel.py index 40f6717fb3a..fd7eeca5100 100644 --- a/python/sglang/srt/model_parallel.py +++ b/python/sglang/srt/model_parallel.py @@ -9,6 +9,7 @@ try: from torch.distributed.tensor import DTensor, Shard except ImportError: + # torch 2.4 or older from torch.distributed._tensor import DTensor, Shard from torch.distributed.tensor.parallel import ( ColwiseParallel, From 5cc3ca6846176460688f04486dce416dd9213b72 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 8 Nov 2024 16:26:31 -0800 Subject: [PATCH 5/9] Modularize TP application --- .../sglang/srt/model_executor/model_runner.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index dfdfc64ca75..4698d952ea3 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -54,7 +54,6 @@ ReqToTokenPool, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.model_parallel import tensor_parallel from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( @@ -154,13 +153,11 @@ def __init__( min_per_gpu_memory = self.init_torch_distributed() self.sampler = Sampler() self.load_model() + + # Apply torch TP if model supports it supports_torch_tp = getattr(self.model, "supports_torch_tp", False) if self.tp_size > 1 and supports_torch_tp: - logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.") - device_mesh = torch.distributed.init_device_mesh( - self.device, (self.tp_size,) - ) - tensor_parallel(self.model, device_mesh) + self.apply_torch_tp() if server_args.lora_paths is not None: self.init_lora_manager() @@ -565,6 +562,15 @@ def init_cuda_graphs(self): logger.info("Capture cuda graph begin. This can take up to several minutes.") self.cuda_graph_runner = CudaGraphRunner(self) + def apply_torch_tp(self): + logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.") + from sglang.srt.model_parallel import tensor_parallel + + device_mesh = torch.distributed.init_device_mesh( + self.device, (self.tp_size,) + ) + tensor_parallel(self.model, device_mesh) + def forward_decode(self, forward_batch: ForwardBatch): if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch): return self.cuda_graph_runner.replay(forward_batch) From ee80b5d500e023e60280a25dff40f969c2c34877 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 8 Nov 2024 17:28:21 -0800 Subject: [PATCH 6/9] Move tp_size to weight loader --- python/sglang/srt/models/torch_native_llama.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/models/torch_native_llama.py b/python/sglang/srt/models/torch_native_llama.py index 22c288216e3..37a66e30821 100644 --- a/python/sglang/srt/models/torch_native_llama.py +++ b/python/sglang/srt/models/torch_native_llama.py @@ -68,6 +68,8 @@ def gate_up_proj_weight_loader( assert loaded_shard_id < len(self.output_sizes) param_data = param.data shard_offset, shard_size = gate_up_offsets[loaded_shard_id] + # Everything shrinks by tp_size if TP enabled + shard_offset, shard_size = shard_offset // tp_size, shard_size // tp_size param_data = param_data.narrow(0, shard_offset, shard_size) loaded_weight = loaded_weight.narrow(0, tp_rank * shard_size, shard_size) assert param_data.shape == loaded_weight.shape @@ -95,7 +97,7 @@ def __init__( intermediate_size * 2, bias=False, ) - self.gate_up_proj.output_sizes = [intermediate_size // tp_size] * 2 + self.gate_up_proj.output_sizes = [intermediate_size] * 2 self.gate_up_proj.weight_loader = types.MethodType( gate_up_proj_weight_loader, self.gate_up_proj ) @@ -135,6 +137,8 @@ def qkv_proj_weight_loader( self.weight_loader(param, loaded_weight_shard, shard_id) else: shard_offset, shard_size = qkv_offsets[loaded_shard_id] + # Everything shrinks by tp_size if TP enabled + shard_offset, shard_size = shard_offset // tp_size, shard_size // tp_size param_data = param.data param_data = param_data.narrow(0, shard_offset, shard_size) loaded_weight = loaded_weight.narrow(0, tp_rank * shard_size, shard_size) @@ -193,11 +197,9 @@ def __init__( (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim, bias=False, ) - self.qkv_proj.total_num_heads = self.total_num_heads self.qkv_proj.head_size = self.head_dim - self.qkv_proj.total_num_kv_heads = self.total_num_kv_heads - self.qkv_proj.num_heads = self.num_heads - self.qkv_proj.num_kv_heads = self.num_kv_heads + self.qkv_proj.num_heads = self.total_num_heads + self.qkv_proj.num_kv_heads = self.total_num_kv_heads self.qkv_proj.weight_loader = types.MethodType( qkv_proj_weight_loader, self.qkv_proj ) From 4488442b0e7886f210888e4699e096ed0662ef80 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Tue, 12 Nov 2024 07:30:54 -0800 Subject: [PATCH 7/9] Wait async tensor; fix param size; conditional inference mode --- python/sglang/bench_latency.py | 22 +++++- .../sglang/srt/model_executor/model_runner.py | 3 + python/sglang/srt/model_parallel.py | 19 ++++- .../sglang/srt/models/torch_native_llama.py | 73 +++++++++---------- 4 files changed, 74 insertions(+), 43 deletions(-) diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index d97b641ea15..d7700f9a27d 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -220,8 +220,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len): return reqs -@torch.inference_mode() -def extend(reqs, model_runner): +def _extend(reqs, model_runner): batch = ScheduleBatch.init_new( reqs=reqs, req_to_token_pool=model_runner.req_to_token_pool, @@ -237,8 +236,15 @@ def extend(reqs, model_runner): return next_token_ids, logits_output.next_token_logits, batch -@torch.inference_mode() -def decode(input_token_ids, batch, model_runner): +def extend(reqs, model_runner): + # Disable inference mode for now when torch TP is applied. We can remove + # this workaround once DTensor adds support for inference mode. + use_inf_mode = not model_runner.torch_tp_applied + with torch.inference_mode(use_inf_mode): + return _extend(reqs, model_runner) + + +def _decode(input_token_ids, batch, model_runner): batch.output_ids = input_token_ids batch.prepare_for_decode() model_worker_batch = batch.get_model_worker_batch() @@ -248,6 +254,14 @@ def decode(input_token_ids, batch, model_runner): return next_token_ids, logits_output.next_token_logits +def decode(input_token_ids, batch, model_runner): + # Disable inference mode for now when torch TP is applied. We can remove + # this workaround once DTensor adds support for inference mode. + use_inf_mode = not model_runner.torch_tp_applied + with torch.inference_mode(use_inf_mode): + return _decode(input_token_ids, batch, model_runner) + + def correctness_test( server_args, port_args, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 4698d952ea3..301b0643f29 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -158,6 +158,9 @@ def __init__( supports_torch_tp = getattr(self.model, "supports_torch_tp", False) if self.tp_size > 1 and supports_torch_tp: self.apply_torch_tp() + self.torch_tp_applied = True + else: + self.torch_tp_applied = False if server_args.lora_paths is not None: self.init_lora_manager() diff --git a/python/sglang/srt/model_parallel.py b/python/sglang/srt/model_parallel.py index fd7eeca5100..1503a2a91b8 100644 --- a/python/sglang/srt/model_parallel.py +++ b/python/sglang/srt/model_parallel.py @@ -16,6 +16,7 @@ RowwiseParallel, parallelize_module, ) +from torch.distributed._functional_collectives import AsyncCollectiveTensor class ColwiseParallelSharded(ColwiseParallel): @@ -36,6 +37,22 @@ def _partition_linear_fn(self, name, module, device_mesh): dist_param = torch.nn.Parameter(dtensor, requires_grad=False) module.register_parameter(name, dist_param) +class RowwiseParallelMaybeWait(RowwiseParallel): + """ + A version of RowwiseParallel that waits for the output (establish dependency + between comm stream and compute stream in CUDA sense) before going into the + next op. This is needed to workaround the current interaction between + AsyncCollectiveTensor and custom ops, such as `class RMSNorm(CustomOp)`. + """ + @staticmethod + def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): + outputs = super(RowwiseParallelMaybeWait, RowwiseParallelMaybeWait)._prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh) + # wait for the output to be ready + if isinstance(outputs, AsyncCollectiveTensor): + return outputs.wait() + else: + return outputs + def tensor_parallel( module: torch.nn.Module, @@ -62,7 +79,7 @@ def tplize(mod: torch.nn.Module) -> None: if tp_style == "Colwise": parallelize_module(submod, device_mesh, ColwiseParallel()) elif tp_style == "Rowwise": - parallelize_module(submod, device_mesh, RowwiseParallel()) + parallelize_module(submod, device_mesh, RowwiseParallelMaybeWait()) elif tp_style == "Colwise_Sharded": parallelize_module(submod, device_mesh, ColwiseParallelSharded()) else: diff --git a/python/sglang/srt/models/torch_native_llama.py b/python/sglang/srt/models/torch_native_llama.py index 37a66e30821..d4b37d95c7a 100644 --- a/python/sglang/srt/models/torch_native_llama.py +++ b/python/sglang/srt/models/torch_native_llama.py @@ -50,31 +50,29 @@ def gate_up_proj_weight_loader( self, param: Parameter, loaded_weight: torch.Tensor, - loaded_shard_id: Optional[int] = None, + loaded_shard_id: int, ): # shard_id: (shard_offset, shard_size) gate_up_offsets = {} current_shard_offset = 0 for i, output_size in enumerate(self.output_sizes): + # Everything shrinks by tp_size if TP enabled + output_size = output_size // tp_size gate_up_offsets[i] = (current_shard_offset, output_size) current_shard_offset += output_size - if loaded_shard_id is None: - for shard_id, (shard_offset, shard_size) in gate_up_offsets.items(): - loaded_weight_shard = loaded_weight.narrow( - 0, shard_offset, shard_size - ) - self.weight_loader(param, loaded_weight_shard, shard_id) - else: - assert loaded_shard_id < len(self.output_sizes) - param_data = param.data - shard_offset, shard_size = gate_up_offsets[loaded_shard_id] - # Everything shrinks by tp_size if TP enabled - shard_offset, shard_size = shard_offset // tp_size, shard_size // tp_size - param_data = param_data.narrow(0, shard_offset, shard_size) - loaded_weight = loaded_weight.narrow(0, tp_rank * shard_size, shard_size) - assert param_data.shape == loaded_weight.shape - param_data.copy_(loaded_weight) - return + # Re-size the param to the size after TP + if current_shard_offset != param.shape[0]: + # The clone will free the original, full tensor + param.data = param.data.narrow(0, 0, current_shard_offset).clone() + + # Now load gate or up + assert loaded_shard_id < len(self.output_sizes) + param_data = param.data + shard_offset, shard_size = gate_up_offsets[loaded_shard_id] + param_data = param_data.narrow(0, shard_offset, shard_size) + loaded_weight = loaded_weight.narrow(0, tp_rank * shard_size, shard_size) + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) class LlamaMLP(nn.Module): @@ -121,30 +119,29 @@ def qkv_proj_weight_loader( self, param: Parameter, loaded_weight: torch.Tensor, - loaded_shard_id: Optional[str] = None, + loaded_shard_id: str, ): + num_heads = self.num_heads // tp_size + num_kv_heads = self.num_kv_heads // tp_size # shard_id: (shard_offset, shard_size) qkv_offsets = { - "q": (0, self.num_heads * self.head_size), - "k": (self.num_heads * self.head_size, self.num_kv_heads * self.head_size), - "v": ((self.num_heads + self.num_kv_heads) * self.head_size, self.num_kv_heads * self.head_size), + "q": (0, num_heads * self.head_size), + "k": (num_heads * self.head_size, num_kv_heads * self.head_size), + "v": ((num_heads + num_kv_heads) * self.head_size, num_kv_heads * self.head_size), } - if loaded_shard_id is None: - for shard_id, (shard_offset, shard_size) in qkv_offsets.items(): - loaded_weight_shard = loaded_weight.narrow( - param.output_dim, shard_offset, shard_size - ) - self.weight_loader(param, loaded_weight_shard, shard_id) - else: - shard_offset, shard_size = qkv_offsets[loaded_shard_id] - # Everything shrinks by tp_size if TP enabled - shard_offset, shard_size = shard_offset // tp_size, shard_size // tp_size - param_data = param.data - param_data = param_data.narrow(0, shard_offset, shard_size) - loaded_weight = loaded_weight.narrow(0, tp_rank * shard_size, shard_size) - assert param_data.shape == loaded_weight.shape - param_data.copy_(loaded_weight) - return + total_size = qkv_offsets["v"][0] + qkv_offsets["v"][1] + # Re-size the param to the size after TP + if total_size != param.shape[0]: + # The clone will free the original, full tensor + param.data = param.data.narrow(0, 0, total_size).clone() + + # Now load q, k or v + shard_offset, shard_size = qkv_offsets[loaded_shard_id] + param_data = param.data + param_data = param_data.narrow(0, shard_offset, shard_size) + loaded_weight = loaded_weight.narrow(0, tp_rank * shard_size, shard_size) + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) class LlamaAttention(nn.Module): From 1173be9883c848a2ee3e875ed7c8013976ac7edb Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Tue, 12 Nov 2024 12:12:18 -0800 Subject: [PATCH 8/9] Lint --- python/sglang/srt/model_parallel.py | 4 +++- python/sglang/srt/models/torch_native_llama.py | 6 ++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/model_parallel.py b/python/sglang/srt/model_parallel.py index 1503a2a91b8..8e503952047 100644 --- a/python/sglang/srt/model_parallel.py +++ b/python/sglang/srt/model_parallel.py @@ -6,17 +6,19 @@ import torch from torch.distributed.device_mesh import DeviceMesh + try: from torch.distributed.tensor import DTensor, Shard except ImportError: # torch 2.4 or older from torch.distributed._tensor import DTensor, Shard + +from torch.distributed._functional_collectives import AsyncCollectiveTensor from torch.distributed.tensor.parallel import ( ColwiseParallel, RowwiseParallel, parallelize_module, ) -from torch.distributed._functional_collectives import AsyncCollectiveTensor class ColwiseParallelSharded(ColwiseParallel): diff --git a/python/sglang/srt/models/torch_native_llama.py b/python/sglang/srt/models/torch_native_llama.py index d4b37d95c7a..140fba30c4d 100644 --- a/python/sglang/srt/models/torch_native_llama.py +++ b/python/sglang/srt/models/torch_native_llama.py @@ -24,7 +24,10 @@ from torch import nn from torch.nn.parameter import Parameter from transformers import LlamaConfig -from vllm.distributed import get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -41,7 +44,6 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch - tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() From 95ee81134bbe6a4fd7f9acc62df783073f06097c Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Tue, 12 Nov 2024 20:54:30 -0800 Subject: [PATCH 9/9] black lint --- python/sglang/srt/model_executor/model_runner.py | 4 +--- python/sglang/srt/model_parallel.py | 14 ++++++++++---- python/sglang/srt/models/torch_native_llama.py | 5 ++++- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 301b0643f29..a0cd4250d86 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -569,9 +569,7 @@ def apply_torch_tp(self): logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.") from sglang.srt.model_parallel import tensor_parallel - device_mesh = torch.distributed.init_device_mesh( - self.device, (self.tp_size,) - ) + device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,)) tensor_parallel(self.model, device_mesh) def forward_decode(self, forward_batch: ForwardBatch): diff --git a/python/sglang/srt/model_parallel.py b/python/sglang/srt/model_parallel.py index 8e503952047..afe504082b2 100644 --- a/python/sglang/srt/model_parallel.py +++ b/python/sglang/srt/model_parallel.py @@ -27,18 +27,18 @@ class ColwiseParallelSharded(ColwiseParallel): sharded. This is used for the fused wqkv case, where during loading, we already sharded wq, wk, wv before fusing them. """ + # Override the _partition_linear_fn in ColwiseParallel def _partition_linear_fn(self, name, module, device_mesh): # colwise shard weight/bias to Shard(0), weight be Shard(0) # means Colwise as Linear is input * weight^T + bias, where # weight would become Shard(1) for name, param in module.named_parameters(): - dtensor = DTensor.from_local( - param, device_mesh, [Shard(0)] - ) + dtensor = DTensor.from_local(param, device_mesh, [Shard(0)]) dist_param = torch.nn.Parameter(dtensor, requires_grad=False) module.register_parameter(name, dist_param) + class RowwiseParallelMaybeWait(RowwiseParallel): """ A version of RowwiseParallel that waits for the output (establish dependency @@ -46,9 +46,14 @@ class RowwiseParallelMaybeWait(RowwiseParallel): next op. This is needed to workaround the current interaction between AsyncCollectiveTensor and custom ops, such as `class RMSNorm(CustomOp)`. """ + @staticmethod def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): - outputs = super(RowwiseParallelMaybeWait, RowwiseParallelMaybeWait)._prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh) + outputs = super( + RowwiseParallelMaybeWait, RowwiseParallelMaybeWait + )._prepare_output_fn( + output_layouts, use_local_output, mod, outputs, device_mesh + ) # wait for the output to be ready if isinstance(outputs, AsyncCollectiveTensor): return outputs.wait() @@ -68,6 +73,7 @@ def tensor_parallel( device_mesh (`torch.distributed.DeviceMesh`): The device mesh to use for tensor parallelism. """ + # Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module. # No op if `_tp_plan` attribute does not exist under the module. # This is a helper function to be used with `model.apply` to recursively diff --git a/python/sglang/srt/models/torch_native_llama.py b/python/sglang/srt/models/torch_native_llama.py index 140fba30c4d..a7403240bea 100644 --- a/python/sglang/srt/models/torch_native_llama.py +++ b/python/sglang/srt/models/torch_native_llama.py @@ -129,7 +129,10 @@ def qkv_proj_weight_loader( qkv_offsets = { "q": (0, num_heads * self.head_size), "k": (num_heads * self.head_size, num_kv_heads * self.head_size), - "v": ((num_heads + num_kv_heads) * self.head_size, num_kv_heads * self.head_size), + "v": ( + (num_heads + num_kv_heads) * self.head_size, + num_kv_heads * self.head_size, + ), } total_size = qkv_offsets["v"][0] + qkv_offsets["v"][1] # Re-size the param to the size after TP