Skip to content

Commit 3fea4fa

Browse files
authored
support flattened_bucket type param sync for fsdp2sglang (#373)
speedup fsdp2sgalng param sync time: Qwen3-30B-A3B: 30s -> 11s Qwen3-Next-80B-A3B-Instruct: 529s -> 32s
1 parent de820c9 commit 3fea4fa

File tree

3 files changed

+141
-114
lines changed

3 files changed

+141
-114
lines changed

chatlearn/models/fsdp_module.py

Lines changed: 103 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717
import os
1818
import random
1919
import gc
20-
from typing import List
20+
from typing import List, Dict
2121

2222
import numpy as np
2323
import torch
24+
from torch import Tensor
2425
import torch.distributed as dist
2526
from torch.distributed.tensor import DTensor
2627
from torch import optim, nn
@@ -304,6 +305,7 @@ def model_setup(self):
304305
torch.cuda.synchronize()
305306
for name, buf in model.named_buffers():
306307
dist.broadcast(buf, src=0)
308+
307309
self.model = model
308310
self.model.to(torch.float32)
309311

@@ -324,12 +326,16 @@ def model_setup(self):
324326
del full_state
325327
self.offload()
326328

327-
def get_fsdp_param_name(self, block_size=300_000_000) -> List[List]:
329+
def get_fsdp_param_name(self, block_size=3_000_000_000) -> List[List]:
328330
name_list = []
329331
param_cnt = 0
330332
current_group = []
331333
for name, param in self.model.named_parameters():
332-
param_cnt += param.numel()
334+
param_cnt += (
335+
param.numel() * self.fsdp_size
336+
if isinstance(param, DTensor)
337+
else param.numel()
338+
)
333339
current_group.append(name)
334340
if param_cnt >= block_size:
335341
name_list.append(current_group)
@@ -339,43 +345,115 @@ def get_fsdp_param_name(self, block_size=300_000_000) -> List[List]:
339345
name_list.append(current_group)
340346
return name_list
341347

348+
def convert_block2flattened_bucket(self, block_parameter: Dict[str, Tensor]):
349+
from sglang.srt.weight_sync.tensor_bucket import FlattenedTensorMetadata
350+
351+
flatten_tensor_list = []
352+
metadatas: List[FlattenedTensorMetadata] = []
353+
354+
def convert_tensor(
355+
name: str,
356+
param: Tensor,
357+
flatten_tensor_list: List[Tensor],
358+
metadatas: List[FlattenedTensorMetadata],
359+
buffer_offset=0,
360+
is_experts=False,
361+
num_block=1,
362+
):
363+
"""
364+
convert a param tensor(single or group mlp) to flatten_tensor_list
365+
which is used in sglang update_weights_from_tensor api
366+
"""
367+
assert (
368+
param.shape[0] % num_block == 0
369+
), "param can't be chunked by num_block in dim 0"
370+
interval = param.numel() // num_block
371+
shape = torch.Size((param.shape[0] // num_block,) + param.shape[1:])
372+
373+
for i in range(num_block):
374+
start_idx = buffer_offset
375+
end_idx = buffer_offset + interval
376+
buffer_offset = end_idx
377+
local_name = name.replace("group_mlp", f"experts.{i}") if is_experts else name
378+
metadata = FlattenedTensorMetadata(
379+
name=local_name,
380+
shape=shape,
381+
dtype=param.dtype,
382+
start_idx=start_idx,
383+
end_idx=end_idx,
384+
numel=interval,
385+
)
386+
metadatas.append(metadata)
387+
flattened_param = param.contiguous().view(-1)
388+
flatten_tensor_list.append(flattened_param)
389+
return flatten_tensor_list, metadatas, buffer_offset
390+
391+
buffer_offset = 0
392+
for name, param in block_parameter.items():
393+
param = (
394+
param.full_tensor().detach()
395+
if isinstance(param, DTensor)
396+
else param.detach()
397+
)
398+
if self.module_args.groupgemm and "group_mlp" in name:
399+
num_experts = self.model_config.num_experts
400+
flatten_tensor_list, metadatas, buffer_offset = convert_tensor(
401+
name=name,
402+
param=param,
403+
flatten_tensor_list=flatten_tensor_list,
404+
metadatas=metadatas,
405+
buffer_offset=buffer_offset,
406+
is_experts=True,
407+
num_block=num_experts,
408+
)
409+
else:
410+
flatten_tensor_list, metadatas, buffer_offset = convert_tensor(
411+
name, param, flatten_tensor_list, metadatas, buffer_offset
412+
)
413+
flattened_tensor = torch.cat(flatten_tensor_list)
414+
return flattened_tensor, metadatas
415+
342416
def get_weight_ipc_handles_by_name(self, block_name: List[str]):
343417
"""
344418
get fsdp warpped module weight by name get from named_parameters
345419
avoid get total model state_dict
346420
"""
421+
if self.module_args.use_expandable_segments:
422+
torch.cuda.memory._set_allocator_settings("expandable_segments:False")
423+
# get matched param full tensor
424+
block_parameter = {}
425+
reduce_tensor_dict = {} # used for vllm
426+
for name, param in self.model.named_parameters():
427+
if name in block_name:
428+
block_parameter[name] = (
429+
param.full_tensor().detach()
430+
if isinstance(param, DTensor)
431+
else param.detach()
432+
)
433+
347434
rollout_engine = self._runtime_args.rollout_backend
348435
if rollout_engine == "sglang":
349436
# lazy import sglang
350437
from sglang.srt.utils import MultiprocessingSerializer
351438
from sglang.srt.patch_torch import monkey_patch_torch_reductions
439+
352440
monkey_patch_torch_reductions()
353-
if self.module_args.use_expandable_segments:
354-
torch.cuda.memory._set_allocator_settings("expandable_segments:False")
355-
reduce_tensor_dict = {}
356-
serialize_func = reduce_tensor if rollout_engine=='vllm' else MultiprocessingSerializer.serialize
357-
for name, param in self.model.named_parameters():
358-
if name in block_name:
359-
if self.module_args.groupgemm and "group_mlp" in name:
360-
# This model is using groupgemm for moe forward
361-
param = param.full_tensor().detach()
362-
num_experts = self.model_config.num_experts
363-
#split_size = param.shape[0] // num_experts
364-
param_per_expert = torch.chunk(param, num_experts, dim=0)
365-
#param_per_expert = torch.split(param, split_size, dim=0)
366-
for i in range(num_experts):
367-
local_name = name.replace('group_mlp', f"experts.{i}")
368-
reduce_tensor_dict[local_name] = serialize_func(param_per_expert[i])
369-
else:
370-
reduce_tensor_dict[name] = serialize_func(param.full_tensor().detach() \
371-
if isinstance(param, DTensor) else param.detach())
441+
flattened_tensor, metadatas = self.convert_block2flattened_bucket(
442+
block_parameter
443+
)
444+
bucket_dict = {"flattened_tensor": flattened_tensor, "metadata": metadatas}
445+
serialized_bucket = MultiprocessingSerializer.serialize(
446+
bucket_dict, output_str=True
447+
)
448+
return serialized_bucket
449+
elif rollout_engine == "vllm":
450+
for name, param in block_parameter.items():
451+
reduce_tensor_dict[name] = reduce_tensor(param)
452+
372453
if self.module_args.use_expandable_segments:
373454
torch.cuda.memory._set_allocator_settings("expandable_segments:True")
374455
return reduce_tensor_dict
375456

376-
def update_weights_from_buckets(self, buckets):
377-
pass
378-
379457
@torch.no_grad()
380458
def onload_weights(self, empty_cache=True):
381459
device_id = torch.cuda.current_device()

chatlearn/models/patches/transformers/qwen3_next_moe_patch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919
from torch import nn
2020
import torch.nn.functional as F
21+
from transformers.activations import ACT2FN
2122

2223
from chatlearn.models.patches.transformers.layers.groupgemm import MoeGroupMLP
2324

chatlearn/models/sglang_module.py

Lines changed: 37 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
ResumeMemoryOccupationReqInput,
5151
UpdateWeightsFromTensorReqInput,
5252
)
53-
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
5453
from sglang.srt.utils import (
5554
MultiprocessingSerializer,
5655
assert_pkg_version,
@@ -397,49 +396,21 @@ def generate(self, query: List[Dict], is_eval: bool) -> List[Dict]:
397396
self.flush_cache()
398397
return outputs
399398

400-
def update_weights_from_ipc_handles(self, reduce_data, load_format=None):
401-
if load_format == "flattened_bucket":
402-
gathered_data = None
403-
if self.is_engine():
404-
gathered_data = [None] * self._tp_size
405-
dist.gather_object(
406-
obj=reduce_data,
407-
object_gather_list=gathered_data,
408-
dst=self.cpu_mesh["tp"].mesh.tolist()[0],
409-
group=self.cpu_mesh["tp"].get_group(),
410-
)
411-
if self.is_engine():
412-
self.llm.update_weights_from_tensor(
413-
named_tensors=gathered_data,
414-
load_format=load_format,
415-
)
416-
torch.cuda.synchronize()
417-
return
418-
419-
for index, (name, serialized_tensor) in enumerate(reduce_data.items()):
420-
if self.is_engine():
421-
gathered_serialized_tensors = [None] * self._tp_size
422-
else:
423-
gathered_serialized_tensors = None
424-
425-
dist.gather_object(
426-
obj=serialized_tensor,
427-
object_gather_list=gathered_serialized_tensors,
428-
dst=self.cpu_mesh["tp"].mesh.tolist()[0],
429-
group=self.cpu_mesh["tp"].get_group(),
399+
def update_weights_from_ipc_handles(self, reduce_data):
400+
gathered_data = None
401+
if self.is_engine():
402+
gathered_data = [None] * self._tp_size
403+
dist.gather_object(
404+
obj=reduce_data,
405+
object_gather_list=gathered_data,
406+
dst=self.cpu_mesh["tp"].mesh.tolist()[0],
407+
group=self.cpu_mesh["tp"].get_group(),
408+
)
409+
if self.is_engine():
410+
self.llm.update_weights_from_tensor(
411+
named_tensors=gathered_data,
412+
load_format="flattened_bucket",
430413
)
431-
432-
if self.is_engine():
433-
self.llm.update_weights_from_tensor(
434-
named_tensors=[
435-
(
436-
name,
437-
LocalSerializedTensor(values=gathered_serialized_tensors),
438-
)
439-
],
440-
# load_format=load_format,
441-
flush_cache=index == len(reduce_data) - 1,
442-
)
443414
torch.cuda.synchronize()
444415

445416
def flush_cache(self):
@@ -567,6 +538,8 @@ def parameter_sync(self):
567538

568539
@torch.no_grad()
569540
def update_weights_from_buckets(self, buckets: List[Optional['BucketInfo']]):
541+
"""Used for Mcore2SGLang Parameter Sync
542+
"""
570543
from sglang.srt.patch_torch import monkey_patch_torch_reductions
571544
monkey_patch_torch_reductions()
572545
param_id_to_update = set()
@@ -584,10 +557,12 @@ def update_weights_from_buckets(self, buckets: List[Optional['BucketInfo']]):
584557
for shard_idx, (offset, sharded_tensor_info) in enumerate(bucket.recv_layout):
585558
param_id_to_bucket[sharded_tensor_info.param_id].append((bucket_idx, shard_idx))
586559

560+
# 1-dim concated flattened tensor
587561
buffer = None
588562
buffer_offset = 0
589563
buffer_size = 4 * 1024 ** 3
590-
metadatas = []
564+
# metadata: name, shape, dtype, start_idx, end_idx, numel for every tensor item in buffer
565+
metadatas: List[FlattenedTensorMetadata] = []
591566
for param_id in param_id_to_update:
592567
param_name = self.param_id_to_local_name[param_id]
593568
shard_info = self.param_id_to_metadata[param_id]
@@ -600,7 +575,7 @@ def update_weights_from_buckets(self, buckets: List[Optional['BucketInfo']]):
600575
serialized_bucket = MultiprocessingSerializer.serialize(
601576
bucket_dict, output_str=True
602577
)
603-
self.update_weights_from_ipc_handles(serialized_bucket, load_format="flattened_bucket")
578+
self.update_weights_from_ipc_handles(serialized_bucket)
604579
buffer = torch.empty(buffer_size, dtype=shard_info.dtype, device='cuda')
605580
buffer_offset = 0
606581
metadatas = []
@@ -630,7 +605,7 @@ def update_weights_from_buckets(self, buckets: List[Optional['BucketInfo']]):
630605
serialized_bucket = MultiprocessingSerializer.serialize(
631606
bucket_dict, output_str=True
632607
)
633-
self.update_weights_from_ipc_handles(serialized_bucket, load_format="flattened_bucket")
608+
self.update_weights_from_ipc_handles(serialized_bucket)
634609

635610
del buffer, weight, shard, bucket_dict
636611
torch.cuda.synchronize()
@@ -727,49 +702,22 @@ async def generate_per_request(self, query: Dict, is_eval: bool) -> Dict:
727702
)
728703
return outputs
729704

730-
async def update_weights_from_ipc_handles(self, reduce_data, load_format=None):
731-
if load_format == "flattened_bucket":
732-
gathered_data = None
733-
if self.is_engine():
734-
gathered_data = [None] * self._tp_size
735-
dist.gather_object(
736-
obj=reduce_data,
737-
object_gather_list=gathered_data,
738-
dst=self.cpu_mesh["tp"].mesh.tolist()[0],
739-
group=self.cpu_mesh["tp"].get_group(),
740-
)
741-
if self.is_engine():
742-
await self.llm.update_weights_from_tensor(
743-
named_tensors=gathered_data,
744-
load_format=load_format,
745-
)
746-
torch.cuda.synchronize()
747-
return
705+
async def update_weights_from_ipc_handles(self, reduce_data):
748706

749-
for index, (name, serialized_tensor) in enumerate(reduce_data.items()):
750-
if self.is_engine():
751-
gathered_serialized_tensors = [None] * self._tp_size
752-
else:
753-
gathered_serialized_tensors = None
754-
755-
dist.gather_object(
756-
obj=serialized_tensor,
757-
object_gather_list=gathered_serialized_tensors,
758-
dst=self.cpu_mesh["tp"].mesh.tolist()[0],
759-
group=self.cpu_mesh["tp"].get_group(),
707+
gathered_data = None
708+
if self.is_engine():
709+
gathered_data = [None] * self._tp_size
710+
dist.gather_object(
711+
obj=reduce_data,
712+
object_gather_list=gathered_data,
713+
dst=self.cpu_mesh["tp"].mesh.tolist()[0],
714+
group=self.cpu_mesh["tp"].get_group(),
715+
)
716+
if self.is_engine():
717+
await self.llm.update_weights_from_tensor(
718+
named_tensors=gathered_data,
719+
load_format="flattened_bucket",
760720
)
761-
762-
if self.is_engine():
763-
await self.llm.update_weights_from_tensor(
764-
named_tensors=[
765-
(
766-
name,
767-
LocalSerializedTensor(values=gathered_serialized_tensors),
768-
)
769-
],
770-
# load_format=load_format,
771-
flush_cache=index == len(reduce_data) - 1,
772-
)
773721
torch.cuda.synchronize()
774722

775723
@torch.no_grad()
@@ -807,7 +755,7 @@ async def update_weights_from_buckets(self, buckets: List[Optional['BucketInfo']
807755
serialized_bucket = MultiprocessingSerializer.serialize(
808756
bucket_dict, output_str=True
809757
)
810-
await self.update_weights_from_ipc_handles(serialized_bucket, load_format="flattened_bucket")
758+
await self.update_weights_from_ipc_handles(serialized_bucket)
811759
buffer = torch.empty(buffer_size, dtype=shard_info.dtype, device='cuda')
812760
buffer_offset = 0
813761
metadatas = []
@@ -837,7 +785,7 @@ async def update_weights_from_buckets(self, buckets: List[Optional['BucketInfo']
837785
serialized_bucket = MultiprocessingSerializer.serialize(
838786
bucket_dict, output_str=True
839787
)
840-
await self.update_weights_from_ipc_handles(serialized_bucket, load_format="flattened_bucket")
788+
await self.update_weights_from_ipc_handles(serialized_bucket)
841789

842790
del buffer, weight, shard, bucket_dict
843791
torch.cuda.synchronize()

0 commit comments

Comments
 (0)