Skip to content

Conversation

parthchadha
Copy link
Contributor

@parthchadha parthchadha commented Sep 24, 2025

What does this PR do ?

Add a one line overview of what this PR aims to accomplish.

Issues

List issues that this PR closes (syntax):

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Summary by CodeRabbit

  • New Features
    • Adds accurate next-token entropy computation that works across tensor-parallel vocabulary shards.
    • Exposes a new metric, full_entropy, alongside the existing approximate entropy for comparison.
    • Improves numerical stability of entropy calculations, providing more reliable training metrics.

@parthchadha parthchadha requested review from a team as code owners September 24, 2025 19:56
@parthchadha parthchadha changed the title Compute entropy across full vocab for logging feat: Compute entropy across full vocab for logging Sep 24, 2025
Copy link
Contributor

coderabbitai bot commented Sep 24, 2025

📝 Walkthrough

Walkthrough

Adds exact next-token entropy computation, including a tensor-parallel implementation. Integrates the distributed entropy into ClippedPGLossFn, handling sharded and non-sharded vocab cases, masking and reducing to per-sequence values. Returns both full and approximate entropy metrics. Introduces _compute_distributed_entropy in distributed utilities.

Changes

Cohort / File(s) Summary
Loss metrics integration
nemo_rl/algorithms/loss_functions.py
Imports _compute_distributed_entropy. Computes true entropy over next-token logits: uses distributed path when vocab is sharded, otherwise local softmax path. Applies masking and reduction to get seq_entropy. Returns metrics with full_entropy and preserves approx_entropy.
Distributed entropy computation
nemo_rl/distributed/model_utils.py
Adds _compute_distributed_entropy to compute numerically stable entropy across tensor-parallel shards using max-stabilized normalization, local log-probs, and all-reduce aggregation. No changes to existing functions.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Trainer
  participant ClippedPGLossFn as LossFn
  participant ModelUtils as DistributedUtils
  participant TPGroup as Tensor-Parallel Group

  Trainer->>LossFn: __call__(logits, mask, vocab_parallel_group?)
  alt Vocab is sharded
    LossFn->>DistributedUtils: _compute_distributed_entropy(logits_shard, tp_group)
    activate DistributedUtils
    DistributedUtils->>TPGroup: all_reduce(MAX) on logits for stabilization
    DistributedUtils->>DistributedUtils: compute local log-probs and probs
    DistributedUtils->>TPGroup: all_reduce(SUM) entropy contributions
    DistributedUtils-->>LossFn: per-token entropy
    deactivate DistributedUtils
  else Not sharded
    LossFn->>LossFn: softmax + log to get per-token entropy
  end
  LossFn->>LossFn: apply mask, reduce to seq_entropy
  LossFn-->>Trainer: loss, metrics{ full_entropy, approx_entropy, ... }
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning The PR adds a new distributed entropy computation and integrates it into the loss metrics, which can affect both numerics (stability across dtypes, masking, CP/TP interactions) and performance (additional all-reduces and per-token softmax-like work). The PR description, as provided, contains no test results, benchmarking, or convergence checks, and includes only a template with placeholders. Given the potential performance and numerical implications, this qualifies as a non-trivial change that requires evidence of testing and performance impact, which is currently absent. Please update the PR description to include: 1) unit/integration tests or logs confirming correctness of the full-vocab entropy under TP/CP with masking, 2) convergence or training stability checks showing no regression versus baseline, and 3) before/after performance numbers (e.g., tokens/sec or step time) with configuration details to quantify overhead of the added all-reduces. With this information, we can reassess and pass the check.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title succinctly describes the primary enhancement—computing entropy over the full vocabulary for logging—and aligns directly with the changes to add a distributed full_entropy metric alongside the existing approximate entropy. It is clear, specific, and free of extraneous details or generic phrasing.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch pchadha-add-full-entropy-log

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (4)
nemo_rl/distributed/model_utils.py (1)

96-104: Avoid extra log and epsilon; use -p * log_p via log_probs directly.

Computing log(probs + eps) adds noise and work. Use log_probs_local for exact -p*log p without eps.

Apply this diff:

