-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Speculator Architecture #6
Merged
daviswer
merged 13 commits into
foundation-model-stack:main
from
JRosenkranz:speculator-v2
Feb 23, 2024
Merged
Changes from 8 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
57ed0c0
Add speculator with two forward modes
daviswer 1e79d45
Small qol adds
daviswer d9d8530
Fix import
daviswer ef8ac71
Swap n_heads for n_predict
daviswer 423ee80
Swap self.n_predict for self.npredict
daviswer 49b1ea0
Further docs, legibility, comments
daviswer 58c1b09
Move speculator to models subfolder
daviswer 0c98baa
Blacking, casing
daviswer 6bf3944
Add type hints / docstrings
daviswer 3c248c2
Fix typing import
daviswer 838c728
Blacking (sigh)
daviswer bd45b87
Fix typing imports pt2
daviswer a4c4ba3
isorting
daviswer File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
import math | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from fms.modules.layernorm import LayerNormParameterized | ||
|
||
class MLPSpeculator(nn.Module): | ||
""" | ||
This is a simple MLP-based speculator that functions similarly to Medusa | ||
(https://arxiv.org/abs/2401.10774), ingesting context via the final embedding | ||
vector from the base model. However, this model also conditions on previously | ||
predicted tokens, similarly to an RNN, allowing it to generate better-quality n-grams. | ||
|
||
The architecture is as flat and simple as possible: for each prediction head, | ||
the current state vector is projected into a new latent space and added to the | ||
previous token's embedding. This sum goes through layernorm and activation, forming | ||
the new state vector. This state predicts the next token (or set of candidate tokens) | ||
for the current head, and then is passed on to the next. | ||
... | ||
Args | ||
---- | ||
emb_dim : int | ||
Dimensionality of the input vector from the base model. | ||
inner_dim : int | ||
Latent dimensionality of the speculator model. | ||
vocab_size : int | ||
Number of entries in the tokenizer associated with the base model. | ||
n_predict : int | ||
Number of heads / number of tokens to guess ahead. Model size and speed scale with this value. | ||
""" | ||
|
||
def __init__(self, emb_dim=4096, inner_dim=0, vocab_size=32000, n_predict=3): | ||
super().__init__() | ||
self.n_predict = n_predict | ||
self.emb_dim = emb_dim | ||
inner_dim = inner_dim if inner_dim != 0 else emb_dim | ||
self.inner_dim = inner_dim | ||
self.vsize = vocab_size | ||
self.emb = nn.ModuleList( | ||
[nn.Embedding(vocab_size, inner_dim) for _ in range(n_predict)] | ||
) | ||
self.proj = nn.ModuleList( | ||
[ | ||
nn.Linear((emb_dim if i==0 else inner_dim), inner_dim, bias=False) | ||
for i in range(n_predict) | ||
] | ||
) | ||
self.head = nn.ModuleList( | ||
[nn.Linear(inner_dim, vocab_size, bias=False) for _ in range(n_predict)] | ||
) | ||
self.ln = nn.ModuleList( | ||
[ | ||
LayerNormParameterized( | ||
inner_dim, elementwise_shift=True, elementwise_scale=True | ||
) | ||
for _ in range(n_predict) | ||
] | ||
) | ||
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation | ||
self.state_weight = 0.5 ** (0.5 / n_predict) | ||
self.emb_weight = math.sqrt(1 - self.state_weight**2) | ||
self.activation = nn.GELU() | ||
|
||
def reset_parameters(self): | ||
for m in self.modules(): | ||
if isinstance(m, nn.Embedding) or isinstance(m, nn.Linear): | ||
nn.init.trunc_normal_(m.weight, 0, 1 / math.sqrt(self.inner_dim)) | ||
elif isinstance(m, LayerNormParameterized): | ||
m.weight.data.fill_(1) | ||
m.bias.data.zero_() | ||
|
||
def generate_suffixes(self, state, ind, topk=[5, 4, 3], n=5): | ||
""" | ||
FOR INFERENCE | ||
---- | ||
Generate tree of candidate sequences given latest base model embedding (state) and chosen token (ind). | ||
Topk indicates # of tree "branches" at each head. | ||
n pares down the candidate list from prod(topk) to the top n most confident. | ||
""" | ||
# state: b 1 d | ||
# ind: b 1 | ||
# k indicates # of candidates | ||
# h indicates # of generated tokens | ||
b = state.size(0) | ||
out = torch.empty(b, 1, 0, device=state.device).int() # b k h | ||
log_probs = torch.zeros(b, 1, device=state.device) # b k | ||
assert ( | ||
len(topk) == self.n_predict | ||
), f"You must provide a topk number for each head ({self.n_predict} heads, {len(topk)} provided)" | ||
for i in range(self.n_predict): | ||
# Project and predict | ||
z = self.emb[i](ind) | ||
z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d | ||
state = self.proj[i](state) * self.state_weight + z | ||
state = self.activation(self.ln[i](state)) # b k d | ||
probs = F.log_softmax(self.head[i](state), dim=2) # b k v | ||
probs, preds = probs.topk(topk[i], dim=2) # b k k' | ||
|
||
# Update candidate set with new predictions | ||
out = out.unsqueeze(2).expand(-1, -1, topk[i], -1) # b k k' h | ||
out = torch.cat([out, preds.unsqueeze(3)], dim=3) # b k k' h+1 | ||
out = out.view(b, -1, i + 1) # b kk' h+1 | ||
|
||
# Update state, log_probs and ind for new predictions | ||
state = state.unsqueeze(2).expand(-1, -1, topk[i], -1) # b k k' d | ||
state = state.reshape(b, -1, state.size(3)) # b kk' d | ||
ind = preds.view(b, -1) # b kk' | ||
log_probs = log_probs.unsqueeze(2).expand(b, -1, topk[i]) # b k k' | ||
log_probs = log_probs.add(probs).reshape(b, -1) # b kk' | ||
|
||
# Take only top n best guesses | ||
best_guesses = log_probs.topk(n, dim=1)[1] # b k | ||
return out.gather( | ||
1, best_guesses.unsqueeze(2).expand(-1, -1, self.n_predict) | ||
) # b n h | ||
|
||
def forward(self, state, inds): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we add type hints and docstrings for this |
||
""" | ||
FOR TRAINING | ||
---- | ||
Since we're assuming all prior tokens are "correct", don't act recursively, just pull from provided inds. | ||
Produces self.n_predict predicted tokens for each token embedding in state. | ||
Inds requires self.n_predict extra tokens on the right to "simulate" recursive behavior for end positions. | ||
""" | ||
# state: b n d | ||
# inds: b n+h (..., pred token, n+2, n+3, n+4) | ||
out = [] | ||
for i in range(self.n_predict): | ||
z = self.emb[i](inds[:, i : i + state.size(1)]) | ||
z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b n d | ||
state = self.proj[i](state) * self.state_weight + z | ||
state = self.activation(self.ln[i](state)) # b n d | ||
out.append(self.head[i](state)) # b n v | ||
return torch.stack(out, dim=0) # h b n v |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add type hints and docstrings for this