-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speculative sampling #17
base: main
Are you sure you want to change the base?
Conversation
…added args to script for sampling
Example outputs demonstrating new sampling capabilities:
Greedy baseline:
top_k=5, temperature=2: no slowdown
top_k=5, temperature=5: slowdown due to low likelihood of (ridiculous) output
|
fms_extras/utils/generation.py
Outdated
# Composite greedy and non greedy outputs | ||
greedy = logits.argmax(-1) | ||
mask = do_sample[:, None, None].int() | ||
return samples * mask + (1 - mask) * greedy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if the mask is really a mask and not a weighting, might be better to use torch.where
.
we're calculating the sampled results even if we don't use them? I guess that's something to do with compilation but I would have thought the generation code would be outside the compile path?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, I'll swap to torch.where
. We are calculating the sampled result for every case, and while that will be useful for compile down the road, in this case it's mostly just for efficient gpu usage - pretty sure that partitioning the greedy/non-greedy lines and then re-mixing them after is more work than just sampling everything
For example, if the base model predicts tokens A and B with equal 50% probability, and the | ||
speculator produces one candidate with A and another with B, with independent sampling there's | ||
a 25% chance of rejecting both, even though one must be correct. Consistent sampling allows us | ||
to avoid this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if the goal is to speculate on a mutually exclusive set of possible continuations, why are we sampling at all and not just speculating on the top-k predictions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could do this, but we're more concerned with the ability to sample here than we are with the non-greediness of the approach. In this case "not greedy" is meant strictly literally, in that sampling involves not selecting greedily (assuming I'm understanding the question)
Implements simple speculative sampling via candidate-consistent ground-truth sampling. See #12 for a discussion on implementation details and why this is needed in the first place.
__generate_targets()
function, implementing both greedy and non-greedy selectionspeculative_generate()
paged_speculative_inference.py
demo scriptNotably, for low temperature and top_k, we anecdotally observe no reduction in speculator performance compared to the greedy case!