-    probs_local = log_probs_local.exp()
-
-    # Compute local entropy contribution: -p * log(p)
-    eps = 1e-8
-    entropy_local = -probs_local * torch.log(probs_local + eps)
+    probs_local = log_probs_local.exp()
+    # Local entropy contribution: -p * log(p)
+    entropy_local = -(probs_local * log_probs_local)
nemo_rl/algorithms/loss_functions.py (3)

345-356: Use log_softmax for stability and compute -∑ p log p without epsilon in the non‑TP path.

Keeps computation in fp32 and avoids extra log(+eps).

Apply this diff:

-            next_token_logits_wo_last = next_token_logits[:, :-1, :]
-            probs = torch.softmax(next_token_logits_wo_last, dim=-1)
-            eps = 1e-8
-            log_probs = torch.log(probs + eps)
-            token_entropy = -torch.sum(probs * log_probs, dim=-1)
+            logits = next_token_logits[:, :-1, :].to(torch.float32)
+            log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
+            probs = log_probs.exp()
+            token_entropy = -(probs * log_probs).sum(dim=-1)

25-26: Avoid importing a private symbol across modules.

Consider renaming _compute_distributed_entropy to compute_distributed_entropy (public) or re‑exporting it to avoid leading underscore usage in external modules.


385-386: Metric name is clear; confirm units.

full_entropy is globally token‑normalized (masked_mean with global_valid_toks). If you need per‑sequence entropy later, consider also logging an unnormalized sum.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e22a340 and 41d901a.

📒 Files selected for processing (2)
  • nemo_rl/algorithms/loss_functions.py (3 hunks)
  • nemo_rl/distributed/model_utils.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts

Files:

  • nemo_rl/distributed/model_utils.py
  • nemo_rl/algorithms/loss_functions.py
