@@ -285,20 +285,22 @@ def model_setup(self):
285
285
model .to_empty (device = "cuda" )
286
286
287
287
# load real state dict
288
- options = StateDictOptions (full_state_dict = True , cpu_offload = False , broadcast_from_rank0 = True )
289
-
290
- # module-wise sync avoid OOM while run model like qwen3-moe-235B
291
- for name , module in model .named_modules ():
292
- has_weights = any (k .startswith (name + "." ) for k in full_state .keys ()) and len (list (module .children ()))== 0
293
- if has_weights :
294
- set_model_state_dict (
295
- module ,
296
- {k .replace (name + "." , "" ): v for k , v in full_state .items () if k .startswith (name + "." )},
297
- options = options
298
- )
299
- # set_model_state_dict(model, full_state, options=options)
300
-
301
- # load buffer data
288
+ options = StateDictOptions (full_state_dict = True , cpu_offload = False , broadcast_from_rank0 = True , strict = False )
289
+
290
+ # bucket-wise sync avoid OOM while run model like qwen3-moe-235B
291
+ update_bucket = {}
292
+ bucket_size = 3 * 1024 ** 3
293
+ numel_cnt = 0
294
+ for name , param in full_state .items ():
295
+ numel_cnt += param .numel ()
296
+ update_bucket [name ] = full_state [name ]
297
+ if numel_cnt >= bucket_size :
298
+ set_model_state_dict (model , update_bucket , options = options )
299
+ update_bucket = {}
300
+ numel_cnt = 0
301
+ set_model_state_dict (model , update_bucket , options = options )
302
+
303
+ # load buffer data because persistent maybe False
302
304
if dist .get_rank ()== 0 :
303
305
for name , buf in model .named_buffers ():
304
306
buf .data .copy_ (buffer_dict [name ])
@@ -323,7 +325,7 @@ def model_setup(self):
323
325
# resume model weights
324
326
if self .resume_training :
325
327
self .load_checkpoint (self ._episode_id )
326
- del full_state
328
+ del full_state , update_bucket
327
329
self .offload ()
328
330
329
331
def get_fsdp_param_name (self , block_size = 3_000_000_000 ) -> List [List ]:
0 commit comments