diff --git a/HowToApproachvLLM.md b/HowToApproachvLLM.md index b29a0ce..7c970a2 100644 --- a/HowToApproachvLLM.md +++ b/HowToApproachvLLM.md @@ -145,6 +145,20 @@ for name, param in model.named_parameters(): - Each GPU stores complete heads (don't shard head_size) - Enables independent attention computation per GPU +**Benchmark Results:** +Use the following command to enter the `src/myvllm/layers` directory, and verify whether the execution results are correct in a distributed environment. +``` +cd src/myvllm/layers +CUDA_VISIBLE_DEVICES=0,1,2,3 uv run torchrun --nproc_per_node=4 linear.py +``` +If the output shows `allclose=True`, it indicates that the multi-machine parallel results are consistent with the single-machine results. +``` +[ColumnParallel] allclose=True, max_abs_err=0.000107 +[MergedColumnParallel] allclose=True, max_abs_err=0.000103 +[QKVColumnParallel] allclose=True, max_abs_err=0.000061 +[RowParallel] allclose=True, max_abs_err=0.000011 +``` + **MLP Layer Pattern:** - One ColumnParallel → One RowParallel → `dist.all_reduce` - Output sharding of first layer = Input sharding of second layer diff --git a/HowToApproachvLLM_zh.md b/HowToApproachvLLM_zh.md index f911196..eb93425 100644 --- a/HowToApproachvLLM_zh.md +++ b/HowToApproachvLLM_zh.md @@ -146,10 +146,25 @@ for name, param in model.named_parameters(): - 每张 GPU 存完整的 heads(不对 `head_size` 维度做切分) - 使每张 GPU 可以独立完成注意力计算 +**基准测试:** +使用如下指令,进入`src/myvllm/layers`目录,在分布式环境测试运行结果是否正确 +``` +cd src/myvllm/layers +CUDA_VISIBLE_DEVICES=0,1,2,3 uv run torchrun --nproc_per_node=4 linear.py +``` +输出结果`allclose=True`则代表多机并行结果与单机一致 +``` +[ColumnParallel] allclose=True, max_abs_err=0.000107 +[MergedColumnParallel] allclose=True, max_abs_err=0.000103 +[QKVColumnParallel] allclose=True, max_abs_err=0.000061 +[RowParallel] allclose=True, max_abs_err=0.000011 +``` **MLP 层的常见模式:** - 一个 `ColumnParallel` → 一个 RowParallel → `dist.all_reduce` - 第一层的输出切分方式 = 第二层的输入切分方式 + + --- ### 1.4 词表嵌入(Vocab Embedding)与 LM Head ✅ diff --git a/src/myvllm/layers/linear.py b/src/myvllm/layers/linear.py index 4f032d2..7942fee 100644 --- a/src/myvllm/layers/linear.py +++ b/src/myvllm/layers/linear.py @@ -1,6 +1,7 @@ import torch.nn as nn import torch import torch.distributed as dist +import os class LinearBase(nn.Module): """ @@ -225,14 +226,260 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return result + + if __name__ == "__main__": - # Example usage - if dist.is_available() and not dist.is_initialized(): - dist.init_process_group( - backend="gloo", - init_method="tcp://127.0.0.1:29500", - rank=0, - world_size=1, + # how to run? + # 1. cd src/myvllm/layers + # 2. CUDA_VISIBLE_DEVICES=0,1,2,3 uv run torchrun --nproc_per_node=4 linear.py # 4 GPUs + + def _init_dist(): + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + backend = "nccl" if torch.cuda.is_available() else "gloo" + + if torch.cuda.is_available(): + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + dist.init_process_group( + backend=backend, + init_method="env://", + device_id=local_rank, + ) + else: + device = torch.device("cpu") + dist.init_process_group( backend=backend, init_method="env://", ) + return rank, world_size, local_rank, device + + # Single linear layer and column parallel test + @torch.no_grad() + def test_column_parallel(device): + tp_rank = dist.get_rank() + tp_size = dist.get_world_size() # Number of parallel GPUs + + in_features = 1024 * tp_size + out_features = 1024 * tp_size + batch = 4 + + # Ensure that each rank gets exactly the same full input/weight + g = torch.Generator(device="cpu").manual_seed(2026) + x_full = torch.randn(batch, in_features, generator=g) + w_full = torch.randn(out_features, in_features, generator=g) + b_full = torch.randn(out_features, generator=g) + + x_full = x_full.to(device) + w_full = w_full.to(device) + b_full = b_full.to(device) + + # reference (single GPU) + single_layer = ReplicatedLinear(in_features, out_features, bias=True).to(device) + single_layer.weight.weight_loader(single_layer.weight, w_full) + single_layer.bias.weight_loader(single_layer.bias, b_full) + y_single = single_layer(x_full) + + # TP layer (each rank stores out_features/tp) + col_tp_layer = ColumnParallelLinear(in_features, out_features, bias=True).to(device) + col_tp_layer.weight.weight_loader(col_tp_layer.weight, w_full) + col_tp_layer.bias.weight_loader(col_tp_layer.bias, b_full) + + # forward + y_col_tp = col_tp_layer(x_full) # [batch, out_features/tp] + + # Restore full output: all_gather+concat + y_parts = [torch.empty_like(y_col_tp) for _ in range(tp_size)] + dist.all_gather(y_parts, y_col_tp) + y_full = torch.cat(y_parts, dim=-1) # [batch, out_features] + + # Alignment check (print only at rank0) + max_err = (y_full - y_single).abs().max().item() + ok = torch.allclose(y_full, y_single, rtol=1e-4, atol=1e-4) + if tp_rank == 0: + print(f"[ColumnParallel] allclose={ok}, max_abs_err={max_err:.6f}") + + + # MergedColumnParallelLinear test + @torch.no_grad() + def test_merged_column_parallel(device): + tp_rank = dist.get_rank() + tp_size = dist.get_world_size() + + # Make the dimension automatically adapt to tp_size to ensure divisibility + in_features = 1024 * tp_size + out_each = 512 * tp_size + out_sizes = [out_each, out_each, out_each] # Combination of Q, K and V matrices + batch = 4 + + g = torch.Generator(device="cpu").manual_seed(2026) + x_full = torch.randn(batch, in_features, generator=g) + w_q = torch.randn(out_sizes[0], in_features, generator=g) + w_k = torch.randn(out_sizes[1], in_features, generator=g) + w_v = torch.randn(out_sizes[2], in_features, generator=g) + + x_full = x_full.to(device) + w_q = w_q.to(device) + w_k = w_k.to(device) + w_v = w_v.to(device) + + # Reference: single-card equivalent output (Q | K | V concat) + y_ref = torch.cat( + [ + nn.functional.linear(x_full, w_q, None), + nn.functional.linear(x_full, w_k, None), + nn.functional.linear(x_full, w_v, None), + ], + dim=-1, + ) + + # TP merged layer + # NOTE: your MergedColumnParallelLinear defines weight_loader(param, loaded_weights, loaded_weight_id), + # so bias loader signature doesn't match base (param, loaded_weights). Therefore bias=False here. + merged = MergedColumnParallelLinear(in_features, out_sizes, bias=False).to(device) + merged.weight_loader(merged.weight, w_q, 0) + merged.weight_loader(merged.weight, w_k, 1) + merged.weight_loader(merged.weight, w_v, 2) + + y_local = merged(x_full) # [batch, sum(out_sizes)/tp], layout: [q_local, k_local, v_local] + + # all_gather then re-pack to [Q_all | K_all | V_all] + y_parts = [torch.empty_like(y_local) for _ in range(tp_size)] + dist.all_gather(y_parts, y_local) + + ql = out_sizes[0] // tp_size + kl = out_sizes[1] // tp_size + vl = out_sizes[2] // tp_size + + q_full = torch.cat([p[:, :ql] for p in y_parts], dim=-1) + k_full = torch.cat([p[:, ql : ql + kl] for p in y_parts], dim=-1) + v_full = torch.cat([p[:, ql + kl : ql + kl + vl] for p in y_parts], dim=-1) + y_full = torch.cat([q_full, k_full, v_full], dim=-1) + + max_err = (y_full - y_ref).abs().max().item() + ok = torch.allclose(y_full, y_ref, rtol=1e-4, atol=1e-4) + if tp_rank == 0: + print(f"[MergedColumnParallel] allclose={ok}, max_abs_err={max_err:.6f}") + + + # QKVColumnParallelLinear test + @torch.no_grad() + def test_qkv_column_parallel(device): + tp_rank = dist.get_rank() + tp_size = dist.get_world_size() + + # make input dim divisible by tp_size + input_size = 1024 * tp_size + head_size = 16 + num_heads = 4 * tp_size + num_kv_heads = 2 * tp_size + batch = 4 + + g = torch.Generator(device="cpu").manual_seed(2026) + x_full = torch.randn(batch, input_size, generator=g) + w_q = torch.randn(head_size * num_heads, input_size, generator=g) + w_k = torch.randn(head_size * num_kv_heads, input_size, generator=g) + w_v = torch.randn(head_size * num_kv_heads, input_size, generator=g) + + x_full = x_full.to(device) + w_q = w_q.to(device) + w_k = w_k.to(device) + w_v = w_v.to(device) + + # reference: full Q|K|V + y_ref = torch.cat( + [ + nn.functional.linear(x_full, w_q, None), + nn.functional.linear(x_full, w_k, None), + nn.functional.linear(x_full, w_v, None), + ], + dim=-1, ) - layer = LinearBase(input_size=10, output_size=5) - print("LinearBase layer initialized:", layer) \ No newline at end of file + + # TP QKV layer: rank output layout [q_local | k_local | v_local] + qkv = QKVColumnParallelLinear( + input_size=input_size, + head_size=head_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + bias=False, + ).to(device) + + qkv.weight_loader(qkv.weight, w_q, "q") + qkv.weight_loader(qkv.weight, w_k, "k") + qkv.weight_loader(qkv.weight, w_v, "v") + + y_local = qkv(x_full) # [batch, head_size*(local_h + 2*local_kv)] + + # all_gather then re-pack to [Q_all | K_all | V_all] + y_parts = [torch.empty_like(y_local) for _ in range(tp_size)] + dist.all_gather(y_parts, y_local) + + ql = head_size * (num_heads // tp_size) + kl = head_size * (num_kv_heads // tp_size) + vl = head_size * (num_kv_heads // tp_size) + + q_full = torch.cat([p[:, :ql] for p in y_parts], dim=-1) + k_full = torch.cat([p[:, ql : ql + kl] for p in y_parts], dim=-1) + v_full = torch.cat([p[:, ql + kl : ql + kl + vl] for p in y_parts], dim=-1) + y_full = torch.cat([q_full, k_full, v_full], dim=-1) + + max_err = (y_full - y_ref).abs().max().item() + ok = torch.allclose(y_full, y_ref, rtol=1e-4, atol=1e-4) + if tp_rank == 0: + print(f"[QKVColumnParallel] allclose={ok}, max_abs_err={max_err:.6f}") + + + # RowParallelLinear test + @torch.no_grad() + def test_row_parallel(device): + tp_rank = dist.get_rank() + tp_size = dist.get_world_size() + + in_features = 128 * tp_size + out_features = 256 + batch = 4 + + g = torch.Generator(device="cpu").manual_seed(2026) + x_full = torch.randn(batch, in_features, generator=g) + w_full = torch.randn(out_features, in_features, generator=g) + b_full = torch.randn(out_features, generator=g) + + x_full = x_full.to(device) + w_full = w_full.to(device) + b_full = b_full.to(device) + + # reference + single = ReplicatedLinear(in_features, out_features, bias=True).to(device) + single.weight.weight_loader(single.weight, w_full) + single.bias.weight_loader(single.bias, b_full) + y_ref = single(x_full) + + # RowParallel + row_tp = RowParallelLinear(in_features, out_features, bias=True).to(device) + row_tp.weight.weight_loader(row_tp.weight, w_full) + + if row_tp.bias is not None: + row_tp.bias.data.copy_(b_full / tp_size) + + shard = in_features // tp_size + start = tp_rank * shard + x_part = x_full.narrow(-1, start, shard) + + y_row = row_tp(x_part) + + max_err = (y_row - y_ref).abs().max().item() + ok = torch.allclose(y_row, y_ref, rtol=1e-4, atol=1e-4) + if tp_rank == 0: + print(f"[RowParallel] allclose={ok}, max_abs_err={max_err:.6f}") + + rank, world_size, local_rank, device = _init_dist() + if rank == 0: + print(f"Running TP tests with world_size={world_size} on device={device}") + + # The test output 'allclose=True' means passed. + test_column_parallel(device) + test_merged_column_parallel(device) + test_qkv_column_parallel(device) + test_row_parallel(device) + + dist.barrier() + dist.destroy_process_group() \ No newline at end of file