Skip to content

Conversation

nfergu
Copy link
Contributor

@nfergu nfergu commented May 9, 2025

This is a proposal for adding a "search strategy" to token generation. This enables different strategies for selecting generated tokens. For example, see this (not tested) prototype of beam search that is based on this.

I have left the interface of generate_step unchanged by defaulting to a "linear search" strategy, which does the same thing as before. In addition, there is a new function called generate_with_search which accepts a SearchStrategy implementation, which allows other strategies to be used. I haven't plumbed generate_with_search into the main stream_generate but perhaps it should be?

There's perhaps some tidy up to be done, but I thought I'd get this out for feedback before I did too much more on it.

Let me know what you think. Happy to make changes. One alternative to this PR would be that search strategies would duplicate the code that is currently in generate_with_search within the SearchStrategy implementation itself. Or we refactor the code that is currently in generate_with_search into a helper function that search strategies would call. If you think either of those are nicer approaches I'll close this PR.

Yields:
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
Tuple[int, mx.array]: One token and a vector of log probabilities.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to the main change, but I think the typing was wrong here. I'm pretty sure this method generates int, as it calls .item() on the token array, but I might be missing something.

*,
max_tokens: int = 256,
sampler: Optional[Callable[mx.array, mx.array]] = None,
sampler: Optional[Callable[[mx.array], mx.array]] = None,
Copy link
Contributor Author

@nfergu nfergu May 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to the main change, but I think the typing was wrong here. The sampler has a single mx.array argument and returns an mx.array AFAICT.

quantized_kv_start: int = 0,
prompt_progress_callback: Optional[Callable[int, int]] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
prompt_progress_callback: Optional[Callable[[int, int], None]] = None,
Copy link
Contributor Author

@nfergu nfergu May 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to the main change, but I think the typing was wrong here. The callback has two int arguments, and doesn't return anything AFAICT.

prompt: mx.array,
prompt_cache: List[Any],
quantize_cache_fn: Callable[[Any], None],
total_prompt_tokens: int,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a bit awkward that generate needs to take total_prompt_tokens and prompt_progress_callback, but I could immediately see a nice way around this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant