-
-
Notifications
You must be signed in to change notification settings - Fork 135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Repetition Range ('rep_range') #888
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR. There's a few issues in its current state, as outlined in my comments. Mainly, this seems designed for sequential inference, and will fail for batched inference, due to no handling for per-sequence differences within a batch.
Also please run ./formatting.sh
to fix linting issues.
|
||
_SAMPLING_EPS = 1e-5 | ||
_MAX_TEMP = 1e-2 | ||
|
||
APHRODITE_NO_DEPRECATION_WARNING = envs.APHRODITE_NO_DEPRECATION_WARNING | ||
APHRODITE_NO_DEPRECATION_WARNING = bool(int(os.environ.get("APHRODITE_NO_DEPRECATION_WARNING", "0"))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After installing as editible and then modifying the files, I was getting circular import errors. As far as I could tell it was because envs.py is in ./aphrodite/ and not ./aphrodite/common/
I was trying to make the least impactful change that still let it run, so I didn't want to move envs.py
@@ -400,6 +402,9 @@ def _verify_args(self) -> None: | |||
if self.repetition_penalty < 1.0: | |||
raise ValueError("repetition_penalty must be in [1, inf), got " | |||
f"{self.repetition_penalty}.") | |||
if self.rep_range is not None and self.rep_range < 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably allow 0 for infinite (or all tokens) range. Unless other inference software do it this way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're absolutely right, I'll implement on Monday
@@ -34,7 +34,7 @@ | |||
|
|||
# If enabled, we switch to a more performant implementation | |||
# of top-k and top-p | |||
APHRODITE_USE_SAMPLING_KERNELS = envs.APHRODITE_USE_SAMPLING_KERNELS | |||
APHRODITE_USE_SAMPLING_KERNELS = bool(int(os.environ.get("APHRODITE_USE_SAMPLING_KERNELS", "0"))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as before.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See other instance
|
||
repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size) | ||
repetition_penalties[~(prompt_mask | output_mask)] = 1.0 | ||
logits = torch.where(logits > 0, logits / repetition_penalties, | ||
logits * repetition_penalties) | ||
|
||
# We follow the definition in OpenAI API. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't remove comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
o7
Sorry, danger of AI-assisted programming
@@ -272,7 +272,8 @@ def forward( | |||
sampling_tensors.output_tokens, | |||
sampling_tensors.presence_penalties, | |||
sampling_tensors.frequency_penalties, | |||
sampling_tensors.repetition_penalties) | |||
sampling_tensors.repetition_penalties, | |||
rep_range=rep_range) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rep_range=rep_range) | |
sampling_tensors.rep_range) |
This parameter needs to be added to the sampling_metadata module.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll fix on Monday!
@@ -507,25 +508,39 @@ def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, | |||
output_tokens_tensor: torch.Tensor, | |||
presence_penalties: torch.Tensor, | |||
frequency_penalties: torch.Tensor, | |||
repetition_penalties: torch.Tensor) -> torch.Tensor: | |||
repetition_penalties: torch.Tensor, | |||
rep_range: Optional[int] = None) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rep_range: Optional[int] = None) -> torch.Tensor: | |
rep_range: torch.Tensor) -> torch.Tensor: |
Must be tensorized after it's added to sampling_metadata. If we treat it as a single integer, it'll apply to all sequences within the batch; tensorizing this will allow us to match the batch dimension, in case other sequences may want different ranges.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great point, will handle
if rep_range is not None and rep_range > 0: | ||
# Just take the last rep_range tokens from output_tokens_tensor | ||
# This is much more efficient as we're only looking at recent history | ||
output_tokens_tensor = output_tokens_tensor[:, -rep_range:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is applying the same range to all sequences in batch, no?
Also creating a new tensor here is probably less efficient. I think the slicing should be done in the bin counting function above this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, I'll have to go over it again with batching in mind, and also agreed about rolling in the slicing
if output_len < rep_range: | ||
# Calculate how many prompt tokens we should include | ||
prompt_tokens_to_include = min(rep_range - output_len, prompt_end_idx) | ||
prompt_tokens_tensor = prompt_tokens_tensor[:, -prompt_tokens_to_include:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not all sequences in a batch have the same output len (and consequently may not include the same number of prompt tokens).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, I'll have to go over it again with batching in mind
prompt_tokens_tensor = prompt_tokens_tensor[:, -prompt_tokens_to_include:] | ||
else: | ||
# If we have enough output tokens, ignore prompt completely | ||
prompt_tokens_tensor = torch.empty((num_seqs, 0), dtype=torch.long, device=logits.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to do this here? I think most of the range ops can be done in bin counting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll refactor it, thanks again
Adds a range to the repetition penalties (all samplers under do_penalties)
Counts back from the current token, applying to all output tokens within range, then prompt tokens if the range extends that far
The most expensive operations are just simple slicing operations which are relatively fast in PyTorch.