-
Notifications
You must be signed in to change notification settings - Fork 270
Add search strategy #164
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add search strategy #164
Conversation
Yields: | ||
Tuple[mx.array, mx.array]: One token and a vector of log probabilities. | ||
Tuple[int, mx.array]: One token and a vector of log probabilities. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not related to the main change, but I think the typing was wrong here. I'm pretty sure this method generates int, as it calls .item()
on the token array, but I might be missing something.
*, | ||
max_tokens: int = 256, | ||
sampler: Optional[Callable[mx.array, mx.array]] = None, | ||
sampler: Optional[Callable[[mx.array], mx.array]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not related to the main change, but I think the typing was wrong here. The sampler has a single mx.array
argument and returns an mx.array
AFAICT.
quantized_kv_start: int = 0, | ||
prompt_progress_callback: Optional[Callable[int, int]] = None, | ||
) -> Generator[Tuple[mx.array, mx.array], None, None]: | ||
prompt_progress_callback: Optional[Callable[[int, int], None]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not related to the main change, but I think the typing was wrong here. The callback has two int
arguments, and doesn't return anything AFAICT.
prompt: mx.array, | ||
prompt_cache: List[Any], | ||
quantize_cache_fn: Callable[[Any], None], | ||
total_prompt_tokens: int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's a bit awkward that generate
needs to take total_prompt_tokens
and prompt_progress_callback
, but I could immediately see a nice way around this.
This is a proposal for adding a "search strategy" to token generation. This enables different strategies for selecting generated tokens. For example, see this (not tested) prototype of beam search that is based on this.
I have left the interface of
generate_step
unchanged by defaulting to a "linear search" strategy, which does the same thing as before. In addition, there is a new function calledgenerate_with_search
which accepts aSearchStrategy
implementation, which allows other strategies to be used. I haven't plumbedgenerate_with_search
into the mainstream_generate
but perhaps it should be?There's perhaps some tidy up to be done, but I thought I'd get this out for feedback before I did too much more on it.
Let me know what you think. Happy to make changes. One alternative to this PR would be that search strategies would duplicate the code that is currently in
generate_with_search
within theSearchStrategy
implementation itself. Or we refactor the code that is currently ingenerate_with_search
into a helper function that search strategies would call. If you think either of those are nicer approaches I'll close this PR.