-
Notifications
You must be signed in to change notification settings - Fork 148
Set attention_mask to None by default. #1213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
📝 WalkthroughWalkthroughCentralizes attention_mask handling in dtensor_policy_worker_v2 by removing per-branch mask construction and always passing attention_mask=None in both train and get_logprobs paths. No public API changes and no control-flow alterations beyond argument preparation. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant W as PolicyWorkerV2
participant M as Model
rect rgba(230,240,255,0.5)
Note over W: Old (before)
W->>W: Prepare inputs
alt Non-packed
W->>W: Build attention_mask = ones(...)
else Packed / certain branches
W->>W: attention_mask = None
end
W->>M: forward(inputs, attention_mask)
M-->>W: outputs
end
rect rgba(235,255,235,0.5)
Note over W: New (after)
W->>W: Prepare inputs
W->>W: attention_mask = None (centralized)
W->>M: forward(inputs, attention_mask=None)
M-->>W: outputs
end
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal). Please share your feedback with us on this Discord post. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
nemo_rl/models/policy/dtensor_policy_worker_v2.py (2)
970-975
: Fix stale/incorrect comment and typo (“causal”).The comment refers to a prior “all-ones” mask while the code now passes None. Update to reflect the new behavior and correct the typo.
- # DTensor requires the casual attention kernel to hit, - # yet our attention mask above is not always all 1s - # this is fine because we mask with the actual attention mask - # later, but for input it has to be all 1s - attention_mask = None + # We deliberately pass attention_mask=None to select the causal kernel path. + # Padding is applied later via post_attention_mask when computing token logprobs. + attention_mask = None
1226-1241
: Unify score() to also pass attention_mask=None by default.For consistency with train/get_logprobs and the PR intent, consider dropping the explicit all-ones mask here too.
- attention_mask = torch.ones( - (batch_size, seq_len), - dtype=torch.bool, - device=input_ids.device, - ) + attention_mask = None
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
nemo_rl/models/policy/dtensor_policy_worker_v2.py
(2 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts
Files:
nemo_rl/models/policy/dtensor_policy_worker_v2.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
nemo_rl/**/*.py
: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)
Files:
nemo_rl/models/policy/dtensor_policy_worker_v2.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: build-container / main
- GitHub Check: Lint check
- GitHub Check: Lint check
- GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (1)
nemo_rl/models/policy/dtensor_policy_worker_v2.py (1)
647-648
: Passing attention_mask=None in train is reasonable; please verify VLM/back-compat.Defaulting to None should select the causal kernel path and is correct for right-padded LM training (padding tokens don’t affect earlier tokens; loss masking handles pads). For seq-packing, FA2 cu_seqlens handles masking. Please sanity-check multimodal variants and any legacy models that might have depended on an explicit all-ones mask.
Signed-off-by: Jonas Yang <[email protected]>
bfc8f67
to
f3e63c6
Compare
|
Signed-off-by: Jonas Yang <[email protected]>
|
|
Signed-off-by: Jonas Yang <[email protected]>
b2195e1
to
2333bbd
Compare
|
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use this
Before your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
Bug Fixes
Performance
Reliability