Skip to content

Commit 2e7b82a

Browse files
committed
flush cache before update begins
1 parent 86757f6 commit 2e7b82a

File tree

4 files changed

+47
-21
lines changed

4 files changed

+47
-21
lines changed

nemo_rl/algorithms/grpo.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -984,6 +984,12 @@ def refit_policy_generation(
984984

985985
# Stream weights via HTTP
986986
# Each training worker will match its GPU UUID to the corresponding SGLang server
987+
# Megatron-style: Flush cache before weight updates
988+
print("[sglang refit] Flushing KV cache before weight updates (Megatron-style)...", flush=True)
989+
flush_success = policy_generation.invalidate_kv_cache()
990+
if not flush_success:
991+
print("[sglang refit] WARNING - Cache flush had issues, but continuing with weight update", flush=True)
992+
987993
print("[sglang refit] Starting weight streaming via HTTP...", flush=True)
988994
futures_train = policy.stream_weights_via_http(
989995
sglang_url_to_gpu_uuids=sglang_url_to_gpu_uuids,

nemo_rl/models/generation/sglang/sglang_generation.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -353,22 +353,27 @@ def __del__(self) -> None:
353353
self.shutdown()
354354

355355
def invalidate_kv_cache(self) -> bool:
356-
"""Invalidate KV cache after weight updates.
356+
"""Invalidate KV cache before weight updates (Megatron-style).
357357
358-
For SGLang, this might need to call a different method or might not be needed
359-
if the server handles it automatically.
358+
This flushes the cache before weight updates to clear stale cache.
359+
Only primary workers (TP rank 0, model owners) will flush their cache.
360+
361+
Returns:
362+
bool: True if all caches were flushed successfully, False otherwise
360363
"""
361364
try:
362-
# For SGLang, we can call a method on each worker if it exists
363-
futures = []
364-
for worker in self.worker_group.workers:
365-
if hasattr(worker, "invalidate_kv_cache"):
366-
futures.append(worker.invalidate_kv_cache.remote())
367-
368-
if futures:
369-
results = ray.get(futures)
370-
return all(result for result in results if result is not None)
371-
return True
365+
futures = self.worker_group.run_all_workers_single_data(
366+
"invalidate_kv_cache",
367+
run_rank_0_only_axes=["tensor_parallel"],
368+
)
369+
results = ray.get(futures)
370+
results = [r for r in results if r is not None]
371+
success = all(result for result in results) if results else True
372+
if success:
373+
print("[sglang refit] All SGLang server caches flushed successfully", flush=True)
374+
else:
375+
print("[sglang refit] WARNING - Some SGLang server caches failed to flush", flush=True)
376+
return success
372377
except Exception as e:
373-
print(f"Error invalidating SGLang caches: {e}")
378+
print(f"[sglang refit] Error flushing SGLang caches: {e}", flush=True)
374379
return False

nemo_rl/models/policy/dtensor_policy_worker_v2.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1757,8 +1757,10 @@ def stream_weights_via_http(
17571757
current_device_uuid = self.report_device_id()
17581758

17591759
def dtensor_params_generator():
1760-
"""Generator that yields (name, tensor) pairs, converting DTensors to local tensors."""
1761-
for name, tensor in self.model.state_dict().items():
1760+
"""Generator that yields (name, tensor) pairs, converting DTensors to local tensors.
1761+
"""
1762+
state_dict_items = sorted(self.model.state_dict().items(), key=lambda x: x[0])
1763+
for name, tensor in state_dict_items:
17621764
if isinstance(tensor, DTensor):
17631765
# Convert DTensor to full tensor for streaming
17641766
full_tensor = tensor.full_tensor()
@@ -1770,7 +1772,6 @@ def dtensor_params_generator():
17701772
else:
17711773
# Convert to target dtype
17721774
yield name, tensor.to(self.dtype, non_blocking=True).contiguous()
1773-
17741775
# Use the HTTP implementation
17751776
stream_weights_via_http_impl(
17761777
params_generator=dtensor_params_generator(),

nemo_rl/models/policy/utils.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,7 @@ def stream_weights_via_http_impl(
524524
from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
525525
except ImportError:
526526
from sglang.srt.patch_torch import monkey_patch_torch_reductions
527+
print(f"[sglang refit details] entering stream_weights_via_http_impl")
527528

528529
monkey_patch_torch_reductions()
529530

@@ -559,6 +560,13 @@ def stream_weights_via_http_impl(
559560
tensor_list = list(params_generator)
560561
total_tensors = len(tensor_list)
561562

563+
if rank == ipc_gather_src:
564+
print(
565+
f"[sglang refit details] {worker_name}: Starting weight update - "
566+
f"Total parameters to update: {total_tensors}",
567+
flush=True
568+
)
569+
562570
for idx, (name, tensor) in enumerate(tensor_list):
563571
torch.cuda.current_stream().synchronize()
564572
tensor = tensor.contiguous().cuda()
@@ -574,10 +582,9 @@ def stream_weights_via_http_impl(
574582
)
575583

576584
if rank == ipc_gather_src:
577-
is_last = (idx == total_tensors - 1)
578585
_send_tensor_to_sglang(
579586
url, name, gathered_handlers, tensor.shape, str(tensor.dtype),
580-
flush_cache=is_last
587+
flush_cache=False
581588
)
582589
tensor_count += 1
583590

@@ -586,11 +593,18 @@ def stream_weights_via_http_impl(
586593
del gathered_handlers
587594
torch.cuda.empty_cache()
588595

589-
if rank == 0:
596+
if rank == ipc_gather_src:
590597
print(
591-
f"[sglang refit] {worker_name}: Sent {tensor_count} tensors to SGLang server: {base_url}",
598+
f"[sglang refit details] {worker_name}: Weight update completed - "
599+
f"Successfully updated {tensor_count}/{total_tensors} parameters to SGLang server: {base_url}",
592600
flush=True
593601
)
602+
if tensor_count != total_tensors:
603+
print(
604+
f"[sglang refit details] {worker_name}: WARNING - Expected {total_tensors} tensors, "
605+
f"but only sent {tensor_count}",
606+
flush=True
607+
)
594608

595609
except Exception as e:
596610
print(

0 commit comments

Comments
 (0)