From 9a7c96e820315b2fb42e46994facbe06b206bf81 Mon Sep 17 00:00:00 2001 From: Sahil Suneja <6835847+sahilsuneja1@users.noreply.github.com> Date: Thu, 20 Jun 2024 18:57:22 -0400 Subject: [PATCH] Update speculator.py --- fms_extras/models/speculator.py | 28 ++++++++-------------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/fms_extras/models/speculator.py b/fms_extras/models/speculator.py index f640cba..e4de672 100644 --- a/fms_extras/models/speculator.py +++ b/fms_extras/models/speculator.py @@ -32,12 +32,8 @@ class MLPSpeculator(nn.Module): Number of entries in the tokenizer associated with the base model. n_predict : int Number of heads / number of tokens to guess ahead. Model size and speed scale with this value. - tie_emb : bool - If true, use a single set of embedding weights for every model head/stage - tie_head : bool - If true, use a single set of prediction weights for every model head/stage - tie_transitions : bool - If true, use a single set of internal projection weights for every model head/stage after the first. + tie_weights : bool + If true, use a single set of weights for every model head/stage after the first. The initial projection from the base model may have a different size, so that stays separate. """ @@ -47,10 +43,7 @@ def __init__( inner_dim=0, vocab_size=32000, n_predict=3, - tie_emb=False, - tie_head=False, - tie_transition=False, - tie_wts=False, + tie_weights=False, scale_input=False, ): super().__init__() @@ -90,23 +83,18 @@ def __init__( self.activation = nn.GELU() # Handle weight tying as specified - if tie_wts: - tie_emb = tie_head = tie_transition = True - if tie_emb: - assert n_predict > 1, "You cannot tie embeddings when only 1 exists" + if tie_weights: + assert n_predict > 1, "You cannot tie weights between stages when only 1 exists" for emb in self.emb: emb.weight = self.emb[0].weight - if tie_head: - assert n_predict > 1, "You cannot tie heads when only 1 exists" + for head in self.head: head.weight = self.head[0].weight - if tie_transition: - assert ( - n_predict > 2 - ), "You cannot tie internal transitions when only 1 internal transition exists" + for ln in self.ln: ln.weight = self.ln[0].weight ln.bias = self.ln[0].bias + # Since first proj has different size, allow different initial proj from base into model for i in range(2, n_predict): self.proj[i].weight = self.proj[1].weight