From 9a7c96e820315b2fb42e46994facbe06b206bf81 Mon Sep 17 00:00:00 2001
From: Sahil Suneja <6835847+sahilsuneja1@users.noreply.github.com>
Date: Thu, 20 Jun 2024 18:57:22 -0400
Subject: [PATCH] Update speculator.py

---
 fms_extras/models/speculator.py | 28 ++++++++--------------------
 1 file changed, 8 insertions(+), 20 deletions(-)

diff --git a/fms_extras/models/speculator.py b/fms_extras/models/speculator.py
index f640cba..e4de672 100644
--- a/fms_extras/models/speculator.py
+++ b/fms_extras/models/speculator.py
@@ -32,12 +32,8 @@ class MLPSpeculator(nn.Module):
         Number of entries in the tokenizer associated with the base model.
     n_predict : int
         Number of heads / number of tokens to guess ahead. Model size and speed scale with this value.
-    tie_emb : bool
-        If true, use a single set of embedding weights for every model head/stage
-    tie_head : bool
-        If true, use a single set of prediction weights for every model head/stage
-    tie_transitions : bool
-        If true, use a single set of internal projection weights for every model head/stage after the first.
+    tie_weights : bool
+        If true, use a single set of weights for every model head/stage after the first.
         The initial projection from the base model may have a different size, so that stays separate.
     """
 
@@ -47,10 +43,7 @@ def __init__(
         inner_dim=0,
         vocab_size=32000,
         n_predict=3,
-        tie_emb=False,
-        tie_head=False,
-        tie_transition=False,
-        tie_wts=False,
+        tie_weights=False,
         scale_input=False,
     ):
         super().__init__()
@@ -90,23 +83,18 @@ def __init__(
         self.activation = nn.GELU()
 
         # Handle weight tying as specified
-        if tie_wts:
-            tie_emb = tie_head = tie_transition = True
-        if tie_emb:
-            assert n_predict > 1, "You cannot tie embeddings when only 1 exists"
+        if tie_weights:
+            assert n_predict > 1, "You cannot tie weights between stages when only 1 exists"
             for emb in self.emb:
                 emb.weight = self.emb[0].weight
-        if tie_head:
-            assert n_predict > 1, "You cannot tie heads when only 1 exists"
+
             for head in self.head:
                 head.weight = self.head[0].weight
-        if tie_transition:
-            assert (
-                n_predict > 2
-            ), "You cannot tie internal transitions when only 1 internal transition exists"
+
             for ln in self.ln:
                 ln.weight = self.ln[0].weight
                 ln.bias = self.ln[0].bias
+
             # Since first proj has different size, allow different initial proj from base into model
             for i in range(2, n_predict):
                 self.proj[i].weight = self.proj[1].weight