Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions HowToApproachvLLM.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,20 @@ for name, param in model.named_parameters():
- Each GPU stores complete heads (don't shard head_size)
- Enables independent attention computation per GPU

**Benchmark Results:**
Use the following command to enter the `src/myvllm/layers` directory, and verify whether the execution results are correct in a distributed environment.
```
cd src/myvllm/layers
CUDA_VISIBLE_DEVICES=0,1,2,3 uv run torchrun --nproc_per_node=4 linear.py
```
If the output shows `allclose=True`, it indicates that the multi-machine parallel results are consistent with the single-machine results.
```
[ColumnParallel] allclose=True, max_abs_err=0.000107
[MergedColumnParallel] allclose=True, max_abs_err=0.000103
[QKVColumnParallel] allclose=True, max_abs_err=0.000061
[RowParallel] allclose=True, max_abs_err=0.000011
```

**MLP Layer Pattern:**
- One ColumnParallel → One RowParallel → `dist.all_reduce`
- Output sharding of first layer = Input sharding of second layer
Expand Down
15 changes: 15 additions & 0 deletions HowToApproachvLLM_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,25 @@ for name, param in model.named_parameters():
- 每张 GPU 存完整的 heads(不对 `head_size` 维度做切分)
- 使每张 GPU 可以独立完成注意力计算

**基准测试:**
使用如下指令,进入`src/myvllm/layers`目录,在分布式环境测试运行结果是否正确
```
cd src/myvllm/layers
CUDA_VISIBLE_DEVICES=0,1,2,3 uv run torchrun --nproc_per_node=4 linear.py
```
输出结果`allclose=True`则代表多机并行结果与单机一致
```
[ColumnParallel] allclose=True, max_abs_err=0.000107
[MergedColumnParallel] allclose=True, max_abs_err=0.000103
[QKVColumnParallel] allclose=True, max_abs_err=0.000061
[RowParallel] allclose=True, max_abs_err=0.000011
```
**MLP 层的常见模式:**
- 一个 `ColumnParallel` → 一个 RowParallel → `dist.all_reduce`
- 第一层的输出切分方式 = 第二层的输入切分方式



---

### 1.4 词表嵌入(Vocab Embedding)与 LM Head ✅
Expand Down
265 changes: 256 additions & 9 deletions src/myvllm/layers/linear.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch.nn as nn
import torch
import torch.distributed as dist
import os

class LinearBase(nn.Module):
"""
Expand Down Expand Up @@ -225,14 +226,260 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return result




if __name__ == "__main__":
# Example usage
if dist.is_available() and not dist.is_initialized():
dist.init_process_group(
backend="gloo",
init_method="tcp://127.0.0.1:29500",
rank=0,
world_size=1,
# how to run?
# 1. cd src/myvllm/layers
# 2. CUDA_VISIBLE_DEVICES=0,1,2,3 uv run torchrun --nproc_per_node=4 linear.py # 4 GPUs

def _init_dist():
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ.get("LOCAL_RANK", 0))
backend = "nccl" if torch.cuda.is_available() else "gloo"

if torch.cuda.is_available():
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
dist.init_process_group(
backend=backend,
init_method="env://",
device_id=local_rank,
)
else:
device = torch.device("cpu")
dist.init_process_group( backend=backend, init_method="env://", )
return rank, world_size, local_rank, device

# Single linear layer and column parallel test
@torch.no_grad()
def test_column_parallel(device):
tp_rank = dist.get_rank()
tp_size = dist.get_world_size() # Number of parallel GPUs

in_features = 1024 * tp_size
out_features = 1024 * tp_size
batch = 4

# Ensure that each rank gets exactly the same full input/weight
g = torch.Generator(device="cpu").manual_seed(2026)
x_full = torch.randn(batch, in_features, generator=g)
w_full = torch.randn(out_features, in_features, generator=g)
b_full = torch.randn(out_features, generator=g)

x_full = x_full.to(device)
w_full = w_full.to(device)
b_full = b_full.to(device)

# reference (single GPU)
single_layer = ReplicatedLinear(in_features, out_features, bias=True).to(device)
single_layer.weight.weight_loader(single_layer.weight, w_full)
single_layer.bias.weight_loader(single_layer.bias, b_full)
y_single = single_layer(x_full)

# TP layer (each rank stores out_features/tp)
col_tp_layer = ColumnParallelLinear(in_features, out_features, bias=True).to(device)
col_tp_layer.weight.weight_loader(col_tp_layer.weight, w_full)
col_tp_layer.bias.weight_loader(col_tp_layer.bias, b_full)

# forward
y_col_tp = col_tp_layer(x_full) # [batch, out_features/tp]

# Restore full output: all_gather+concat
y_parts = [torch.empty_like(y_col_tp) for _ in range(tp_size)]
dist.all_gather(y_parts, y_col_tp)
y_full = torch.cat(y_parts, dim=-1) # [batch, out_features]

# Alignment check (print only at rank0)
max_err = (y_full - y_single).abs().max().item()
ok = torch.allclose(y_full, y_single, rtol=1e-4, atol=1e-4)
if tp_rank == 0:
print(f"[ColumnParallel] allclose={ok}, max_abs_err={max_err:.6f}")


# MergedColumnParallelLinear test
@torch.no_grad()
def test_merged_column_parallel(device):
tp_rank = dist.get_rank()
tp_size = dist.get_world_size()

# Make the dimension automatically adapt to tp_size to ensure divisibility
in_features = 1024 * tp_size
out_each = 512 * tp_size
out_sizes = [out_each, out_each, out_each] # Combination of Q, K and V matrices
batch = 4

g = torch.Generator(device="cpu").manual_seed(2026)
x_full = torch.randn(batch, in_features, generator=g)
w_q = torch.randn(out_sizes[0], in_features, generator=g)
w_k = torch.randn(out_sizes[1], in_features, generator=g)
w_v = torch.randn(out_sizes[2], in_features, generator=g)

x_full = x_full.to(device)
w_q = w_q.to(device)
w_k = w_k.to(device)
w_v = w_v.to(device)

# Reference: single-card equivalent output (Q | K | V concat)
y_ref = torch.cat(
[
nn.functional.linear(x_full, w_q, None),
nn.functional.linear(x_full, w_k, None),
nn.functional.linear(x_full, w_v, None),
],
dim=-1,
)

# TP merged layer
# NOTE: your MergedColumnParallelLinear defines weight_loader(param, loaded_weights, loaded_weight_id),
# so bias loader signature doesn't match base (param, loaded_weights). Therefore bias=False here.
merged = MergedColumnParallelLinear(in_features, out_sizes, bias=False).to(device)
merged.weight_loader(merged.weight, w_q, 0)
merged.weight_loader(merged.weight, w_k, 1)
merged.weight_loader(merged.weight, w_v, 2)

y_local = merged(x_full) # [batch, sum(out_sizes)/tp], layout: [q_local, k_local, v_local]

# all_gather then re-pack to [Q_all | K_all | V_all]
y_parts = [torch.empty_like(y_local) for _ in range(tp_size)]
dist.all_gather(y_parts, y_local)

ql = out_sizes[0] // tp_size
kl = out_sizes[1] // tp_size
vl = out_sizes[2] // tp_size

q_full = torch.cat([p[:, :ql] for p in y_parts], dim=-1)
k_full = torch.cat([p[:, ql : ql + kl] for p in y_parts], dim=-1)
v_full = torch.cat([p[:, ql + kl : ql + kl + vl] for p in y_parts], dim=-1)
y_full = torch.cat([q_full, k_full, v_full], dim=-1)

max_err = (y_full - y_ref).abs().max().item()
ok = torch.allclose(y_full, y_ref, rtol=1e-4, atol=1e-4)
if tp_rank == 0:
print(f"[MergedColumnParallel] allclose={ok}, max_abs_err={max_err:.6f}")


# QKVColumnParallelLinear test
@torch.no_grad()
def test_qkv_column_parallel(device):
tp_rank = dist.get_rank()
tp_size = dist.get_world_size()

# make input dim divisible by tp_size
input_size = 1024 * tp_size
head_size = 16
num_heads = 4 * tp_size
num_kv_heads = 2 * tp_size
batch = 4

g = torch.Generator(device="cpu").manual_seed(2026)
x_full = torch.randn(batch, input_size, generator=g)
w_q = torch.randn(head_size * num_heads, input_size, generator=g)
w_k = torch.randn(head_size * num_kv_heads, input_size, generator=g)
w_v = torch.randn(head_size * num_kv_heads, input_size, generator=g)

x_full = x_full.to(device)
w_q = w_q.to(device)
w_k = w_k.to(device)
w_v = w_v.to(device)

# reference: full Q|K|V
y_ref = torch.cat(
[
nn.functional.linear(x_full, w_q, None),
nn.functional.linear(x_full, w_k, None),
nn.functional.linear(x_full, w_v, None),
],
dim=-1,
)
layer = LinearBase(input_size=10, output_size=5)
print("LinearBase layer initialized:", layer)

# TP QKV layer: rank output layout [q_local | k_local | v_local]
qkv = QKVColumnParallelLinear(
input_size=input_size,
head_size=head_size,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
bias=False,
).to(device)

qkv.weight_loader(qkv.weight, w_q, "q")
qkv.weight_loader(qkv.weight, w_k, "k")
qkv.weight_loader(qkv.weight, w_v, "v")

y_local = qkv(x_full) # [batch, head_size*(local_h + 2*local_kv)]

# all_gather then re-pack to [Q_all | K_all | V_all]
y_parts = [torch.empty_like(y_local) for _ in range(tp_size)]
dist.all_gather(y_parts, y_local)

ql = head_size * (num_heads // tp_size)
kl = head_size * (num_kv_heads // tp_size)
vl = head_size * (num_kv_heads // tp_size)

q_full = torch.cat([p[:, :ql] for p in y_parts], dim=-1)
k_full = torch.cat([p[:, ql : ql + kl] for p in y_parts], dim=-1)
v_full = torch.cat([p[:, ql + kl : ql + kl + vl] for p in y_parts], dim=-1)
y_full = torch.cat([q_full, k_full, v_full], dim=-1)

max_err = (y_full - y_ref).abs().max().item()
ok = torch.allclose(y_full, y_ref, rtol=1e-4, atol=1e-4)
if tp_rank == 0:
print(f"[QKVColumnParallel] allclose={ok}, max_abs_err={max_err:.6f}")


# RowParallelLinear test
@torch.no_grad()
def test_row_parallel(device):
tp_rank = dist.get_rank()
tp_size = dist.get_world_size()

in_features = 128 * tp_size
out_features = 256
batch = 4

g = torch.Generator(device="cpu").manual_seed(2026)
x_full = torch.randn(batch, in_features, generator=g)
w_full = torch.randn(out_features, in_features, generator=g)
b_full = torch.randn(out_features, generator=g)

x_full = x_full.to(device)
w_full = w_full.to(device)
b_full = b_full.to(device)

# reference
single = ReplicatedLinear(in_features, out_features, bias=True).to(device)
single.weight.weight_loader(single.weight, w_full)
single.bias.weight_loader(single.bias, b_full)
y_ref = single(x_full)

# RowParallel
row_tp = RowParallelLinear(in_features, out_features, bias=True).to(device)
row_tp.weight.weight_loader(row_tp.weight, w_full)

if row_tp.bias is not None:
row_tp.bias.data.copy_(b_full / tp_size)

shard = in_features // tp_size
start = tp_rank * shard
x_part = x_full.narrow(-1, start, shard)

y_row = row_tp(x_part)

max_err = (y_row - y_ref).abs().max().item()
ok = torch.allclose(y_row, y_ref, rtol=1e-4, atol=1e-4)
if tp_rank == 0:
print(f"[RowParallel] allclose={ok}, max_abs_err={max_err:.6f}")

rank, world_size, local_rank, device = _init_dist()
if rank == 0:
print(f"Running TP tests with world_size={world_size} on device={device}")

# The test output 'allclose=True' means passed.
test_column_parallel(device)
test_merged_column_parallel(device)
test_qkv_column_parallel(device)
test_row_parallel(device)

dist.barrier()
dist.destroy_process_group()