-
Notifications
You must be signed in to change notification settings - Fork 151
simplify adv and loss code #1382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
|
||
|
|
||
| def compute_loss( | ||
| trainer_logprobs: Any, # list of Float[Tensor, "seq_i"] with potentially different seq_i lengths |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lol why did we have any here before
| def shift_logits(logits: Float[Tensor, "batch seq vocab"]) -> Float[Tensor, "batch seq vocab"]: | ||
| def shift_logits(logits: Tensor) -> Tensor: | ||
| """Removes final token logits and adds a zero logit for the first token.""" | ||
| # We drop the last logit because it corresponds to the next token that will be sampled but is not here yet |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comments are nice no?
|
|
||
| @jaxtyped(typechecker=typechecker) | ||
| @torch.compile(dynamic=True) | ||
| def selective_log_softmax( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why remove all the jaxtyping?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think its nice for ppl not so familiar with the code to know the input/ output tensor shapes. also good to catch at runtime if we violate these
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ya fair, just cus it's very set in place what the shapes would be. Will put them back
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it should be very set in place no?
| seq_mask_pos_adv = (seq_ratio > cfg.seq_mask_pos_adv) & (seq_adv > 0) | ||
|
|
||
| if cfg.ratio_type == "sequence": | ||
| log_ratio = trainer_logprobs - trainer_logprobs.detach() + torch.clamp(seq_log_ratio, max=cfg.seq_clip).detach() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Sequence mode gradient clipping behavior changed
In sequence mode, the old code applied torch.clamp after computing trainer_logprobs - trainer_logprobs.detach() + seq_log_ratio.detach(), which would zero out gradients when the sequence ratio exceeded the clip value. The new code applies clamp before .detach() on seq_log_ratio, meaning gradients always flow through trainer_logprobs - trainer_logprobs.detach() regardless of ratio magnitude. This removes the implicit gradient blocking behavior for extreme importance ratios, potentially affecting training stability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these look like they should be in a separate pr?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ya gonna revamp this pr as theres lots of algorithm options we will want to add
Note
Vectorizes advantage computation and overhauls RL loss with new masking/sequence controls, removing loss_scale from compute_loss and updating training/tests accordingly.
compute_advantagesusing tensor reshaping; support length-weighted baseline viacompletion_lengths; remove per-group helper and return flat list.LossConfig:mask_low/high,seq_mask_low/high,seq_mask_neg_adv/pos_adv,seq_clip, andconstant_norm.compute_loss:loss_scalearg; add sequence- and token-level masking with new thresholds and adv-aware sequence masks; sequence ratio clipping and optional per-sequence normalization.token_masked*,seq_masked_*, and KL splits).train.pyto computeloss_scale(batch-size forsequenceorconstant_norm) and divide loss aftercompute_loss.compute_losswithoutloss_scale.Written by Cursor Bugbot for commit 5324fd0. This will update automatically on new commits. Configure here.