-
Notifications
You must be signed in to change notification settings - Fork 144
feat: Compute entropy across full vocab for logging #1200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Parth Chadha <[email protected]>
📝 WalkthroughWalkthroughAdds exact next-token entropy computation, including a tensor-parallel implementation. Integrates the distributed entropy into ClippedPGLossFn, handling sharded and non-sharded vocab cases, masking and reducing to per-sequence values. Returns both full and approximate entropy metrics. Introduces Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Trainer
participant ClippedPGLossFn as LossFn
participant ModelUtils as DistributedUtils
participant TPGroup as Tensor-Parallel Group
Trainer->>LossFn: __call__(logits, mask, vocab_parallel_group?)
alt Vocab is sharded
LossFn->>DistributedUtils: _compute_distributed_entropy(logits_shard, tp_group)
activate DistributedUtils
DistributedUtils->>TPGroup: all_reduce(MAX) on logits for stabilization
DistributedUtils->>DistributedUtils: compute local log-probs and probs
DistributedUtils->>TPGroup: all_reduce(SUM) entropy contributions
DistributedUtils-->>LossFn: per-token entropy
deactivate DistributedUtils
else Not sharded
LossFn->>LossFn: softmax + log to get per-token entropy
end
LossFn->>LossFn: apply mask, reduce to seq_entropy
LossFn-->>Trainer: loss, metrics{ full_entropy, approx_entropy, ... }
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal). Please share your feedback with us on this Discord post. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (4)
nemo_rl/distributed/model_utils.py (1)
96-104
: Avoid extra log and epsilon; use -p * log_p via log_probs directly.Computing log(probs + eps) adds noise and work. Use log_probs_local for exact -p*log p without eps.
Apply this diff:
- probs_local = log_probs_local.exp() - - # Compute local entropy contribution: -p * log(p) - eps = 1e-8 - entropy_local = -probs_local * torch.log(probs_local + eps) + probs_local = log_probs_local.exp() + # Local entropy contribution: -p * log(p) + entropy_local = -(probs_local * log_probs_local)nemo_rl/algorithms/loss_functions.py (3)
345-356
: Use log_softmax for stability and compute -∑ p log p without epsilon in the non‑TP path.Keeps computation in fp32 and avoids extra log(+eps).
Apply this diff:
- next_token_logits_wo_last = next_token_logits[:, :-1, :] - probs = torch.softmax(next_token_logits_wo_last, dim=-1) - eps = 1e-8 - log_probs = torch.log(probs + eps) - token_entropy = -torch.sum(probs * log_probs, dim=-1) + logits = next_token_logits[:, :-1, :].to(torch.float32) + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + probs = log_probs.exp() + token_entropy = -(probs * log_probs).sum(dim=-1)
25-26
: Avoid importing a private symbol across modules.Consider renaming
_compute_distributed_entropy
tocompute_distributed_entropy
(public) or re‑exporting it to avoid leading underscore usage in external modules.
385-386
: Metric name is clear; confirm units.
full_entropy
is globally token‑normalized (masked_mean with global_valid_toks). If you need per‑sequence entropy later, consider also logging an unnormalized sum.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
nemo_rl/algorithms/loss_functions.py
(3 hunks)nemo_rl/distributed/model_utils.py
(1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts
Files:
nemo_rl/distributed/model_utils.py
nemo_rl/algorithms/loss_functions.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
nemo_rl/**/*.py
: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)
Files:
nemo_rl/distributed/model_utils.py
nemo_rl/algorithms/loss_functions.py
🧬 Code graph analysis (1)
nemo_rl/algorithms/loss_functions.py (2)
nemo_rl/distributed/model_utils.py (1)
_compute_distributed_entropy
(59-110)nemo_rl/algorithms/utils.py (1)
masked_mean
(130-142)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Lint check
- GitHub Check: Post submodule check comment / Comment on PR
- GitHub Check: Post automodel integration comment / Comment on PR
_compute_distributed_entropy, | ||
from_parallel_logits_to_logprobs, | ||
get_logprobs_from_vocab_parallel_logits, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Import the CP all-gather helper for full-entropy when CP is enabled.
You’ll need allgather_cp_sharded_tensor to assemble per-CP entropy before masking.
Apply this diff:
from nemo_rl.distributed.model_utils import (
_compute_distributed_entropy,
from_parallel_logits_to_logprobs,
get_logprobs_from_vocab_parallel_logits,
+ allgather_cp_sharded_tensor,
)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
_compute_distributed_entropy, | |
from_parallel_logits_to_logprobs, | |
get_logprobs_from_vocab_parallel_logits, | |
) | |
from nemo_rl.distributed.model_utils import ( | |
_compute_distributed_entropy, | |
from_parallel_logits_to_logprobs, | |
get_logprobs_from_vocab_parallel_logits, | |
allgather_cp_sharded_tensor, | |
) |
🤖 Prompt for AI Agents
In nemo_rl/algorithms/loss_functions.py around lines 25 to 28, the allgather
helper for tensor-parallel (CP) should be imported but is missing; add the
symbol allgather_cp_sharded_tensor to the existing import list (alongside
_compute_distributed_entropy, from_parallel_logits_to_logprobs,
get_logprobs_from_vocab_parallel_logits) so the code can assemble per-CP entropy
before applying the mask.
# Compute actual entropy across all vocab: H(π) = -∑_v π(v) * log(π(v)) | ||
with torch.no_grad(): | ||
if vocab_parallel_group is not None: | ||
next_token_logits_trimmed = next_token_logits[ | ||
:, : data["input_ids"].shape[1] - 1, : | ||
] | ||
|
||
# Compute entropy across all vocabulary shards using distributed computation | ||
token_entropy = _compute_distributed_entropy( | ||
next_token_logits_trimmed, | ||
group=vocab_parallel_group, | ||
) | ||
|
||
# Apply masking and global reduction | ||
seq_entropy = masked_mean( | ||
token_entropy, | ||
mask, | ||
global_normalization_factor=global_valid_toks, | ||
) | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handle context parallelism when computing full entropy; current code will shape‑mismatch with mask under CP.
token_entropy
is computed on the local CP shard, but mask
is for the full sequence. This will crash on values * mask
. Gather across CP first, then slice to the unpadded length.
Apply this diff:
with torch.no_grad():
if vocab_parallel_group is not None:
- next_token_logits_trimmed = next_token_logits[
- :, : data["input_ids"].shape[1] - 1, :
- ]
-
- # Compute entropy across all vocabulary shards using distributed computation
- token_entropy = _compute_distributed_entropy(
- next_token_logits_trimmed,
- group=vocab_parallel_group,
- )
-
- # Apply masking and global reduction
- seq_entropy = masked_mean(
- token_entropy,
- mask,
- global_normalization_factor=global_valid_toks,
- )
+ # Trim logits in sequence dim (safe if shorter under CP)
+ next_token_logits_trimmed = next_token_logits[:, : data["input_ids"].shape[1] - 1, :]
+ # Compute TP-aggregated per-token entropy on this CP shard
+ token_entropy = _compute_distributed_entropy(
+ next_token_logits_trimmed, group=vocab_parallel_group
+ )
+ # If CP is enabled, gather per-token entropy across CP to match mask shape
+ if context_parallel_group is not None:
+ token_entropy = allgather_cp_sharded_tensor(
+ token_entropy, context_parallel_group, seq_dim=1
+ )
+ token_entropy = token_entropy[:, : data["input_ids"].shape[1] - 1]
+ # Reduce with mask
+ seq_entropy = masked_mean(
+ token_entropy, mask, global_normalization_factor=global_valid_toks
+ )
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
# Compute actual entropy across all vocab: H(π) = -∑_v π(v) * log(π(v)) | |
with torch.no_grad(): | |
if vocab_parallel_group is not None: | |
next_token_logits_trimmed = next_token_logits[ | |
:, : data["input_ids"].shape[1] - 1, : | |
] | |
# Compute entropy across all vocabulary shards using distributed computation | |
token_entropy = _compute_distributed_entropy( | |
next_token_logits_trimmed, | |
group=vocab_parallel_group, | |
) | |
# Apply masking and global reduction | |
seq_entropy = masked_mean( | |
token_entropy, | |
mask, | |
global_normalization_factor=global_valid_toks, | |
) | |
else: | |
# Compute actual entropy across all vocab: H(π) = -∑_v π(v) * log(π(v)) | |
with torch.no_grad(): | |
if vocab_parallel_group is not None: | |
# Trim logits in sequence dim (safe if shorter under CP) | |
next_token_logits_trimmed = next_token_logits[:, : data["input_ids"].shape[1] - 1, :] | |
# Compute TP-aggregated per-token entropy on this CP shard | |
token_entropy = _compute_distributed_entropy( | |
next_token_logits_trimmed, group=vocab_parallel_group | |
) | |
# If CP is enabled, gather per-token entropy across CP to match mask shape | |
if context_parallel_group is not None: | |
token_entropy = allgather_cp_sharded_tensor( | |
token_entropy, context_parallel_group, seq_dim=1 | |
) | |
token_entropy = token_entropy[:, : data["input_ids"].shape[1] - 1] | |
# Reduce with mask | |
seq_entropy = masked_mean( | |
token_entropy, mask, global_normalization_factor=global_valid_toks | |
) | |
else: |
🤖 Prompt for AI Agents
In nemo_rl/algorithms/loss_functions.py around lines 325 to 344, token_entropy
is computed only on the local context-parallel (CP) shard but the mask is for
the full sequence causing a shape mismatch; gather the per-shard token_entropy
across the CP group to reconstruct the full-vocab/full-sequence entropy tensor
(using the same CP group for collective gather), then slice/trim that gathered
tensor to data["input_ids"].shape[1] - 1 (to remove padding/extra time steps)
before calling masked_mean with the global mask and global_valid_toks; ensure
the collective uses the correct device/dtype and preserves batch dimension
ordering so masked_mean receives matching shapes.
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use this
Before your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit