Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions chatlearn/models/fsdp_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,20 +285,22 @@ def model_setup(self):
model.to_empty(device="cuda")

# load real state dict
options = StateDictOptions(full_state_dict=True, cpu_offload=False, broadcast_from_rank0=True)

# module-wise sync avoid OOM while run model like qwen3-moe-235B
for name, module in model.named_modules():
has_weights = any(k.startswith(name + ".") for k in full_state.keys()) and len(list(module.children()))==0
if has_weights:
set_model_state_dict(
module,
{k.replace(name + ".", ""): v for k, v in full_state.items() if k.startswith(name + ".")},
options=options
)
# set_model_state_dict(model, full_state, options=options)

# load buffer data
options = StateDictOptions(full_state_dict=True, cpu_offload=False, broadcast_from_rank0=True, strict=False)

# bucket-wise sync avoid OOM while run model like qwen3-moe-235B
update_bucket = {}
bucket_size = 3 * 1024 ** 3
numel_cnt = 0
for name, param in full_state.items():
numel_cnt += param.numel()
update_bucket[name] = full_state[name]
if numel_cnt >= bucket_size:
set_model_state_dict(model, update_bucket, options=options)
update_bucket = {}
numel_cnt = 0
set_model_state_dict(model, update_bucket, options=options)

# load buffer data because persistent maybe False
if dist.get_rank()==0:
for name, buf in model.named_buffers():
buf.data.copy_(buffer_dict[name])
Expand All @@ -323,7 +325,7 @@ def model_setup(self):
# resume model weights
if self.resume_training:
self.load_checkpoint(self._episode_id)
del full_state
del full_state, update_bucket
self.offload()

def get_fsdp_param_name(self, block_size=3_000_000_000) -> List[List]:
Expand Down