-
Notifications
You must be signed in to change notification settings - Fork 176
feat: add SGLang rollout backend, part1 [WIP] #1580
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
3daa5d3 to
2e7b82a
Compare
cb2f593 to
1bb6f25
Compare
285512b to
035db3d
Compare
f50351d to
0b91f05
Compare
terrykong
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
awesome work @PrinsYin
just an FYI of a PR that's in flight https://github.com/NVIDIA-NeMo/RL/pull/1567/files
left some comments
dc4cb9e to
9564a03
Compare
Signed-off-by: Ryan <[email protected]> Signed-off-by: Zhuoran Yin <[email protected]>
…a server Signed-off-by: Ryan <[email protected]> Signed-off-by: Zhuoran Yin <[email protected]>
…p servers Signed-off-by: Ryan <[email protected]> Signed-off-by: Zhuoran Yin <[email protected]>
Signed-off-by: Ryan <[email protected]> Signed-off-by: Zhuoran Yin <[email protected]>
Signed-off-by: Ryan <[email protected]> Signed-off-by: Zhuoran Yin <[email protected]>
Signed-off-by: Ryan <[email protected]> Signed-off-by: Zhuoran Yin <[email protected]>
Signed-off-by: Ryan <[email protected]> Signed-off-by: Zhuoran Yin <[email protected]>
Signed-off-by: Ryan <[email protected]> Signed-off-by: Zhuoran Yin <[email protected]>
Signed-off-by: Ryan <[email protected]> Signed-off-by: Zhuoran Yin <[email protected]>
Signed-off-by: Ryan <[email protected]> Signed-off-by: Zhuoran Yin <[email protected]>
Signed-off-by: Ryan <[email protected]> Signed-off-by: Zhuoran Yin <[email protected]>
Signed-off-by: Ryan <[email protected]> Signed-off-by: Zhuoran Yin <[email protected]>
Signed-off-by: Ryan <[email protected]> Signed-off-by: Zhuoran Yin <[email protected]>
Signed-off-by: Ryan <[email protected]> Signed-off-by: Zhuoran Yin <[email protected]>
sglang: add 1B example Signed-off-by: Ryan <[email protected]> Signed-off-by: Zhuoran Yin <[email protected]>
Signed-off-by: Ryan <[email protected]> Signed-off-by: Zhuoran Yin <[email protected]>
Signed-off-by: Ryan <[email protected]> Signed-off-by: Zhuoran Yin <[email protected]>
Signed-off-by: Ryan <[email protected]> Signed-off-by: Zhuoran Yin <[email protected]>
Signed-off-by: Ryan <[email protected]> Signed-off-by: Zhuoran Yin <[email protected]>
Signed-off-by: Zhuoran Yin <[email protected]>
Signed-off-by: Zhuoran Yin <[email protected]>
Signed-off-by: Zhuoran Yin <[email protected]>
- Convert SGLangConfig from regular class to TypedDict inheriting GenerationConfig - Align structure with VllmConfig pattern for consistency - Mark all fields as NotRequired for backward compatibility - Add sglang_kwargs field for additional ServerArgs parameters - Add type casting in grpo.py for type safety This maintains backward compatibility while aligning with the existing generation config structure pattern. Signed-off-by: Zhuoran Yin <[email protected]>
Signed-off-by: Zhuoran Yin <[email protected]>
Signed-off-by: Zhuoran Yin <[email protected]>
Co-authored-by: Shanmugam Ramasamy <[email protected]> Signed-off-by: Zhuoran Yin <[email protected]>
b6d8ba8 to
843a06a
Compare
📝 WalkthroughWalkthroughThis PR introduces SGLang as a new distributed generation backend alongside vLLM. It includes SGLang-specific configuration, worker implementation, HTTP-based weight streaming for colocated inference, integration into GRPO training, generalized metrics handling, and comprehensive setup/documentation changes. Changes
Sequence Diagram(s)sequenceDiagram
participant GRPO as GRPO Training Loop
participant GenInit as Initialize Generation<br/>(with policy)
participant Policy as Policy Worker
participant SGLangGen as SGLang Generation<br/>Worker
participant SGLang as SGLang Server
GRPO->>GenInit: Call initialize_generation_with_policy<br/>(colocated_inference=true)
alt Parallel Initialization (colocated)
par Init Policy
GenInit->>Policy: Initialize policy
Policy-->>GenInit: policy ready
and Init Generation
GenInit->>SGLangGen: Initialize generation
SGLangGen->>SGLang: Launch server process
SGLang-->>SGLangGen: Server running
SGLangGen-->>GenInit: generation ready
end
end
GenInit-->>GRPO: (policy, generation, timing_metrics)
GRPO->>SGLangGen: Generate samples (batched data)
SGLangGen->>SGLangGen: Shard data across DP axis
SGLangGen->>SGLang: POST /generate (HTTP request)
SGLang->>SGLang: Execute generation per sample
SGLang-->>SGLangGen: Batch results (token ids, logprobs)
SGLangGen->>SGLangGen: Aggregate & pad results
SGLangGen-->>GRPO: BatchedDataDict output
GRPO->>Policy: Call refit_policy_generation()
Policy->>GRPO: Generate samples with updated policy
Note over GRPO,Policy: If colocated and SGLang:
GRPO->>Policy: stream_weights_via_http(sglang_urls)
Policy->>Policy: Convert DTensor params to local tensors
Policy->>SGLang: POST /update_weights_from_tensor (HTTP)
SGLang->>SGLang: Update model weights
SGLang-->>Policy: ACK
Policy-->>GRPO: Weights streamed & synced
GRPO->>SGLangGen: Get metrics: generation_logger_metrics
SGLangGen-->>GRPO: {inflight_batch_sizes, pending_samples}
GRPO->>GRPO: Log metrics to WandB
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Areas requiring extra attention:
Possibly related PRs
Suggested labels
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
nemo_rl/distributed/virtual_cluster.py (1)
43-63: Add--directory {git_root}to new SGLang executables for consistencyThe new
AUTOMODEL_SGLANGandSGLANGentries omit--directory {git_root}, while all otheruv runentries include it. This inconsistency means these executables will behave differently depending on the caller's current working directory, potentially failing to locatepyproject.toml.Suggested fix:
# Use NeMo-RL direct dependencies, nemo-automodel, and SGLang. - AUTOMODEL_SGLANG = "uv run --locked --extra automodel --extra sglang" + AUTOMODEL_SGLANG = ( + f"uv run --locked --extra automodel --extra sglang --directory {git_root}" + ) # Use NeMo-RL direct dependencies and SGLang. - SGLANG = "uv run --locked --extra sglang" + SGLANG = f"uv run --locked --extra sglang --directory {git_root}"
🧹 Nitpick comments (17)
docs/design-docs/generation.md (1)
11-21: Clarifymodel_nameoptionality to matchGenerationConfigIn code,
GenerationConfig.model_nameisNotRequired[str]and is sometimes filled by helpers likeconfigure_generation_config, but here it is documented as a requiredstr. To avoid confusion, consider clarifying that:
- Users typically set
model_name, but- Some flows may populate it automatically and treat it as optional in the TypedDict.
E.g. change the snippet comment to something like “Name or path of the model (may be populated by helpers).”
nemo_rl/models/generation/interfaces.py (1)
261-278: Optional metrics hooks are fine; consider silencing B027Adding
clear_logger_metrics/get_logger_metricsas optional hooks with no-op defaults is a reasonable design forGenerationInterface, and it matches the vLLM implementation.Ruff’s B027 warning (
empty method in an abstract base class) is effectively a false positive here. Two options:
- Keep them non-abstract (recommended) and silence B027, e.g.:
def clear_logger_metrics(self) -> None: # noqa: B027 """Clear logger metrics for performance reporting.""" # Optional hook; default is a no-op. return None
- Or add minimal behavior (e.g., a
return None) plus a clarifying comment, which also addresses the “empty” concern.I’d avoid making them abstract since that would force every backend to implement metrics even when not needed.
nemo_rl/models/policy/dtensor_policy_worker_v2.py (1)
1762-1805: HTTP weight streaming integration looks consistent with existing IPC pathThe new
stream_weights_via_httpcorrectly handlescpu_offload, derives the current GPU UUID, converts DTensors to full tensors with the target dtype, and delegates to the shared HTTP implementation; the sortedstate_dictiteration should give deterministic ordering across ranks. You might optionally factor the DTensor-to-local generator into a shared helper used by both HTTP and ZMQ paths to avoid duplication.nemo_rl/models/policy/utils.py (1)
15-27: HTTP weight streaming flow is coherent; consider a few small cleanupsThe end-to-end HTTP streaming path (GPU UUID → server selection → IPC handler gather → POST to
/update_weights_from_tensor) looks structurally correct and matches the intended SGLang contract. A few low-priority refinements you might consider:
- Remove or use unused parameters/locals:
rankandsglang_url_to_gpu_uuidsin_setup_ipc_gather_group,gather_groupin_gather_ipc_handlers, andshape/dtypein_send_tensor_to_sglang, to reduce cognitive load.- In
stream_weights_via_http_impl, if_setup_ipc_gather_groupreturns(None, None, None)(e.g., unexpecteddiststate or UUID mismatch), the function quietly becomes a no-op; raising a clear error in that case would make misconfiguration easier to diagnose.- In
_send_tensor_to_sglang, replace the bareexcept:aroundresponse.textwithexcept Exception:(and optionally log the exception) to avoid masking non-HTTP runtime issues while still enriching the error message.- Optionally rename the unused loop index
idxinfor idx, (name, tensor) in enumerate(tensor_list):to_to match intent and silence linters.Also applies to: 498-743
nemo_rl/models/generation/sglang/config.py (1)
20-97: SGLang config types look reasonable; align used keys with docs/YAMLThe
SglangSpecificArgs/SGLangConfigTypedDicts cleanly expose SGLang’sServerArgs-style fields without hard-coding defaults, which matches the config guidelines. For the subset of keys you actually expect users to set via NeMo RL configs, ensure their purpose, valid values, and recommended defaults are documented and reflected in exemplar YAMLs (e.g., the newgrpo_math_1B_sglang.yaml), so this large surface stays discoverable.Based on learnings, config TypedDict additions should be documented and mirrored in examples.
nemo_rl/models/generation/sglang/sglang_generation.py (6)
48-54: Remove unusedworkers_per_nodeparameter.The
workers_per_nodeparameter is declared but never used in the constructor. Consider removing it to avoid confusion, or document why it's reserved for future use.def __init__( self, cluster: RayVirtualCluster, config: SGLangConfig, name_prefix: str = "sglang_policy", - workers_per_node: Optional[Union[int, list[int]]] = None, ):
84-88: Use proper logging instead ofDirect
loggingmodule or the project's logger.+import logging + +logger = logging.getLogger(__name__) + if total_gpus % gpus_per_server != 0: - print( + logger.warning( f"[WARNING] Total GPUs ({total_gpus}) is not divisible by GPUs per server ({gpus_per_server}). " f"Will use {num_servers} servers, leaving {total_gpus % gpus_per_server} GPUs unused." )
319-328: Addstrict=Truetozip()for safer iteration.Without
strict=True, mismatched list lengths would silently truncate data. This could mask bugs whereurlsanduuids_listhave different sizes.# Create mapping url_to_uuids = {} - for url, uuids in zip(urls, uuids_list): + for url, uuids in zip(urls, uuids_list, strict=True): if url is not None and uuids is not None: url_to_uuids[url] = uuids
330-336: Stub methods should return consistent types.
prepare_for_generationandfinish_generationhave-> boolimplied by the interface but returnNonevia implicitpass. Either add explicitreturn Trueor update docstrings to explain the current stub behavior.def prepare_for_generation(self, *args: Any, **kwargs: Any) -> bool: """Wake workers up for colocated inference.""" - pass + return True def finish_generation(self, *args: Any, **kwargs: Any) -> bool: """Sleep workers and reset prefix cache.""" - pass + return True
338-345: Narrow the exception type in shutdown.Catching a broad
Exceptioncan hide unexpected errors. Consider catching more specific exceptions or at least logging the exception type for debugging.def shutdown(self) -> bool: """Shut down all SGLang workers and clean up resources.""" try: # Use the worker group's shutdown method with the worker's cleanup method return self.worker_group.shutdown(cleanup_method="shutdown") - except Exception as e: - print(f"Error during SGLang policy shutdown: {e}") + except (RuntimeError, ray.exceptions.RayError) as e: + print(f"Error during SGLang policy shutdown: {type(e).__name__}: {e}") return False
365-380: Consider restructuring to use else block for success path.The static analysis correctly identifies that the success print and return could be in an
elseblock for cleaner control flow.try: futures = self.worker_group.run_all_workers_single_data( "invalidate_kv_cache", run_rank_0_only_axes=["tensor_parallel"], ) results = ray.get(futures) results = [r for r in results if r is not None] success = all(result for result in results) if results else True - if success: - print("[sglang refit] All SGLang server caches flushed successfully", flush=True) - else: - print("[sglang refit] WARNING - Some SGLang server caches failed to flush", flush=True) - return success except Exception as e: print(f"[sglang refit] Error flushing SGLang caches: {e}", flush=True) return False + else: + if success: + print("[sglang refit] All SGLang server caches flushed successfully", flush=True) + else: + print("[sglang refit] WARNING - Some SGLang server caches failed to flush", flush=True) + return successnemo_rl/models/generation/sglang/sglang_worker.py (5)
116-122: Unusedfraction_of_gpusparameter.The
fraction_of_gpusparameter is passed fromconfigure_workerbut never used in__init__. Either use it or remove it from both places.def __init__( self, config: SGLangConfig, bundle_indices: Optional[list[int]] = None, - fraction_of_gpus: float = 1.0, seed: Optional[int] = None, ):Also update
configure_workerto not pass this:- init_kwargs["fraction_of_gpus"] = num_gpus
322-331: Unusedstop_stringsparameter.The
stop_stringsparameter is declared but not used in_build_sampling_params. The actual stop string handling happens per-sample in_generate_single_sample. Remove this parameter to avoid confusion.def _build_sampling_params( self, *, greedy: bool, - stop_strings, max_new_tokens: Optional[int] = None, input_len: Optional[int] = None, context_length: Optional[int] = None, sample_index: Optional[int] = None, ) -> dict[str, Any]: """Build sampling parameters dictionary for SGLang API. Args: greedy: Whether to use greedy decoding (temperature=0.0) - stop_strings: Merged stop strings (not used here, handled per sample) max_new_tokens: Override max_new_tokens from config if provided
463-467: Remove redundant exception handler.This exception handler catches and immediately re-raises without any additional handling. It adds no value and reduces code clarity.
async def wrap(idx, coro): async with semaphore: - try: - result = await coro - return idx, result - except Exception as e: - raise + result = await coro + return idx, result
590-593: Remove redundant exception handler.Same issue as lines 463-467 - catching and re-raising without any handling.
# Execute all requests concurrently using the dedicated event loop thread - try: - all_results = self.async_loop_thread.run(self._generate_async(tasks)) - except Exception as e: - raise + all_results = self.async_loop_thread.run(self._generate_async(tasks))
606-606: Unused loop variablenew_logprobs.In the first pass calculating
max_length,new_logprobsis unpacked but not used. Use_to indicate it's intentionally ignored.# First pass: calculate max_length - for i, (new_tokens, new_logprobs) in enumerate(all_results): + for i, (new_tokens, _) in enumerate(all_results): input_len = input_lengths[i].item() generation_length = len(new_tokens)nemo_rl/algorithms/grpo.py (1)
493-553: Remove unusedgeneration_nameparameter.The
generation_nameparameter is passed but never used in the function body. Either remove it or use it in logging/output.def initialize_generation_with_policy( init_generation_fn, - generation_name: str, init_time_key: str, colocated_inference: bool, worker_init_timing_metrics: dict, ): """ Generic function to initialize a generation engine (vLLM or SGLang) along with policy. Args: init_generation_fn: Function that initializes the generation engine (init_vllm or init_sglang) - generation_name: Name of the generation engine ("vLLM" or "SGLang") init_time_key: Key name for storing initialization time in metrics ("vllm_init_time_s" or "sglang_init_time_s")And update the call sites:
policy_generation, policy = initialize_generation_with_policy( init_generation_fn=init_vllm, - generation_name="vLLM", init_time_key="vllm_init_time_s",
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (19)
docs/design-docs/generation.md(5 hunks)examples/configs/grpo_math_1B_sglang.yaml(1 hunks)nemo_rl/algorithms/grpo.py(14 hunks)nemo_rl/algorithms/utils.py(1 hunks)nemo_rl/distributed/ray_actor_environment_registry.py(2 hunks)nemo_rl/distributed/virtual_cluster.py(2 hunks)nemo_rl/models/generation/interfaces.py(1 hunks)nemo_rl/models/generation/sglang/__init__.py(1 hunks)nemo_rl/models/generation/sglang/config.py(1 hunks)nemo_rl/models/generation/sglang/sglang_generation.py(1 hunks)nemo_rl/models/generation/sglang/sglang_worker.py(1 hunks)nemo_rl/models/generation/sglang/utils.py(1 hunks)nemo_rl/models/generation/vllm/vllm_generation.py(1 hunks)nemo_rl/models/policy/dtensor_policy_worker_v2.py(1 hunks)nemo_rl/models/policy/interfaces.py(1 hunks)nemo_rl/models/policy/lm_policy.py(1 hunks)nemo_rl/models/policy/utils.py(2 hunks)pyproject.toml(1 hunks)run.sh(1 hunks)
🧰 Additional context used
📓 Path-based instructions (6)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Conform code to Python 3.12+
Indent code with 4 spaces. Do not use tabs
Use snake_case for file names
Use PascalCase for class names
Use snake_case for function and method names
Use snake_case for local variables
Prefix variable names that start with a number with 'k' (e.g., k_99th_percentile)
Use upper snake_case with 'G' prefix for global variables (e.g., G_MY_GLOBAL)
Use upper snake_case for constants
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
Prefer docstrings over comments for interfaces that may be used outside a file
Reserve comments for code within a function or interfaces that are local to a file
If a piece of code is commented out, include a comment describing its usage and why it's commented out. Remove debug comments before merging
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx
Avoid using reflection when functionality can be easily achieved without reflection
When using try-except blocks, limit the except clause to the smallest set of specific errors possible
When using try-except blocks for duck-typing, keep the body of the try as small as possible and use the else block for logic
YAML is the single source of truth for configuration defaults. Do not set non-None defaults in code for configuration values
For required configuration attributes, access config directly and expect presence (e.g., policy_cfg['precision']) without hidden defaults
Use typing.NotRequired to mark optional attributes in TypedDict for configuration
When adding a new config key to a TypedDict subclass, document the key's purpose, valid values/types, and recommended default, and reflect the default in exemplar YAMLs under examples/configs/*.yaml
Follow the Google Python Style Guide for Python code
Files:
nemo_rl/models/generation/vllm/vllm_generation.pynemo_rl/models/generation/sglang/__init__.pynemo_rl/models/policy/lm_policy.pynemo_rl/models/generation/sglang/utils.pynemo_rl/models/generation/interfaces.pynemo_rl/models/policy/dtensor_policy_worker_v2.pynemo_rl/models/generation/sglang/sglang_worker.pynemo_rl/models/generation/sglang/config.pynemo_rl/distributed/virtual_cluster.pynemo_rl/models/policy/interfaces.pynemo_rl/models/generation/sglang/sglang_generation.pynemo_rl/models/policy/utils.pynemo_rl/algorithms/utils.pynemo_rl/algorithms/grpo.pynemo_rl/distributed/ray_actor_environment_registry.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
For any source file under nemo_rl/*.py that defines a class or function decorated with @ray.remote, add a coverage pragma (# pragma: no cover) because these run in separate Ray processes
Files:
nemo_rl/models/generation/vllm/vllm_generation.pynemo_rl/models/generation/sglang/__init__.pynemo_rl/models/policy/lm_policy.pynemo_rl/models/generation/sglang/utils.pynemo_rl/models/generation/interfaces.pynemo_rl/models/policy/dtensor_policy_worker_v2.pynemo_rl/models/generation/sglang/sglang_worker.pynemo_rl/models/generation/sglang/config.pynemo_rl/distributed/virtual_cluster.pynemo_rl/models/policy/interfaces.pynemo_rl/models/generation/sglang/sglang_generation.pynemo_rl/models/policy/utils.pynemo_rl/algorithms/utils.pynemo_rl/algorithms/grpo.pynemo_rl/distributed/ray_actor_environment_registry.py
!(**/tests/**|**/test_*.py|**/test_*.sh)
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Add the NVIDIA copyright header to all Python files and shell scripts (excluding tests). The header should include the current year
Files:
nemo_rl/models/generation/vllm/vllm_generation.pynemo_rl/models/generation/sglang/__init__.pynemo_rl/models/policy/lm_policy.pynemo_rl/models/generation/sglang/utils.pynemo_rl/models/generation/interfaces.pynemo_rl/models/policy/dtensor_policy_worker_v2.pynemo_rl/models/generation/sglang/sglang_worker.pyrun.shnemo_rl/models/generation/sglang/config.pynemo_rl/distributed/virtual_cluster.pynemo_rl/models/policy/interfaces.pyexamples/configs/grpo_math_1B_sglang.yamlnemo_rl/models/generation/sglang/sglang_generation.pydocs/design-docs/generation.mdnemo_rl/models/policy/utils.pynemo_rl/algorithms/utils.pynemo_rl/algorithms/grpo.pynemo_rl/distributed/ray_actor_environment_registry.pypyproject.toml
**/*.{py,sh}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
The NVIDIA copyright header should appear at the top of all Python files and shell scripts (excluding tests)
Files:
nemo_rl/models/generation/vllm/vllm_generation.pynemo_rl/models/generation/sglang/__init__.pynemo_rl/models/policy/lm_policy.pynemo_rl/models/generation/sglang/utils.pynemo_rl/models/generation/interfaces.pynemo_rl/models/policy/dtensor_policy_worker_v2.pynemo_rl/models/generation/sglang/sglang_worker.pyrun.shnemo_rl/models/generation/sglang/config.pynemo_rl/distributed/virtual_cluster.pynemo_rl/models/policy/interfaces.pynemo_rl/models/generation/sglang/sglang_generation.pynemo_rl/models/policy/utils.pynemo_rl/algorithms/utils.pynemo_rl/algorithms/grpo.pynemo_rl/distributed/ray_actor_environment_registry.py
**/*.sh
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.sh: Use uv run instead of python to execute scripts
Follow the Google Shell Style Guide for shell scripts
Files:
run.sh
docs/**/*.md
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Update docs/index.md when a new markdown doc is added under docs/**/*.md or a markdown file is renamed, ensuring the document appears in the most appropriate section
Files:
docs/design-docs/generation.md
🧠 Learnings (6)
📚 Learning: 2025-11-24T17:24:41.976Z
Learnt from: CR
Repo: NVIDIA-NeMo/RL PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-11-24T17:24:41.976Z
Learning: Applies to **/*.sh : Use uv run instead of python to execute scripts
Applied to files:
run.sh
📚 Learning: 2025-11-24T17:24:41.976Z
Learnt from: CR
Repo: NVIDIA-NeMo/RL PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-11-24T17:24:41.976Z
Learning: Applies to tests/test_suites/**/*.sh : Driver shell scripts should match the YAML base name with .sh extension and invoke training entrypoint with uv run
Applied to files:
run.sh
📚 Learning: 2025-11-24T17:24:41.976Z
Learnt from: CR
Repo: NVIDIA-NeMo/RL PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-11-24T17:24:41.976Z
Learning: Applies to **/*.py : When adding a new config key to a TypedDict subclass, document the key's purpose, valid values/types, and recommended default, and reflect the default in exemplar YAMLs under examples/configs/*.yaml
Applied to files:
nemo_rl/models/generation/sglang/config.py
📚 Learning: 2025-09-19T03:00:58.662Z
Learnt from: shuo-nvidia
Repo: NVIDIA-NeMo/RL PR: 1006
File: examples/configs/recipes/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-fsdp2tp1.v1.yaml:85-101
Timestamp: 2025-09-19T03:00:58.662Z
Learning: In distillation and GRPO configurations, max_new_tokens is intentionally set to the full context window (max_total_sequence_length) for consistency across the codebase. Overflow cases when prompt + generation tokens exceed max_model_len are handled by safeguards implemented in vllm_worker.py.
Applied to files:
examples/configs/grpo_math_1B_sglang.yaml
📚 Learning: 2025-09-18T14:57:31.003Z
Learnt from: zpqiu
Repo: NVIDIA-NeMo/RL PR: 1006
File: nemo_rl/algorithms/distillation.py:312-354
Timestamp: 2025-09-18T14:57:31.003Z
Learning: The distillation algorithm's cluster setup logic is designed to follow the same patterns used in GRPO for handling distributed training clusters and resource allocation.
Applied to files:
examples/configs/grpo_math_1B_sglang.yaml
📚 Learning: 2025-09-18T14:20:36.297Z
Learnt from: zpqiu
Repo: NVIDIA-NeMo/RL PR: 1006
File: examples/configs/recipes/llm/distillation-qwen3-32b-to-8b-base-2n8g-fsdp2tp2.v1.yaml:113-120
Timestamp: 2025-09-18T14:20:36.297Z
Learning: In distillation workflows, the teacher policy does not perform generation - it only does inference/logprob computation on sequences generated by the student policy. Therefore, teacher generation configuration mismatches (like vLLM tensor parallelism settings) and colocation concerns are not relevant.
Applied to files:
nemo_rl/algorithms/grpo.py
🧬 Code graph analysis (10)
nemo_rl/models/generation/sglang/__init__.py (3)
nemo_rl/models/generation/sglang/config.py (1)
SGLangConfig(93-96)nemo_rl/models/generation/sglang/sglang_generation.py (1)
SGLangGeneration(47-380)nemo_rl/models/generation/sglang/sglang_worker.py (1)
SGLangGenerationWorker(50-734)
nemo_rl/models/policy/lm_policy.py (3)
nemo_rl/models/policy/dtensor_policy_worker_v2.py (1)
stream_weights_via_http(1764-1805)nemo_rl/models/policy/interfaces.py (1)
stream_weights_via_http(189-199)nemo_rl/distributed/worker_groups.py (1)
run_all_workers_single_data(755-799)
nemo_rl/models/generation/sglang/utils.py (1)
nemo_rl/models/generation/sglang/sglang_worker.py (1)
shutdown(665-715)
nemo_rl/models/generation/interfaces.py (1)
nemo_rl/models/generation/vllm/vllm_generation.py (2)
clear_logger_metrics(879-881)get_logger_metrics(883-885)
nemo_rl/models/policy/dtensor_policy_worker_v2.py (3)
nemo_rl/utils/nsys.py (1)
wrap_with_nvtx_name(82-94)nemo_rl/models/policy/interfaces.py (1)
stream_weights_via_http(189-199)nemo_rl/models/policy/utils.py (1)
stream_weights_via_http_impl(498-618)
nemo_rl/models/generation/sglang/sglang_worker.py (5)
nemo_rl/distributed/batched_data_dict.py (1)
BatchedDataDict(75-860)nemo_rl/distributed/virtual_cluster.py (3)
_get_node_ip_local(73-77)_get_free_port_local(80-88)shutdown(483-502)nemo_rl/models/generation/interfaces.py (3)
GenerationDatumSpec(134-165)GenerationOutputSpec(168-212)verify_right_padding(23-99)nemo_rl/models/generation/sglang/config.py (1)
SGLangConfig(93-96)nemo_rl/models/generation/sglang/utils.py (3)
AsyncLoopThread(19-62)run(41-54)shutdown(56-62)
nemo_rl/models/generation/sglang/config.py (2)
nemo_rl/models/generation/interfaces.py (1)
GenerationConfig(118-131)nemo_rl/distributed/worker_groups.py (1)
dp_size(627-629)
nemo_rl/models/generation/sglang/sglang_generation.py (3)
nemo_rl/distributed/batched_data_dict.py (2)
BatchedDataDict(75-860)from_batches(102-172)nemo_rl/distributed/named_sharding.py (4)
NamedSharding(19-222)layout(99-101)names(84-86)get_axis_size(209-211)nemo_rl/models/policy/interfaces.py (2)
init_collective(166-169)prepare_refit_info(180-181)
nemo_rl/algorithms/utils.py (1)
tests/check_metrics.py (1)
max(30-32)
nemo_rl/distributed/ray_actor_environment_registry.py (1)
nemo_rl/distributed/virtual_cluster.py (1)
PY_EXECUTABLES(43-65)
🪛 Ruff (0.14.7)
nemo_rl/models/generation/sglang/utils.py
33-33: Avoid specifying long messages outside the exception class
(TRY003)
51-51: Avoid specifying long messages outside the exception class
(TRY003)
nemo_rl/models/generation/interfaces.py
261-267: GenerationInterface.clear_logger_metrics is an empty method in an abstract base class, but has no abstract decorator
(B027)
nemo_rl/models/generation/sglang/sglang_worker.py
120-120: Unused method argument: fraction_of_gpus
(ARG002)
187-187: Possible binding to all interfaces
(S104)
254-254: Do not catch blind exception: Exception
(BLE001)
326-326: Unused method argument: stop_strings
(ARG002)
465-465: Consider moving this statement to an else block
(TRY300)
466-467: Remove exception handler; error is immediately re-raised
(TRY203)
466-466: Local variable e is assigned to but never used
Remove assignment to unused variable e
(F841)
504-504: Create your own exception
(TRY002)
504-504: Avoid specifying long messages outside the exception class
(TRY003)
556-556: Avoid specifying long messages outside the exception class
(TRY003)
592-593: Remove exception handler; error is immediately re-raised
(TRY203)
592-592: Local variable e is assigned to but never used
Remove assignment to unused variable e
(F841)
606-606: Loop control variable new_logprobs not used within loop body
(B007)
675-675: Do not catch blind exception: Exception
(BLE001)
691-691: Do not catch blind exception: Exception
(BLE001)
709-709: Consider moving this statement to an else block
(TRY300)
711-711: Do not catch blind exception: Exception
(BLE001)
732-732: Probable use of requests call without timeout
(S113)
nemo_rl/models/generation/sglang/sglang_generation.py
53-53: Unused method argument: workers_per_node
(ARG002)
70-72: Avoid specifying long messages outside the exception class
(TRY003)
79-82: Avoid specifying long messages outside the exception class
(TRY003)
170-170: Avoid specifying long messages outside the exception class
(TRY003)
205-209: Avoid specifying long messages outside the exception class
(TRY003)
214-214: Unused method argument: ip
(ARG002)
214-214: Unused method argument: port
(ARG002)
214-214: Unused method argument: world_size
(ARG002)
214-214: Unused method argument: train_world_size
(ARG002)
265-267: Avoid specifying long messages outside the exception class
(TRY003)
287-287: Avoid specifying long messages outside the exception class
(TRY003)
307-307: Avoid specifying long messages outside the exception class
(TRY003)
324-324: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
343-343: Do not catch blind exception: Exception
(BLE001)
377-377: Consider moving this statement to an else block
(TRY300)
378-378: Do not catch blind exception: Exception
(BLE001)
nemo_rl/models/policy/utils.py
527-527: f-string without any placeholders
Remove extraneous f prefix
(F541)
537-540: Avoid specifying long messages outside the exception class
(TRY003)
570-570: Loop control variable idx not used within loop body
Rename unused idx to _idx
(B007)
622-622: Unused function argument: rank
(ARG001)
625-625: Unused function argument: sglang_url_to_gpu_uuids
(ARG001)
639-639: Local variable my_rank is assigned to but never used
Remove assignment to unused variable my_rank
(F841)
660-660: Unused function argument: gather_group
(ARG001)
700-700: Unused function argument: shape
(ARG001)
701-701: Unused function argument: dtype
(ARG001)
736-736: Do not use bare except
(E722)
736-737: try-except-pass detected, consider logging the exception
(S110)
741-743: Avoid specifying long messages outside the exception class
(TRY003)
nemo_rl/algorithms/grpo.py
495-495: Unused function argument: generation_name
(ARG001)
1009-1009: Local variable flush_success is assigned to but never used
Remove assignment to unused variable flush_success
(F841)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Post automodel integration comment / Comment on PR
- GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (16)
nemo_rl/models/generation/vllm/vllm_generation.py (1)
879-885: vLLM logger metrics integration looks goodThe new
clear_logger_metrics/get_logger_metricswrappers cleanly expose vLLM-specific metrics through the genericGenerationInterfacehooks and match the optional-implementation pattern.No changes needed here.
nemo_rl/models/policy/lm_policy.py (1)
754-777: Align HTTP streaming API with IPC streaming (return futures)This wrapper is consistent with
stream_weights_via_ipc_zmq: it fan-outs to all workers and returns alist[ray.ObjectRef]so callers canray.getexternally.However,
ColocatablePolicyInterface.stream_weights_via_httpis currently typed to returnNone, which is inconsistent with this implementation and with the IPC streaming API.Update
ColocatablePolicyInterface.stream_weights_via_httpto returnlist[ray.ObjectRef](same asstream_weights_via_ipc_zmq). No change to this function's body is needed once the interface is updated.nemo_rl/models/policy/interfaces.py (1)
189-199: Verify return type consistency forstream_weights_via_httpThe interface currently declares
stream_weights_via_httpto returnNone, but the review suggests the concrete implementation inPolicy.stream_weights_via_httpreturnslist[ray.ObjectRef], analogous tostream_weights_via_ipc_zmq. If this is accurate, the interface return type should be updated to match the actual implementation and align the HTTP and IPC streaming APIs.Suggested adjustment (pending verification of concrete implementation):
- def stream_weights_via_http( - self, sglang_url_to_gpu_uuids: dict[str, list[str]] - ) -> None: + def stream_weights_via_http( + self, sglang_url_to_gpu_uuids: dict[str, list[str]] + ) -> list[ray.ObjectRef]: """Stream model weights to SGLang servers via HTTP API. Args: sglang_url_to_gpu_uuids: Dict mapping SGLang server URL to list of GPU UUIDs it uses """ raise NotImplementedError( "stream_weights_via_http is not implemented for this policy worker" )examples/configs/grpo_math_1B_sglang.yaml (1)
209-236: Update misleading comment in SGLang configThe
data.max_input_seq_lengthcomment at lines 237-243 referencesvllm.max_model_len, which is specific to the vLLM backend. In this SGLang-focused configuration, the comment should referencesglang_cfg.context_lengthor use backend-agnostic language to avoid confusing readers about which backend-specific limits apply.run.sh (1)
1-19: Add NVIDIA header and useuv runinstead ofpythonThis script should follow the repo shell guidelines (if it is a non-test script):
- Add the NVIDIA copyright header at the top.
- Replace the bare
pythoninvocation withuv run(you already useuvfor venv/pip).A minimal fix:
-#!/bin/bash -set -e +#!/bin/bash +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e @@ -echo "CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES" - - -python examples/run_grpo_math.py --config "$CONFIG_FILE" +echo "CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES" + +uv run examples/run_grpo_math.py --config "$CONFIG_FILE"nemo_rl/models/generation/sglang/__init__.py (1)
1-22: Public SGLang exports are wired correctlyThe header, imports, and
__all__consistently exposeSGLangConfig,SGLangGeneration, andSGLangGenerationWorker; no changes needed.pyproject.toml (1)
81-100: SGLang optional dependency group is well-formedThe new
sglangextra cleanly groups SGLang and its companion packages alongside existingvllm/mcoreextras; no code changes required here.nemo_rl/algorithms/utils.py (1)
524-564: Generation logger metrics refactor is soundSwitching to
metrics.get("generation_logger_metrics", {})under the existing vLLM-logger flag and asserting the expected keys/types keeps behavior backward compatible while making the metrics source backend-agnostic; this block looks good.nemo_rl/distributed/ray_actor_environment_registry.py (1)
23-25: Actor-to-environment mappings for SGLang and DTensorPolicyWorkerV2 look correctThe new
SGLANG_EXECUTABLEconstant and registry entries cleanly routeSGLangGenerationWorkerto the SGLang extra, andDTensorPolicyWorkerV2toAUTOMODEL_SGLANG, which matches the added HTTP streaming and SGLang backend requirements; the additional SYSTEM/PENGUIN mappings are also consistent with their dependencies.Also applies to: 33-39
nemo_rl/models/generation/sglang/sglang_generation.py (1)
213-221: Stub implementation is acceptable but needs documentation clarification.The
init_collectivemethod returns an empty list while accepting multiple parameters. The docstring mentions a TODO but doesn't explain when this will be implemented. This is acceptable as an interface stub, but consider adding a note about current behavior.nemo_rl/models/generation/sglang/sglang_worker.py (1)
185-188: Binding to all interfaces (0.0.0.0) is intentional for distributed setup.This allows the SGLang server to accept connections from other nodes in the cluster. While the static analysis flags this as a security concern, it appears necessary for the distributed inference architecture.
Verify that the network environment where this runs has appropriate firewall rules to limit access to these ports.
nemo_rl/algorithms/grpo.py (5)
65-65: SGLang imports added correctly.The import follows the existing pattern for VllmConfig/VllmGeneration.
486-491: SGLang initialization follows vLLM pattern correctly.The
init_sglangfunction mirrors theinit_vllmpattern with proper timing andfinish_generation()call.
609-627: SGLang backend initialization looks correct.The SGLang path:
- Casts config to SGLangConfig
- Sets model_path from policy config if not already set
- Uses the common
initialize_generation_with_policyhelper- Prints confirmation message
This follows the pattern established for vLLM and addresses the past review comment about avoiding code duplication.
1221-1223: Generation logger metrics handling is properly generalized.The code now uses generic
get_logger_metrics()andclear_logger_metrics()instead of vLLM-specific methods, supporting both vLLM and SGLang backends as discussed in past review comments.Also applies to: 1273-1276
2104-2106: Async GRPO correctly handles generation logger metrics.The async training path properly:
- Clears metrics at training start (line 2104-2106)
- Collects metrics before refit (line 2353-2356)
- Clears metrics after refit (line 2371-2373)
This ensures metrics are collected per-cycle and reset appropriately.
Also applies to: 2346-2356, 2371-2373
| if isinstance(policy_generation, SGLangGeneration): | ||
| sglang_url_to_gpu_uuids = policy_generation.get_sglang_url_to_gpu_uuids() | ||
| # Stream weights via HTTP | ||
| flush_success = policy_generation.invalidate_kv_cache() | ||
| futures_train = policy.stream_weights_via_http( | ||
| sglang_url_to_gpu_uuids=sglang_url_to_gpu_uuids, | ||
| ) | ||
| # Wait for all workers to complete | ||
| ray.get(futures_train) | ||
| update_success = True | ||
| else: | ||
| # Original ZMQ IPC path for vLLM | ||
| futures_train = policy.stream_weights_via_ipc_zmq( | ||
| buffer_size_bytes=buffer_size_bytes | ||
| ) | ||
| futures_inference = policy_generation.update_weights_via_ipc_zmq() | ||
| # wait for all futures to complete | ||
| ray.get(futures_train) | ||
| results = ray.get(futures_inference) | ||
| update_success = all(result for result in results if result is not None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handle flush_success result from invalidate_kv_cache.
The flush_success variable is assigned but never used. If the cache flush fails, the weight update proceeds anyway which could lead to stale cache issues. Consider at least logging a warning.
if isinstance(policy_generation, SGLangGeneration):
sglang_url_to_gpu_uuids = policy_generation.get_sglang_url_to_gpu_uuids()
# Stream weights via HTTP
- flush_success = policy_generation.invalidate_kv_cache()
+ flush_success = policy_generation.invalidate_kv_cache()
+ if not flush_success:
+ print(
+ "[WARNING] KV cache invalidation failed before weight update. "
+ "Proceeding with weight update, but cached results may be stale."
+ )
futures_train = policy.stream_weights_via_http(
sglang_url_to_gpu_uuids=sglang_url_to_gpu_uuids,
)🧰 Tools
🪛 Ruff (0.14.7)
1009-1009: Local variable flush_success is assigned to but never used
Remove assignment to unused variable flush_success
(F841)
🤖 Prompt for AI Agents
In nemo_rl/algorithms/grpo.py around lines 1006 to 1025, the variable
flush_success returned from policy_generation.invalidate_kv_cache() is assigned
but ignored; check its boolean result and handle failures: if flush_success is
False, log a warning (use the module/process logger available in this scope or
policy_generation.logger) and set update_success = False (and skip or
short-circuit the subsequent stream_weights_via_http call) so weight updates do
not proceed with a stale cache; otherwise proceed as before and set
update_success = True after successful ray.get.
| def _launch_server_process(self, server_args: ServerArgs) -> multiprocessing.Process: | ||
| """Launch the SGLang server process and wait for it to be ready.""" | ||
| p = multiprocessing.Process(target=launch_server, args=(server_args,)) | ||
| p.start() | ||
|
|
||
| # Wait for server to be ready by checking health endpoint | ||
| # Use the base_url we stored earlier | ||
| headers = { | ||
| "Content-Type": "application/json; charset=utf-8", | ||
| } | ||
|
|
||
| with requests.Session() as session: | ||
| while True: | ||
| try: | ||
| response = session.get(f"{self.base_url}/health_generate", headers=headers) | ||
| if response.status_code == 200: | ||
| print(f"[SGLang Server] Rank {self.global_rank} Server is ready at {self.base_url}") | ||
| break | ||
| except requests.RequestException: | ||
| pass | ||
|
|
||
| if not p.is_alive(): | ||
| raise Exception(f"[SGLang Server] Rank {self.global_rank} Server process terminated unexpectedly.") | ||
|
|
||
| time.sleep(2) | ||
| return p |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Server startup loop lacks timeout - potential infinite hang.
_launch_server_process polls the health endpoint indefinitely. If the server fails to start properly but the process remains alive, this will hang forever. Add a maximum wait time.
def _launch_server_process(self, server_args: ServerArgs) -> multiprocessing.Process:
"""Launch the SGLang server process and wait for it to be ready."""
p = multiprocessing.Process(target=launch_server, args=(server_args,))
p.start()
# Wait for server to be ready by checking health endpoint
- # Use the base_url we stored earlier
headers = {
"Content-Type": "application/json; charset=utf-8",
}
+ max_wait_time = 300 # 5 minutes timeout
+ start_time = time.time()
with requests.Session() as session:
while True:
+ if time.time() - start_time > max_wait_time:
+ kill_process_tree(p.pid)
+ raise TimeoutError(
+ f"[SGLang Server] Rank {self.global_rank} Server failed to start within {max_wait_time}s"
+ )
try:
response = session.get(f"{self.base_url}/health_generate", headers=headers)
if response.status_code == 200:
print(f"[SGLang Server] Rank {self.global_rank} Server is ready at {self.base_url}")
break
except requests.RequestException:
pass
if not p.is_alive():
- raise Exception(f"[SGLang Server] Rank {self.global_rank} Server process terminated unexpectedly.")
+ raise RuntimeError(f"[SGLang Server] Rank {self.global_rank} Server process terminated unexpectedly.")
time.sleep(2)
return pCommittable suggestion skipped: line range outside the PR's diff.
🧰 Tools
🪛 Ruff (0.14.7)
504-504: Create your own exception
(TRY002)
504-504: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In nemo_rl/models/generation/sglang/sglang_worker.py around lines 482 to 507,
the server startup loop polls the health endpoint forever which can hang if the
process stays alive but never becomes ready; add a maximum wait timeout (e.g.,
default 60 seconds or configurable via ServerArgs) by capturing a start time
before the loop, checking elapsed time on each iteration, and if exceeded: stop
polling, terminate or kill the child process, and raise a clear exception that
includes elapsed time, base_url and process status; ensure the timeout value is
configurable and documented in ServerArgs so callers can adjust it.
| try: | ||
| if hasattr(self, "session") and self.session is not None: | ||
| try: | ||
| async def close_session(): | ||
| await self.session.close() | ||
| if self.connector is not None: | ||
| await self.connector.close() | ||
|
|
||
| self.async_loop_thread.run(close_session()) | ||
| print(f"[SGLang Worker] Rank {self.global_rank} aiohttp session closed.") | ||
| except Exception as e: | ||
| print(f"[SGLang Worker] Rank {self.global_rank} Error closing aiohttp session: {e}") | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Session cleanup may fail after async_loop_thread shutdown.
At line 689, self.async_loop_thread.run(close_session()) is called, but async_loop_thread.shutdown() was already called at line 673. This could cause errors since the event loop thread is already stopped.
def shutdown(self) -> bool:
"""Shutdown the SGLang server process and cleanup async resources."""
- if hasattr(self, "async_loop_thread"):
- try:
- self.async_loop_thread.shutdown()
- print(f"[SGLang Worker] Rank {self.global_rank} Async loop thread shut down.")
- except Exception as e:
- print(f"[SGLang Worker] Rank {self.global_rank} Error shutting down async loop thread: {e}")
-
if not self.is_model_owner:
+ if hasattr(self, "async_loop_thread"):
+ try:
+ self.async_loop_thread.shutdown()
+ except Exception as e:
+ print(f"[SGLang Worker] Rank {self.global_rank} Error shutting down async loop thread: {e}")
return True
try:
if hasattr(self, "session") and self.session is not None:
try:
async def close_session():
await self.session.close()
if self.connector is not None:
await self.connector.close()
self.async_loop_thread.run(close_session())
print(f"[SGLang Worker] Rank {self.global_rank} aiohttp session closed.")
except Exception as e:
print(f"[SGLang Worker] Rank {self.global_rank} Error closing aiohttp session: {e}")
+ # Shutdown async loop thread after session cleanup
+ if hasattr(self, "async_loop_thread"):
+ try:
+ self.async_loop_thread.shutdown()
+ print(f"[SGLang Worker] Rank {self.global_rank} Async loop thread shut down.")
+ except Exception as e:
+ print(f"[SGLang Worker] Rank {self.global_rank} Error shutting down async loop thread: {e}")
+
if not hasattr(self, "server_process") or self.server_process is None:
return True📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| try: | |
| if hasattr(self, "session") and self.session is not None: | |
| try: | |
| async def close_session(): | |
| await self.session.close() | |
| if self.connector is not None: | |
| await self.connector.close() | |
| self.async_loop_thread.run(close_session()) | |
| print(f"[SGLang Worker] Rank {self.global_rank} aiohttp session closed.") | |
| except Exception as e: | |
| print(f"[SGLang Worker] Rank {self.global_rank} Error closing aiohttp session: {e}") | |
| def shutdown(self) -> bool: | |
| """Shutdown the SGLang server process and cleanup async resources.""" | |
| if not self.is_model_owner: | |
| if hasattr(self, "async_loop_thread"): | |
| try: | |
| self.async_loop_thread.shutdown() | |
| except Exception as e: | |
| print(f"[SGLang Worker] Rank {self.global_rank} Error shutting down async loop thread: {e}") | |
| return True | |
| try: | |
| if hasattr(self, "session") and self.session is not None: | |
| try: | |
| async def close_session(): | |
| await self.session.close() | |
| if self.connector is not None: | |
| await self.connector.close() | |
| self.async_loop_thread.run(close_session()) | |
| print(f"[SGLang Worker] Rank {self.global_rank} aiohttp session closed.") | |
| except Exception as e: | |
| print(f"[SGLang Worker] Rank {self.global_rank} Error closing aiohttp session: {e}") | |
| # Shutdown async loop thread after session cleanup | |
| if hasattr(self, "async_loop_thread"): | |
| try: | |
| self.async_loop_thread.shutdown() | |
| print(f"[SGLang Worker] Rank {self.global_rank} Async loop thread shut down.") | |
| except Exception as e: | |
| print(f"[SGLang Worker] Rank {self.global_rank} Error shutting down async loop thread: {e}") | |
| if not hasattr(self, "server_process") or self.server_process is None: | |
| return True |
🧰 Tools
🪛 Ruff (0.14.7)
691-691: Do not catch blind exception: Exception
(BLE001)
| url = f"{self.base_url}/{endpoint}" | ||
| headers = { | ||
| "Content-Type": "application/json; charset=utf-8", | ||
| } | ||
| response = requests.post(url, json=payload or {}, headers=headers) | ||
| response.raise_for_status() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add timeout to requests.post to prevent indefinite hangs.
The _make_request method makes HTTP requests without a timeout, which could hang indefinitely if the server becomes unresponsive.
def _make_request(self, endpoint: str, payload: Optional[dict] = None):
"""Make a POST request to the specified endpoint with the given payload."""
url = f"{self.base_url}/{endpoint}"
headers = {
"Content-Type": "application/json; charset=utf-8",
}
- response = requests.post(url, json=payload or {}, headers=headers)
+ response = requests.post(url, json=payload or {}, headers=headers, timeout=60)
response.raise_for_status()
return response.json()🧰 Tools
🪛 Ruff (0.14.7)
732-732: Probable use of requests call without timeout
(S113)
| def shutdown(self): | ||
| """Shutdown the event loop and wait for the thread to finish.""" | ||
| if self.loop.is_running(): | ||
| self.loop.call_soon_threadsafe(self.loop.stop) | ||
| self._thread.join(timeout=2.0) | ||
| if self.loop.is_running(): | ||
| self.loop.close() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix event-loop shutdown logic to avoid leaking the loop
shutdown() only calls loop.close() when loop.is_running() after the join, which means in the normal case the loop is never closed and its resources are leaked; and in the rare case it’s still running, close() would be unsafe. Consider:
def shutdown(self):
"""Shutdown the event loop and wait for the thread to finish."""
- if self.loop.is_running():
- self.loop.call_soon_threadsafe(self.loop.stop)
- self._thread.join(timeout=2.0)
- if self.loop.is_running():
- self.loop.close()
+ if self.loop.is_running():
+ self.loop.call_soon_threadsafe(self.loop.stop)
+ self._thread.join(timeout=2.0)
+ if not self.loop.is_closed():
+ self.loop.close()
What does this PR do ?
Add SGLang rollout backend, including initilization, generation and refit.
RFC: https://icn9gp2qrfay.feishu.cn/wiki/DfC1wU1UkiRJyGklg1ActrQpnEb
Current Status
Colocated mode for SGLang + FSDP is fully functional:
1. Initiated Ray workers and allocated GPUs to SGLang servers according to the provided configs.
2. Implemented DP sharding, parameter handling, generation workflow, and result collection.
3. Added support for updating weights from tensors (FSDP).
4. Provided environment setup and example usage for running the full pipeline end to end.
Remaining Work / TODOs
Results compared to vllm baseline (Qwen 2.5-1.5B)

Before your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
Release Notes
New Features
Documentation
Configuration
top_kandmodel_namefields.✏️ Tip: You can customize this high-level summary in your review settings.