Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Repetition Range ('rep_range') #888

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

discordianbelle
Copy link
Contributor

Adds a range to the repetition penalties (all samplers under do_penalties)

Counts back from the current token, applying to all output tokens within range, then prompt tokens if the range extends that far

The most expensive operations are just simple slicing operations which are relatively fast in PyTorch.

Copy link
Member

@AlpinDale AlpinDale left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. There's a few issues in its current state, as outlined in my comments. Mainly, this seems designed for sequential inference, and will fail for batched inference, due to no handling for per-sequence differences within a batch.

Also please run ./formatting.sh to fix linting issues.


_SAMPLING_EPS = 1e-5
_MAX_TEMP = 1e-2

APHRODITE_NO_DEPRECATION_WARNING = envs.APHRODITE_NO_DEPRECATION_WARNING
APHRODITE_NO_DEPRECATION_WARNING = bool(int(os.environ.get("APHRODITE_NO_DEPRECATION_WARNING", "0")))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After installing as editible and then modifying the files, I was getting circular import errors. As far as I could tell it was because envs.py is in ./aphrodite/ and not ./aphrodite/common/

I was trying to make the least impactful change that still let it run, so I didn't want to move envs.py

@@ -400,6 +402,9 @@ def _verify_args(self) -> None:
if self.repetition_penalty < 1.0:
raise ValueError("repetition_penalty must be in [1, inf), got "
f"{self.repetition_penalty}.")
if self.rep_range is not None and self.rep_range < 1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably allow 0 for infinite (or all tokens) range. Unless other inference software do it this way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're absolutely right, I'll implement on Monday

@@ -34,7 +34,7 @@

# If enabled, we switch to a more performant implementation
# of top-k and top-p
APHRODITE_USE_SAMPLING_KERNELS = envs.APHRODITE_USE_SAMPLING_KERNELS
APHRODITE_USE_SAMPLING_KERNELS = bool(int(os.environ.get("APHRODITE_USE_SAMPLING_KERNELS", "0")))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as before.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See other instance


repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
repetition_penalties[~(prompt_mask | output_mask)] = 1.0
logits = torch.where(logits > 0, logits / repetition_penalties,
logits * repetition_penalties)

# We follow the definition in OpenAI API.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't remove comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

o7
Sorry, danger of AI-assisted programming

@@ -272,7 +272,8 @@ def forward(
sampling_tensors.output_tokens,
sampling_tensors.presence_penalties,
sampling_tensors.frequency_penalties,
sampling_tensors.repetition_penalties)
sampling_tensors.repetition_penalties,
rep_range=rep_range)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
rep_range=rep_range)
sampling_tensors.rep_range)

This parameter needs to be added to the sampling_metadata module.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll fix on Monday!

@@ -507,25 +508,39 @@ def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
output_tokens_tensor: torch.Tensor,
presence_penalties: torch.Tensor,
frequency_penalties: torch.Tensor,
repetition_penalties: torch.Tensor) -> torch.Tensor:
repetition_penalties: torch.Tensor,
rep_range: Optional[int] = None) -> torch.Tensor:
Copy link
Member

@AlpinDale AlpinDale Dec 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
rep_range: Optional[int] = None) -> torch.Tensor:
rep_range: torch.Tensor) -> torch.Tensor:

Must be tensorized after it's added to sampling_metadata. If we treat it as a single integer, it'll apply to all sequences within the batch; tensorizing this will allow us to match the batch dimension, in case other sequences may want different ranges.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point, will handle

if rep_range is not None and rep_range > 0:
# Just take the last rep_range tokens from output_tokens_tensor
# This is much more efficient as we're only looking at recent history
output_tokens_tensor = output_tokens_tensor[:, -rep_range:]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is applying the same range to all sequences in batch, no?

Also creating a new tensor here is probably less efficient. I think the slicing should be done in the bin counting function above this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I'll have to go over it again with batching in mind, and also agreed about rolling in the slicing

if output_len < rep_range:
# Calculate how many prompt tokens we should include
prompt_tokens_to_include = min(rep_range - output_len, prompt_end_idx)
prompt_tokens_tensor = prompt_tokens_tensor[:, -prompt_tokens_to_include:]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not all sequences in a batch have the same output len (and consequently may not include the same number of prompt tokens).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I'll have to go over it again with batching in mind

prompt_tokens_tensor = prompt_tokens_tensor[:, -prompt_tokens_to_include:]
else:
# If we have enough output tokens, ignore prompt completely
prompt_tokens_tensor = torch.empty((num_seqs, 0), dtype=torch.long, device=logits.device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to do this here? I think most of the range ops can be done in bin counting.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll refactor it, thanks again

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants