Skip to content

Conversation

@faresobeid
Copy link
Contributor

@faresobeid faresobeid commented Dec 2, 2025

Note

Vectorizes advantage computation and overhauls RL loss with new masking/sequence controls, removing loss_scale from compute_loss and updating training/tests accordingly.

  • Advantage:
    • Vectorize compute_advantages using tensor reshaping; support length-weighted baseline via completion_lengths; remove per-group helper and return flat list.
  • Loss/Training:
    • Introduce new masking config in LossConfig: mask_low/high, seq_mask_low/high, seq_mask_neg_adv/pos_adv, seq_clip, and constant_norm.
    • Refactor compute_loss:
      • Remove loss_scale arg; add sequence- and token-level masking with new thresholds and adv-aware sequence masks; sequence ratio clipping and optional per-sequence normalization.
      • Return expanded metrics (token_masked*, seq_masked_*, and KL splits).
    • Update train.py to compute loss_scale (batch-size for sequence or constant_norm) and divide loss after compute_loss.
  • Tests:
    • Update unit tests to call compute_loss without loss_scale.

Written by Cursor Bugbot for commit 5324fd0. This will update automatically on new commits. Configure here.



def compute_loss(
trainer_logprobs: Any, # list of Float[Tensor, "seq_i"] with potentially different seq_i lengths
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol why did we have any here before

def shift_logits(logits: Float[Tensor, "batch seq vocab"]) -> Float[Tensor, "batch seq vocab"]:
def shift_logits(logits: Tensor) -> Tensor:
"""Removes final token logits and adds a zero logit for the first token."""
# We drop the last logit because it corresponds to the next token that will be sampled but is not here yet
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comments are nice no?


@jaxtyped(typechecker=typechecker)
@torch.compile(dynamic=True)
def selective_log_softmax(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why remove all the jaxtyping?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think its nice for ppl not so familiar with the code to know the input/ output tensor shapes. also good to catch at runtime if we violate these

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya fair, just cus it's very set in place what the shapes would be. Will put them back

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should be very set in place no?

seq_mask_pos_adv = (seq_ratio > cfg.seq_mask_pos_adv) & (seq_adv > 0)

if cfg.ratio_type == "sequence":
log_ratio = trainer_logprobs - trainer_logprobs.detach() + torch.clamp(seq_log_ratio, max=cfg.seq_clip).detach()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Sequence mode gradient clipping behavior changed

In sequence mode, the old code applied torch.clamp after computing trainer_logprobs - trainer_logprobs.detach() + seq_log_ratio.detach(), which would zero out gradients when the sequence ratio exceeded the clip value. The new code applies clamp before .detach() on seq_log_ratio, meaning gradients always flow through trainer_logprobs - trainer_logprobs.detach() regardless of ratio magnitude. This removes the implicit gradient blocking behavior for extreme importance ratios, potentially affecting training stability.

Fix in Cursor Fix in Web

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these look like they should be in a separate pr?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ya gonna revamp this pr as theres lots of algorithm options we will want to add

@faresobeid faresobeid closed this Dec 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants