Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Speculator Architecture #6

Merged
merged 13 commits into from
Feb 23, 2024

Conversation

daviswer
Copy link
Collaborator

@daviswer daviswer commented Feb 22, 2024

Add support for speculative decoding architecture, with implementations of distinct parallel pretraining and generative inference forward passes.

import torch.nn.functional as F
from fms.modules.layernorm import LayerNormParameterized

class Speculator(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we might end up with multiple speculators, and if there's something more specific we might want to name this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be helpful to have an explanation of our speculation strategy in the comments too

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I'll add an explanation. We could call this MLP_Speculator

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI snake_case class names aren't convention in python: https://visualgit.readthedocs.io/en/latest/pages/naming_convention.html

we might want to rename the dataset classes at some point too

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the path be models instead of modules? at least in the main fms repo, we distinguish between the two. I think the speculator runs standalone, not as a child of another model

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha I was using the exact opposite logic - I had this in modules a la main repo, because it's a standalone object and not built from any dedicated sub-modules. But I can move this into models - at least for now that would make the organization of this repo slightly simpler


class MLP_Speculator(nn.Module):
"""
This is a simple MLP-based speculator that functions similarly to Medusa, ingesting context via
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might want to link to the medusa paper

The architecture is as flat and simple as possible: for each prediction head, the current
state vector is projected into a new latent space and added to the previous token's embedding.
This sum goes through layernorm and activation, forming the new state vector. This state predicts
the next token (or set of candidate tokens) for the current head, and then is passed on to the next.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: in diffs on github, the text will be easier to read if each lines is shorter (short enough to not need to wrap in a side-by-side window).

long enough that whole paragraphs are wrapped works too though then they become harder to comment on

import torch.nn.functional as F
from fms.modules.layernorm import LayerNormParameterized

class MLP_Speculator(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: snake_case for class names is unconventional in python

)
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
self.state_weight = .5**(.5/n_predict)
self.emb_weight = math.sqrt(1-self.state_weight**2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will need to run black, expects spaces are -

@daviswer
Copy link
Collaborator Author

I made the Blacking and Snake_Casing changes but it looks like a maintainer needs to approve the automated tests again

Copy link
Collaborator

@JRosenkranz JRosenkranz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just comments regarding type hints and docstrings

m.weight.data.fill_(1)
m.bias.data.zero_()

def generate_suffixes(self, state, ind, topk=[5, 4, 3], n=5):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add type hints and docstrings for this

1, best_guesses.unsqueeze(2).expand(-1, -1, self.n_predict)
) # b n h

def forward(self, state, inds):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add type hints and docstrings for this

@daviswer daviswer merged commit 71e5600 into foundation-model-stack:main Feb 23, 2024
3 checks passed
@daviswer daviswer deleted the speculator-v2 branch February 23, 2024 20:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants