Skip to content

Commit 285512b

Browse files
committed
Fix SGLang compatibility: add hasattr checks for vLLM-specific methods
1 parent 1bb6f25 commit 285512b

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

nemo_rl/algorithms/grpo.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,7 +1146,8 @@ def grpo_train(
11461146
dynamic_sampling_num_gen_batches += 1
11471147
with timer.time("generation"):
11481148
# Clear vLLM logger metrics for each generation step
1149-
policy_generation.clear_vllm_logger_metrics()
1149+
if hasattr(policy_generation, "clear_vllm_logger_metrics"):
1150+
policy_generation.clear_vllm_logger_metrics()
11501151
# Use penguin rollouts if enabled. We cascade penguin first since penguin requires async rollouts.
11511152
if _should_use_penguin(master_config):
11521153
generation_config = master_config["policy"]["generation"]
@@ -1198,7 +1199,11 @@ def grpo_train(
11981199
policy_generation.finish_generation()
11991200
# Collect vLLM logger metrics for performance reporting after each generation step
12001201
# inflight batch sizes and num pending samples are collected from each vLLM worker
1201-
vllm_logger_metrics = policy_generation.get_vllm_logger_metrics()
1202+
vllm_logger_metrics = (
1203+
policy_generation.get_vllm_logger_metrics()
1204+
if hasattr(policy_generation, "get_vllm_logger_metrics")
1205+
else None
1206+
)
12021207

12031208
repeated_batch = scale_rewards(
12041209
repeated_batch, master_config["grpo"]["reward_scaling"]
@@ -1984,9 +1989,9 @@ def async_grpo_train(
19841989
trajectory_collector.resume.remote()
19851990

19861991
print("✅ All setup complete, starting buffer wait...")
1987-
19881992
# Clear vLLM logger metrics after at start of training
1989-
policy_generation.clear_vllm_logger_metrics()
1993+
if hasattr(policy_generation, "clear_vllm_logger_metrics"):
1994+
policy_generation.clear_vllm_logger_metrics()
19901995

19911996
# Wait for initial buffer fill
19921997
print(
@@ -2235,7 +2240,11 @@ def async_grpo_train(
22352240

22362241
# Collect vLLM logger metrics for performance reporting
22372242
# inflight batch sizes and num pending samples are collected from each vLLM worker
2238-
vllm_logger_metrics = policy_generation.get_vllm_logger_metrics()
2243+
vllm_logger_metrics = (
2244+
policy_generation.get_vllm_logger_metrics()
2245+
if hasattr(policy_generation, "get_vllm_logger_metrics")
2246+
else None
2247+
)
22392248

22402249
# Only the actual refit/weight transfer should be counted as weight_sync
22412250
print("🔄 Performing policy generation refit...")
@@ -2250,8 +2259,8 @@ def async_grpo_train(
22502259
trajectory_collector.set_weight_version.remote(weight_version)
22512260
trajectory_collector.resume_after_refit.remote()
22522261

2253-
# Clear vLLM logger metrics after each refit (weight sync), starting a new logging cycle
2254-
policy_generation.clear_vllm_logger_metrics()
2262+
if hasattr(policy_generation, "clear_vllm_logger_metrics"):
2263+
policy_generation.clear_vllm_logger_metrics()
22552264

22562265
# Validation
22572266
val_metrics, validation_timings = None, None

0 commit comments

Comments
 (0)