diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index 841ecf56d0e..28962eb9ff8 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -220,8 +220,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len): return reqs -@torch.inference_mode() -def extend(reqs, model_runner): +def _extend(reqs, model_runner): batch = ScheduleBatch.init_new( reqs=reqs, req_to_token_pool=model_runner.req_to_token_pool, @@ -237,8 +236,15 @@ def extend(reqs, model_runner): return next_token_ids, logits_output.next_token_logits, batch -@torch.inference_mode() -def decode(input_token_ids, batch, model_runner): +def extend(reqs, model_runner): + # Disable inference mode for now when torch TP is applied. We can remove + # this workaround once DTensor adds support for inference mode. + use_inf_mode = not model_runner.torch_tp_applied + with torch.inference_mode(use_inf_mode): + return _extend(reqs, model_runner) + + +def _decode(input_token_ids, batch, model_runner): batch.output_ids = input_token_ids batch.prepare_for_decode() model_worker_batch = batch.get_model_worker_batch() @@ -248,6 +254,14 @@ def decode(input_token_ids, batch, model_runner): return next_token_ids, logits_output.next_token_logits +def decode(input_token_ids, batch, model_runner): + # Disable inference mode for now when torch TP is applied. We can remove + # this workaround once DTensor adds support for inference mode. + use_inf_mode = not model_runner.torch_tp_applied + with torch.inference_mode(use_inf_mode): + return _decode(input_token_ids, batch, model_runner) + + def correctness_test( server_args, port_args, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 8b06d2ceac8..f87a1da7f7a 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -148,6 +148,15 @@ def __init__( min_per_gpu_memory = self.init_torch_distributed() self.sampler = Sampler() self.load_model() + + # Apply torch TP if model supports it + supports_torch_tp = getattr(self.model, "supports_torch_tp", False) + if self.tp_size > 1 and supports_torch_tp: + self.apply_torch_tp() + self.torch_tp_applied = True + else: + self.torch_tp_applied = False + if server_args.lora_paths is not None: self.init_lora_manager() self.init_memory_pool( @@ -548,6 +557,13 @@ def init_cuda_graphs(self): logger.info("Capture cuda graph begin. This can take up to several minutes.") self.cuda_graph_runner = CudaGraphRunner(self) + def apply_torch_tp(self): + logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.") + from sglang.srt.model_parallel import tensor_parallel + + device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,)) + tensor_parallel(self.model, device_mesh) + def forward_decode(self, forward_batch: ForwardBatch): if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch): return self.cuda_graph_runner.replay(forward_batch) diff --git a/python/sglang/srt/model_parallel.py b/python/sglang/srt/model_parallel.py new file mode 100644 index 00000000000..afe504082b2 --- /dev/null +++ b/python/sglang/srt/model_parallel.py @@ -0,0 +1,98 @@ +""" +Common utilities for torch model parallelism. +""" + +from typing import Optional, Sequence + +import torch +from torch.distributed.device_mesh import DeviceMesh + +try: + from torch.distributed.tensor import DTensor, Shard +except ImportError: + # torch 2.4 or older + from torch.distributed._tensor import DTensor, Shard + +from torch.distributed._functional_collectives import AsyncCollectiveTensor +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + RowwiseParallel, + parallelize_module, +) + + +class ColwiseParallelSharded(ColwiseParallel): + """ + A version of ColwiseParallel where the local weight has been already + sharded. This is used for the fused wqkv case, where during loading, we + already sharded wq, wk, wv before fusing them. + """ + + # Override the _partition_linear_fn in ColwiseParallel + def _partition_linear_fn(self, name, module, device_mesh): + # colwise shard weight/bias to Shard(0), weight be Shard(0) + # means Colwise as Linear is input * weight^T + bias, where + # weight would become Shard(1) + for name, param in module.named_parameters(): + dtensor = DTensor.from_local(param, device_mesh, [Shard(0)]) + dist_param = torch.nn.Parameter(dtensor, requires_grad=False) + module.register_parameter(name, dist_param) + + +class RowwiseParallelMaybeWait(RowwiseParallel): + """ + A version of RowwiseParallel that waits for the output (establish dependency + between comm stream and compute stream in CUDA sense) before going into the + next op. This is needed to workaround the current interaction between + AsyncCollectiveTensor and custom ops, such as `class RMSNorm(CustomOp)`. + """ + + @staticmethod + def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): + outputs = super( + RowwiseParallelMaybeWait, RowwiseParallelMaybeWait + )._prepare_output_fn( + output_layouts, use_local_output, mod, outputs, device_mesh + ) + # wait for the output to be ready + if isinstance(outputs, AsyncCollectiveTensor): + return outputs.wait() + else: + return outputs + + +def tensor_parallel( + module: torch.nn.Module, + device_mesh: Optional[DeviceMesh] = None, +): + """ + Tensor parallelize the model across the given device mesh. + Args: + module (`torch.nn.Module`): + The module to tensor parallelize. + device_mesh (`torch.distributed.DeviceMesh`): + The device mesh to use for tensor parallelism. + """ + + # Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module. + # No op if `_tp_plan` attribute does not exist under the module. + # This is a helper function to be used with `model.apply` to recursively + # parallelize a model. + def tplize(mod: torch.nn.Module) -> None: + tp_plan = getattr(mod, "_tp_plan", None) + if tp_plan is None: + return + for child_name, tp_style in tp_plan.items(): + submod = mod.get_submodule(child_name) + if tp_style == "Colwise": + parallelize_module(submod, device_mesh, ColwiseParallel()) + elif tp_style == "Rowwise": + parallelize_module(submod, device_mesh, RowwiseParallelMaybeWait()) + elif tp_style == "Colwise_Sharded": + parallelize_module(submod, device_mesh, ColwiseParallelSharded()) + else: + raise ValueError(f"Unknown TP style {tp_style}") + + # `apply` is a native method of `nn.Module` that recursively applies a + # function to every submodule. + module.apply(tplize) diff --git a/python/sglang/srt/models/torch_native_llama.py b/python/sglang/srt/models/torch_native_llama.py index d9ce05b8a65..a7403240bea 100644 --- a/python/sglang/srt/models/torch_native_llama.py +++ b/python/sglang/srt/models/torch_native_llama.py @@ -24,7 +24,10 @@ from torch import nn from torch.nn.parameter import Parameter from transformers import LlamaConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -41,35 +44,45 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch +tp_size = get_tensor_model_parallel_world_size() +tp_rank = get_tensor_model_parallel_rank() + def gate_up_proj_weight_loader( self, param: Parameter, loaded_weight: torch.Tensor, - loaded_shard_id: Optional[int] = None, + loaded_shard_id: int, ): - if loaded_shard_id is None: - shard_offsets: List[Tuple[int, int, int]] = [] - for i, output_size in enumerate(self.output_sizes): - shard_offsets.append((i, current_shard_offset, output_size)) - current_shard_offset += output_size - for shard_id, shard_offset, shard_size in shard_offsets: - loaded_weight_shard = loaded_weight.narrow( - output_dim, shard_offset, shard_size - ) - self.weight_loader(param, loaded_weight_shard, shard_id) - else: - assert loaded_shard_id < len(self.output_sizes) - param_data = param.data - shard_size = loaded_weight.shape[0] - shard_offset = loaded_shard_id * shard_size - param_data = param_data.narrow(0, shard_offset, shard_size) - assert param_data.shape == loaded_weight.shape - param_data.copy_(loaded_weight) - return + # shard_id: (shard_offset, shard_size) + gate_up_offsets = {} + current_shard_offset = 0 + for i, output_size in enumerate(self.output_sizes): + # Everything shrinks by tp_size if TP enabled + output_size = output_size // tp_size + gate_up_offsets[i] = (current_shard_offset, output_size) + current_shard_offset += output_size + # Re-size the param to the size after TP + if current_shard_offset != param.shape[0]: + # The clone will free the original, full tensor + param.data = param.data.narrow(0, 0, current_shard_offset).clone() + + # Now load gate or up + assert loaded_shard_id < len(self.output_sizes) + param_data = param.data + shard_offset, shard_size = gate_up_offsets[loaded_shard_id] + param_data = param_data.narrow(0, shard_offset, shard_size) + loaded_weight = loaded_weight.narrow(0, tp_rank * shard_size, shard_size) + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) class LlamaMLP(nn.Module): + _tp_plan = { + "gate_up_proj": "Colwise_Sharded", + "down_proj": "Rowwise", + } + def __init__( self, hidden_size: int, @@ -104,62 +117,44 @@ def forward(self, x): return x -def _get_shard_offset_mapping(self, loaded_shard_id: str): - shard_offset_mapping = { - "q": 0, - "k": self.num_heads * self.head_size, - "v": (self.num_heads + self.num_kv_heads) * self.head_size, - "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size, - } - return shard_offset_mapping.get(loaded_shard_id) - - -def _get_shard_size_mapping(self, loaded_shard_id: str): - shard_size_mapping = { - "q": self.num_heads * self.head_size, - "k": self.num_kv_heads * self.head_size, - "v": self.num_kv_heads * self.head_size, - } - return shard_size_mapping.get(loaded_shard_id) - - def qkv_proj_weight_loader( self, param: Parameter, loaded_weight: torch.Tensor, - loaded_shard_id: Optional[str] = None, + loaded_shard_id: str, ): - if loaded_shard_id is None: - shard_offsets = [ - # (shard_id, shard_offset, shard_size) - ("q", 0, self.total_num_heads * self.head_size), - ( - "k", - self.total_num_heads * self.head_size, - self.total_num_kv_heads * self.head_size, - ), - ( - "v", - (self.total_num_heads + self.total_num_kv_heads) * self.head_size, - self.total_num_kv_heads * self.head_size, - ), - ] - for shard_id, shard_offset, shard_size in shard_offsets: - loaded_weight_shard = loaded_weight.narrow( - param.output_dim, shard_offset, shard_size - ) - self.weight_loader(param, loaded_weight_shard, shard_id) - else: - shard_offset = self._get_shard_offset_mapping(loaded_shard_id) - shard_size = self._get_shard_size_mapping(loaded_shard_id) - param_data = param.data - param_data = param_data.narrow(0, shard_offset, shard_size) - assert param_data.shape == loaded_weight.shape - param_data.copy_(loaded_weight) - return + num_heads = self.num_heads // tp_size + num_kv_heads = self.num_kv_heads // tp_size + # shard_id: (shard_offset, shard_size) + qkv_offsets = { + "q": (0, num_heads * self.head_size), + "k": (num_heads * self.head_size, num_kv_heads * self.head_size), + "v": ( + (num_heads + num_kv_heads) * self.head_size, + num_kv_heads * self.head_size, + ), + } + total_size = qkv_offsets["v"][0] + qkv_offsets["v"][1] + # Re-size the param to the size after TP + if total_size != param.shape[0]: + # The clone will free the original, full tensor + param.data = param.data.narrow(0, 0, total_size).clone() + + # Now load q, k or v + shard_offset, shard_size = qkv_offsets[loaded_shard_id] + param_data = param.data + param_data = param_data.narrow(0, shard_offset, shard_size) + loaded_weight = loaded_weight.narrow(0, tp_rank * shard_size, shard_size) + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) class LlamaAttention(nn.Module): + _tp_plan = { + "qkv_proj": "Colwise_Sharded", + "o_proj": "Rowwise", + } + def __init__( self, config: LlamaConfig, @@ -176,7 +171,6 @@ def __init__( ) -> None: super().__init__() self.hidden_size = hidden_size - tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size @@ -205,20 +199,12 @@ def __init__( (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim, bias=False, ) - self.qkv_proj.total_num_heads = self.total_num_heads self.qkv_proj.head_size = self.head_dim - self.qkv_proj.total_num_kv_heads = self.total_num_kv_heads self.qkv_proj.num_heads = self.total_num_heads self.qkv_proj.num_kv_heads = self.total_num_kv_heads self.qkv_proj.weight_loader = types.MethodType( qkv_proj_weight_loader, self.qkv_proj ) - self.qkv_proj._get_shard_offset_mapping = types.MethodType( - _get_shard_offset_mapping, self.qkv_proj - ) - self.qkv_proj._get_shard_size_mapping = types.MethodType( - _get_shard_size_mapping, self.qkv_proj - ) self.qkv_proj.weight.weight_loader = self.qkv_proj.weight_loader self.qkv_proj.weight.output_dim = 0 self.o_proj = torch.nn.Linear( @@ -385,6 +371,7 @@ def __init__( self.config = config self.quant_config = quant_config self.torchao_config = global_server_args_dict["torchao_config"] + self.supports_torch_tp = True self.model = LlamaModel(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config)