Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce peak memory for prompt_logprobs requests #907

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

50h100a
Copy link
Collaborator

@50h100a 50h100a commented Dec 16, 2024

First order of business, make prompt_logprobs "compatible" with prefix caching.
It can't take advantage of the caching, but at least it will run.

Second order of business, reduce the peak memory usage of the samplers.
This PR slightly reduces the memory load, but not nearly enough:
On single-GPU, sampling can still take dozens of gigabytes at peak memory. (8b model at 16k was >10gb)
On multi-GPU, sampling is no cheaper, and there's also a colossal memory spike when gathering the logits.

Thoughts:

  • In this PR, some operations are split into smaller batches. Can we split the entire sampling process the same way? Leaving it mostly unchanged, but only handling a fixed k of rows at a time?
  • No idea what the fix is for the gather spikes, deferring to @AlpinDale on that. That might not even be the specific issue, just where it ran out of VRAM for me, but there's something about multi-GPU that's aggravating the memory peaks.

@AlpinDale AlpinDale self-requested a review December 19, 2024 17:35
@AlpinDale
Copy link
Member

Will probably need some restructuring after #925

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants