Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add another LayerNorm to MLPSpeculator #35

Closed
17 changes: 16 additions & 1 deletion fms_extras/models/hf/modeling_mlp_speculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def __init__(
n_predict: int = 3,
top_k_tokens_per_head: List[int] = [5, 4, 3],
n_candidates: int = 5,
tie_weights: bool = False,
scale_input: bool = False,
**kwargs
):
"""
Expand Down Expand Up @@ -49,6 +51,8 @@ def __init__(
self.n_predict = n_predict
self.top_k_tokens_per_head = top_k_tokens_per_head
self.n_candidates = n_candidates
self.tie_weights = tie_weights
self.scale_input = scale_input
super().__init__(**kwargs)


Expand All @@ -68,10 +72,17 @@ def __init__(
inner_dim=config.inner_dim,
vocab_size=config.vocab_size,
n_predict=config.n_predict,
tie_weights=config.tie_weights,
scale_input=config.scale_input,
)
if speculator is None:
self.speculator = MLPSpeculator(
config.emb_dim, config.inner_dim, config.vocab_size, config.n_predict
config.emb_dim,
config.inner_dim,
config.vocab_size,
config.n_predict,
tie_weights=config.tie_weights,
scale_input=config.scale_input,
)
self.speculator.reset_parameters()
else:
Expand All @@ -83,6 +94,8 @@ def from_fms_model(
model: MLPSpeculator,
top_k_tokens_per_head: List[int],
n_candidates: int,
tie_weights: bool = False,
scale_input: bool = False,
*args,
**kwargs
):
Expand All @@ -93,6 +106,8 @@ def from_fms_model(
n_predict=model.n_predict,
top_k_tokens_per_head=top_k_tokens_per_head,
n_candidates=n_candidates,
tie_weights=tie_weights,
scale_input=scale_input,
)
return cls(config, model)

Expand Down
108 changes: 90 additions & 18 deletions fms_extras/models/speculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,8 @@ class MLPSpeculator(nn.Module):
Number of entries in the tokenizer associated with the base model.
n_predict : int
Number of heads / number of tokens to guess ahead. Model size and speed scale with this value.
tie_emb : bool
If true, use a single set of embedding weights for every model head/stage
tie_head : bool
If true, use a single set of prediction weights for every model head/stage
tie_transitions : bool
If true, use a single set of internal projection weights for every model head/stage after the first.
tie_weights : bool
If true, use a single set of weights for every model head/stage after the first.
The initial projection from the base model may have a different size, so that stays separate.
"""

Expand All @@ -47,16 +43,16 @@ def __init__(
inner_dim=0,
vocab_size=32000,
n_predict=3,
tie_emb=False,
tie_head=False,
tie_transition=False,
tie_weights=False,
scale_input=False,
):
super().__init__()
self.n_predict = n_predict
self.emb_dim = emb_dim
inner_dim = inner_dim if inner_dim != 0 else emb_dim
self.inner_dim = inner_dim
self.vsize = vocab_size
self.scale_input = scale_input
self.emb = nn.ModuleList(
[nn.Embedding(vocab_size, inner_dim) for _ in range(n_predict)]
)
Expand All @@ -77,27 +73,31 @@ def __init__(
for _ in range(n_predict)
]
)
if self.scale_input:
self.ln0 = LayerNormParameterized(
emb_dim, elementwise_shift=False, elementwise_scale=False
)
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
self.state_weight = 0.5 ** (0.5 / n_predict)
self.emb_weight = math.sqrt(1 - self.state_weight**2)
self.activation = nn.GELU()

# Handle weight tying as specified
if tie_emb:
assert n_predict > 1, "You cannot tie embeddings when only 1 exists"
if tie_weights:
assert (
n_predict > 1
), "You cannot tie weights between stages when only 1 exists"

for emb in self.emb:
emb.weight = self.emb[0].weight
if tie_head:
assert n_predict > 1, "You cannot tie heads when only 1 exists"

for head in self.head:
head.weight = self.head[0].weight
if tie_transition:
assert (
n_predict > 2
), "You cannot tie internal transitions when only 1 internal transition exists"

for ln in self.ln:
ln.weight = self.ln[0].weight
ln.bias = self.ln[0].bias

# Since first proj has different size, allow different initial proj from base into model
for i in range(2, n_predict):
self.proj[i].weight = self.proj[1].weight
Expand All @@ -106,7 +106,7 @@ def reset_parameters(self):
for m in self.modules():
if isinstance(m, nn.Embedding) or isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, 0, 1 / math.sqrt(self.inner_dim))
elif isinstance(m, LayerNormParameterized):
elif isinstance(m, LayerNormParameterized) and hasattr(m, "weight"):
m.weight.data.fill_(1)
m.bias.data.zero_()

Expand Down Expand Up @@ -147,6 +147,8 @@ def generate_suffixes(
assert (
len(topk) == self.n_predict
), f"You must provide a topk number for each head ({self.n_predict} heads, {len(topk)} provided)"
if self.scale_input:
state = self.ln0(state) / (2**0.5)
for i in range(self.n_predict):
# Project and predict
z = self.emb[i](ind)
Expand Down Expand Up @@ -201,6 +203,8 @@ def forward(
Has size [self.n_predict b n v] where v is vocab size.
"""
out = []
if self.scale_input:
state = self.ln0(state) / (2**0.5)
for i in range(self.n_predict):
z = self.emb[i](inds[:, i : i + state.size(1)])
z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b n d
Expand Down Expand Up @@ -313,6 +317,47 @@ def flatten_batch(inp: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.
"inner_dim": 4096,
}

_llama_34b_code = {
"emb_dim": 8192,
"vocab_size": 32000,
"n_predict": 5,
"inner_dim": 8192,
"scale_input": True,
"tie_wts": True,
}

_llama3_8b_3_2b = {
"emb_dim": 4096,
"vocab_size": 128256,
"n_predict": 4,
"inner_dim": 3072,
}

_ibm_20b_code_instruct = {
"emb_dim": 6144,
"vocab_size": 49152,
"n_predict": 4,
"inner_dim": 4096,
}

_ibm_34b_code_instruct = {
"emb_dim": 6144,
"vocab_size": 49152,
"n_predict": 5,
"inner_dim": 6144,
"scale_input": True,
"tie_wts": True,
}

_llama3_70b_961m = {
"emb_dim": 8192,
"vocab_size": 128256,
"n_predict": 4,
"inner_dim": 3584,
"scale_input": True,
"tie_wts": True,
}

_architecture_name = "mlp_speculator"


Expand Down Expand Up @@ -344,6 +389,33 @@ def factory(**user_kwargs):
"llama.13b.code.2b",
_mlp_speculator_factory_factory(_llama_13b_code),
)
models.register_model(
_architecture_name,
"llama.34b.code.658m",
_mlp_speculator_factory_factory(_llama_34b_code),
)
models.register_model(
_architecture_name,
"llama.llama3.8b.3_2b",
_mlp_speculator_factory_factory(_llama3_8b_3_2b),
)
models.register_model(
_architecture_name,
"llama.llama3.70b.961m",
_mlp_speculator_factory_factory(_llama3_70b_961m),
)

models.register_model(
_architecture_name,
"gpt_bigcode.ibm.20b.1_7b",
_mlp_speculator_factory_factory(_ibm_20b_code_instruct),
)

models.register_model(
_architecture_name,
"gpt_bigcode.ibm.34b.680m",
_mlp_speculator_factory_factory(_ibm_34b_code_instruct),
)


def _rename_hf_weights_to_fms(orig_sd):
Expand Down
Loading