@@ -1146,7 +1146,8 @@ def grpo_train(
11461146 dynamic_sampling_num_gen_batches += 1
11471147 with timer .time ("generation" ):
11481148 # Clear vLLM logger metrics for each generation step
1149- policy_generation .clear_vllm_logger_metrics ()
1149+ if hasattr (policy_generation , "clear_vllm_logger_metrics" ):
1150+ policy_generation .clear_vllm_logger_metrics ()
11501151 # Use penguin rollouts if enabled. We cascade penguin first since penguin requires async rollouts.
11511152 if _should_use_penguin (master_config ):
11521153 generation_config = master_config ["policy" ]["generation" ]
@@ -1198,7 +1199,11 @@ def grpo_train(
11981199 policy_generation .finish_generation ()
11991200 # Collect vLLM logger metrics for performance reporting after each generation step
12001201 # inflight batch sizes and num pending samples are collected from each vLLM worker
1201- vllm_logger_metrics = policy_generation .get_vllm_logger_metrics ()
1202+ vllm_logger_metrics = (
1203+ policy_generation .get_vllm_logger_metrics ()
1204+ if hasattr (policy_generation , "get_vllm_logger_metrics" )
1205+ else None
1206+ )
12021207
12031208 repeated_batch = scale_rewards (
12041209 repeated_batch , master_config ["grpo" ]["reward_scaling" ]
@@ -1984,9 +1989,9 @@ def async_grpo_train(
19841989 trajectory_collector .resume .remote ()
19851990
19861991 print ("✅ All setup complete, starting buffer wait..." )
1987-
19881992 # Clear vLLM logger metrics after at start of training
1989- policy_generation .clear_vllm_logger_metrics ()
1993+ if hasattr (policy_generation , "clear_vllm_logger_metrics" ):
1994+ policy_generation .clear_vllm_logger_metrics ()
19901995
19911996 # Wait for initial buffer fill
19921997 print (
@@ -2235,7 +2240,11 @@ def async_grpo_train(
22352240
22362241 # Collect vLLM logger metrics for performance reporting
22372242 # inflight batch sizes and num pending samples are collected from each vLLM worker
2238- vllm_logger_metrics = policy_generation .get_vllm_logger_metrics ()
2243+ vllm_logger_metrics = (
2244+ policy_generation .get_vllm_logger_metrics ()
2245+ if hasattr (policy_generation , "get_vllm_logger_metrics" )
2246+ else None
2247+ )
22392248
22402249 # Only the actual refit/weight transfer should be counted as weight_sync
22412250 print ("🔄 Performing policy generation refit..." )
@@ -2250,8 +2259,8 @@ def async_grpo_train(
22502259 trajectory_collector .set_weight_version .remote (weight_version )
22512260 trajectory_collector .resume_after_refit .remote ()
22522261
2253- # Clear vLLM logger metrics after each refit (weight sync), starting a new logging cycle
2254- policy_generation .clear_vllm_logger_metrics ()
2262+ if hasattr ( policy_generation , "clear_vllm_logger_metrics" ):
2263+ policy_generation .clear_vllm_logger_metrics ()
22552264
22562265 # Validation
22572266 val_metrics , validation_timings = None , None
0 commit comments