From 2f682982bb42759b514084da9da3ccf59eda9203 Mon Sep 17 00:00:00 2001 From: "yanhaiqiang.yhq" Date: Thu, 18 Sep 2025 13:54:11 +0800 Subject: [PATCH 1/4] update metainit for qwen next --- chatlearn/models/fsdp_module.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/chatlearn/models/fsdp_module.py b/chatlearn/models/fsdp_module.py index 68e0a0a2..3315315b 100644 --- a/chatlearn/models/fsdp_module.py +++ b/chatlearn/models/fsdp_module.py @@ -285,20 +285,22 @@ def model_setup(self): model.to_empty(device="cuda") # load real state dict - options = StateDictOptions(full_state_dict=True, cpu_offload=False, broadcast_from_rank0=True) + options = StateDictOptions(full_state_dict=True, cpu_offload=False, broadcast_from_rank0=True, strict=False) # module-wise sync avoid OOM while run model like qwen3-moe-235B - for name, module in model.named_modules(): - has_weights = any(k.startswith(name + ".") for k in full_state.keys()) and len(list(module.children()))==0 - if has_weights: - set_model_state_dict( - module, - {k.replace(name + ".", ""): v for k, v in full_state.items() if k.startswith(name + ".")}, - options=options - ) - # set_model_state_dict(model, full_state, options=options) - - # load buffer data + update_bucket = {} + bucket_size = 3 * 1024 ** 3 + numel_cnt = 0 + for name, param in full_state.items(): + numel_cnt += param.numel() + update_bucket[name] = full_state[name] + if numel_cnt >= bucket_size: + set_model_state_dict(model, update_bucket, options=options) + update_bucket = {} + numel_cnt = 0 + set_model_state_dict(model, update_bucket, options=options) + + # load buffer data because persistent maybe False if dist.get_rank()==0: for name, buf in model.named_buffers(): buf.data.copy_(buffer_dict[name]) @@ -323,7 +325,7 @@ def model_setup(self): # resume model weights if self.resume_training: self.load_checkpoint(self._episode_id) - del full_state + del full_state, update_bucket self.offload() def get_fsdp_param_name(self, block_size=3_000_000_000) -> List[List]: From 48e21961a8075accd7cf1504047dbc86fe1fbe09 Mon Sep 17 00:00:00 2001 From: "yanhaiqiang.yhq" Date: Thu, 18 Sep 2025 13:55:22 +0800 Subject: [PATCH 2/4] update metainit for qwen next --- chatlearn/models/fsdp_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chatlearn/models/fsdp_module.py b/chatlearn/models/fsdp_module.py index 3315315b..5e17f294 100644 --- a/chatlearn/models/fsdp_module.py +++ b/chatlearn/models/fsdp_module.py @@ -287,7 +287,7 @@ def model_setup(self): # load real state dict options = StateDictOptions(full_state_dict=True, cpu_offload=False, broadcast_from_rank0=True, strict=False) - # module-wise sync avoid OOM while run model like qwen3-moe-235B + # bucket-wise sync avoid OOM while run model like qwen3-moe-235B update_bucket = {} bucket_size = 3 * 1024 ** 3 numel_cnt = 0 From 1238451c2b5f92e7f6cde88375a65cd8272c08d1 Mon Sep 17 00:00:00 2001 From: "yanhaiqiang.yhq" Date: Thu, 18 Sep 2025 16:26:17 +0800 Subject: [PATCH 3/4] add --- chatlearn/models/fsdp_module.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/chatlearn/models/fsdp_module.py b/chatlearn/models/fsdp_module.py index 5e17f294..112a65b1 100644 --- a/chatlearn/models/fsdp_module.py +++ b/chatlearn/models/fsdp_module.py @@ -254,6 +254,7 @@ def model_setup(self): # get state_dict to init model for meta init full_state = None + buffer_dict = None if self.module_args.meta_init: full_state = model.state_dict() @@ -325,7 +326,7 @@ def model_setup(self): # resume model weights if self.resume_training: self.load_checkpoint(self._episode_id) - del full_state, update_bucket + del full_state, buffer_dict self.offload() def get_fsdp_param_name(self, block_size=3_000_000_000) -> List[List]: From 93239d326ae57e7cc1a641b6c140b7ae0f2433fd Mon Sep 17 00:00:00 2001 From: "yanhaiqiang.yhq" Date: Thu, 18 Sep 2025 16:27:50 +0800 Subject: [PATCH 4/4] add --- chatlearn/models/fsdp_module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chatlearn/models/fsdp_module.py b/chatlearn/models/fsdp_module.py index 112a65b1..5e61f3c4 100644 --- a/chatlearn/models/fsdp_module.py +++ b/chatlearn/models/fsdp_module.py @@ -254,7 +254,7 @@ def model_setup(self): # get state_dict to init model for meta init full_state = None - buffer_dict = None + update_bucket = None if self.module_args.meta_init: full_state = model.state_dict() @@ -326,7 +326,7 @@ def model_setup(self): # resume model weights if self.resume_training: self.load_checkpoint(self._episode_id) - del full_state, buffer_dict + del full_state, update_bucket self.offload() def get_fsdp_param_name(self, block_size=3_000_000_000) -> List[List]: