Skip to content

Commit

Permalink
Update speculator.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sahilsuneja1 authored Jun 20, 2024
1 parent 0887569 commit 9a7c96e
Showing 1 changed file with 8 additions and 20 deletions.
28 changes: 8 additions & 20 deletions fms_extras/models/speculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,8 @@ class MLPSpeculator(nn.Module):
Number of entries in the tokenizer associated with the base model.
n_predict : int
Number of heads / number of tokens to guess ahead. Model size and speed scale with this value.
tie_emb : bool
If true, use a single set of embedding weights for every model head/stage
tie_head : bool
If true, use a single set of prediction weights for every model head/stage
tie_transitions : bool
If true, use a single set of internal projection weights for every model head/stage after the first.
tie_weights : bool
If true, use a single set of weights for every model head/stage after the first.
The initial projection from the base model may have a different size, so that stays separate.
"""

Expand All @@ -47,10 +43,7 @@ def __init__(
inner_dim=0,
vocab_size=32000,
n_predict=3,
tie_emb=False,
tie_head=False,
tie_transition=False,
tie_wts=False,
tie_weights=False,
scale_input=False,
):
super().__init__()
Expand Down Expand Up @@ -90,23 +83,18 @@ def __init__(
self.activation = nn.GELU()

# Handle weight tying as specified
if tie_wts:
tie_emb = tie_head = tie_transition = True
if tie_emb:
assert n_predict > 1, "You cannot tie embeddings when only 1 exists"
if tie_weights:
assert n_predict > 1, "You cannot tie weights between stages when only 1 exists"
for emb in self.emb:
emb.weight = self.emb[0].weight
if tie_head:
assert n_predict > 1, "You cannot tie heads when only 1 exists"

for head in self.head:
head.weight = self.head[0].weight
if tie_transition:
assert (
n_predict > 2
), "You cannot tie internal transitions when only 1 internal transition exists"

for ln in self.ln:
ln.weight = self.ln[0].weight
ln.bias = self.ln[0].bias

# Since first proj has different size, allow different initial proj from base into model
for i in range(2, n_predict):
self.proj[i].weight = self.proj[1].weight
Expand Down

0 comments on commit 9a7c96e

Please sign in to comment.