nemo_rl/**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)

Files:

  • nemo_rl/distributed/model_utils.py
  • nemo_rl/algorithms/loss_functions.py
🧬 Code graph analysis (1)
nemo_rl/algorithms/loss_functions.py (2)
nemo_rl/distributed/model_utils.py (1)
  • _compute_distributed_entropy (59-110)
nemo_rl/algorithms/utils.py (1)
  • masked_mean (130-142)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Lint check
  • GitHub Check: Post submodule check comment / Comment on PR
  • GitHub Check: Post automodel integration comment / Comment on PR

Comment on lines +25 to 28
_compute_distributed_entropy,
from_parallel_logits_to_logprobs,
get_logprobs_from_vocab_parallel_logits,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Import the CP all-gather helper for full-entropy when CP is enabled.

You’ll need allgather_cp_sharded_tensor to assemble per-CP entropy before masking.

Apply this diff:

 from nemo_rl.distributed.model_utils import (
     _compute_distributed_entropy,
     from_parallel_logits_to_logprobs,
     get_logprobs_from_vocab_parallel_logits,
+    allgather_cp_sharded_tensor,
 )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
_compute_distributed_entropy,
from_parallel_logits_to_logprobs,
get_logprobs_from_vocab_parallel_logits,
)
from nemo_rl.distributed.model_utils import (
_compute_distributed_entropy,
from_parallel_logits_to_logprobs,
get_logprobs_from_vocab_parallel_logits,
allgather_cp_sharded_tensor,
)
🤖 Prompt for AI Agents
In nemo_rl/algorithms/loss_functions.py around lines 25 to 28, the allgather
helper for tensor-parallel (CP) should be imported but is missing; add the
symbol allgather_cp_sharded_tensor to the existing import list (alongside
_compute_distributed_entropy, from_parallel_logits_to_logprobs,
get_logprobs_from_vocab_parallel_logits) so the code can assemble per-CP entropy
before applying the mask.

Comment on lines +325 to +344
# Compute actual entropy across all vocab: H(π) = -∑_v π(v) * log(π(v))
with torch.no_grad():
if vocab_parallel_group is not None:
next_token_logits_trimmed = next_token_logits[
:, : data["input_ids"].shape[1] - 1, :
]

# Compute entropy across all vocabulary shards using distributed computation
token_entropy = _compute_distributed_entropy(
next_token_logits_trimmed,
group=vocab_parallel_group,
)

# Apply masking and global reduction
seq_entropy = masked_mean(
token_entropy,
mask,
global_normalization_factor=global_valid_toks,
)
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Handle context parallelism when computing full entropy; current code will shape‑mismatch with mask under CP.

token_entropy is computed on the local CP shard, but mask is for the full sequence. This will crash on values * mask. Gather across CP first, then slice to the unpadded length.

Apply this diff:

         with torch.no_grad():
             if vocab_parallel_group is not None:
-                next_token_logits_trimmed = next_token_logits[
-                    :, : data["input_ids"].shape[1] - 1, :
-                ]
-
-                # Compute entropy across all vocabulary shards using distributed computation
-                token_entropy = _compute_distributed_entropy(
-                    next_token_logits_trimmed,
-                    group=vocab_parallel_group,
-                )
-
-                # Apply masking and global reduction
-                seq_entropy = masked_mean(
-                    token_entropy,
-                    mask,
-                    global_normalization_factor=global_valid_toks,
-                )
+                # Trim logits in sequence dim (safe if shorter under CP)
+                next_token_logits_trimmed = next_token_logits[:, : data["input_ids"].shape[1] - 1, :]
+                # Compute TP-aggregated per-token entropy on this CP shard
+                token_entropy = _compute_distributed_entropy(
+                    next_token_logits_trimmed, group=vocab_parallel_group
+                )
+                # If CP is enabled, gather per-token entropy across CP to match mask shape
+                if context_parallel_group is not None:
+                    token_entropy = allgather_cp_sharded_tensor(
+                        token_entropy, context_parallel_group, seq_dim=1
+                    )
+                    token_entropy = token_entropy[:, : data["input_ids"].shape[1] - 1]
+                # Reduce with mask
+                seq_entropy = masked_mean(
+                    token_entropy, mask, global_normalization_factor=global_valid_toks
+                )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Compute actual entropy across all vocab: H(π) = -∑_v π(v) * log(π(v))
with torch.no_grad():
if vocab_parallel_group is not None:
next_token_logits_trimmed = next_token_logits[
:, : data["input_ids"].shape[1] - 1, :
]
# Compute entropy across all vocabulary shards using distributed computation
token_entropy = _compute_distributed_entropy(
next_token_logits_trimmed,
group=vocab_parallel_group,
)
# Apply masking and global reduction
seq_entropy = masked_mean(
token_entropy,
mask,
global_normalization_factor=global_valid_toks,
)
else:
# Compute actual entropy across all vocab: H(π) = -∑_v π(v) * log(π(v))
with torch.no_grad():
if vocab_parallel_group is not None:
# Trim logits in sequence dim (safe if shorter under CP)
next_token_logits_trimmed = next_token_logits[:, : data["input_ids"].shape[1] - 1, :]
# Compute TP-aggregated per-token entropy on this CP shard
token_entropy = _compute_distributed_entropy(
next_token_logits_trimmed, group=vocab_parallel_group
)
# If CP is enabled, gather per-token entropy across CP to match mask shape
if context_parallel_group is not None:
token_entropy = allgather_cp_sharded_tensor(
token_entropy, context_parallel_group, seq_dim=1
)
token_entropy = token_entropy[:, : data["input_ids"].shape[1] - 1]
# Reduce with mask
seq_entropy = masked_mean(
token_entropy, mask, global_normalization_factor=global_valid_toks
)
else:
🤖 Prompt for AI Agents
In nemo_rl/algorithms/loss_functions.py around lines 325 to 344, token_entropy
is computed only on the local context-parallel (CP) shard but the mask is for
the full sequence causing a shape mismatch; gather the per-shard token_entropy
across the CP group to reconstruct the full-vocab/full-sequence entropy tensor
(using the same CP group for collective gather), then slice/trim that gathered
tensor to data["input_ids"].shape[1] - 1 (to remove padding/extra time steps)
before calling masked_mean with the global mask and global_valid_toks; ensure
the collective uses the correct device/dtype and preserves batch dimension
ordering so masked_mean receives matching shapes.

@ffrujeri ffrujeri closed this Sep 24, 2025
@ffrujeri ffrujeri reopened this Sep 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants