Skip to content

Commit 2313452

Browse files
authored
update metainit for qwen next (#375)
* update metainit for qwen next
1 parent 3fea4fa commit 2313452

File tree

1 file changed

+17
-15
lines changed

1 file changed

+17
-15
lines changed

chatlearn/models/fsdp_module.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -285,20 +285,22 @@ def model_setup(self):
285285
model.to_empty(device="cuda")
286286

287287
# load real state dict
288-
options = StateDictOptions(full_state_dict=True, cpu_offload=False, broadcast_from_rank0=True)
289-
290-
# module-wise sync avoid OOM while run model like qwen3-moe-235B
291-
for name, module in model.named_modules():
292-
has_weights = any(k.startswith(name + ".") for k in full_state.keys()) and len(list(module.children()))==0
293-
if has_weights:
294-
set_model_state_dict(
295-
module,
296-
{k.replace(name + ".", ""): v for k, v in full_state.items() if k.startswith(name + ".")},
297-
options=options
298-
)
299-
# set_model_state_dict(model, full_state, options=options)
300-
301-
# load buffer data
288+
options = StateDictOptions(full_state_dict=True, cpu_offload=False, broadcast_from_rank0=True, strict=False)
289+
290+
# bucket-wise sync avoid OOM while run model like qwen3-moe-235B
291+
update_bucket = {}
292+
bucket_size = 3 * 1024 ** 3
293+
numel_cnt = 0
294+
for name, param in full_state.items():
295+
numel_cnt += param.numel()
296+
update_bucket[name] = full_state[name]
297+
if numel_cnt >= bucket_size:
298+
set_model_state_dict(model, update_bucket, options=options)
299+
update_bucket = {}
300+
numel_cnt = 0
301+
set_model_state_dict(model, update_bucket, options=options)
302+
303+
# load buffer data because persistent maybe False
302304
if dist.get_rank()==0:
303305
for name, buf in model.named_buffers():
304306
buf.data.copy_(buffer_dict[name])
@@ -323,7 +325,7 @@ def model_setup(self):
323325
# resume model weights
324326
if self.resume_training:
325327
self.load_checkpoint(self._episode_id)
326-
del full_state
328+
del full_state, update_bucket
327329
self.offload()
328330

329331
def get_fsdp_param_name(self, block_size=3_000_000_000) -> List[List]:

0 commit comments

Comments
 (0)