diff --git a/chatlearn/models/fsdp_module.py b/chatlearn/models/fsdp_module.py index 68e0a0a2..5e17f294 100644 --- a/chatlearn/models/fsdp_module.py +++ b/chatlearn/models/fsdp_module.py @@ -285,20 +285,22 @@ def model_setup(self): model.to_empty(device="cuda") # load real state dict - options = StateDictOptions(full_state_dict=True, cpu_offload=False, broadcast_from_rank0=True) - - # module-wise sync avoid OOM while run model like qwen3-moe-235B - for name, module in model.named_modules(): - has_weights = any(k.startswith(name + ".") for k in full_state.keys()) and len(list(module.children()))==0 - if has_weights: - set_model_state_dict( - module, - {k.replace(name + ".", ""): v for k, v in full_state.items() if k.startswith(name + ".")}, - options=options - ) - # set_model_state_dict(model, full_state, options=options) - - # load buffer data + options = StateDictOptions(full_state_dict=True, cpu_offload=False, broadcast_from_rank0=True, strict=False) + + # bucket-wise sync avoid OOM while run model like qwen3-moe-235B + update_bucket = {} + bucket_size = 3 * 1024 ** 3 + numel_cnt = 0 + for name, param in full_state.items(): + numel_cnt += param.numel() + update_bucket[name] = full_state[name] + if numel_cnt >= bucket_size: + set_model_state_dict(model, update_bucket, options=options) + update_bucket = {} + numel_cnt = 0 + set_model_state_dict(model, update_bucket, options=options) + + # load buffer data because persistent maybe False if dist.get_rank()==0: for name, buf in model.named_buffers(): buf.data.copy_(buffer_dict[name]) @@ -323,7 +325,7 @@ def model_setup(self): # resume model weights if self.resume_training: self.load_checkpoint(self._episode_id) - del full_state + del full_state, update_bucket self.offload() def get_fsdp_param_name(self, block_size=3_000_000_000) -> List[List]: