-
-
Notifications
You must be signed in to change notification settings - Fork 135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: typical_p threshold sampling #343
base: main
Are you sure you want to change the base?
Conversation
ExplanationHere is an explanation of what the modified code is supposed to do.
We can also scale This modification has a much greater impact at low Typical-P Tau ranges than at higher ones. As a starting point, Typical-P Tau around 0.35-0.50 with Sigma of 1.0 appears to work well all-around. Preferably, no other sampler (besides Temperature=1) nor repetition penalty should be used with the modified Typical-P with these suggested settings. Sigma can act as a temperature-like control (higher values will cause a larger amount of less likely tokens to be included and vice-versa). Testing with different language models under various conditions may of course reveal different good values for both Tau and Sigma. Why?A problem in LLM-generated text is degeneracy (boredom/repetition and confusion/incoherency traps) with boredom/repetition due to a "pathology" of too often sampling the most likely tokens [1]. Typical-P by default does a good job in fixing this at low Tau values (by discarding high-probability tokens when appropriate), but it needlessly restricts the set of rarer tokens taken into consideration, ultimately making the token selection oddly deterministic when configured to 0.0 (⇒ the only possible choice is not the most likely token, but the one closest to the logit distribution's "conditional entropy"). Below is one example simulated in a spreadsheet. It can be seen that by default at low Tau values the token selection range is skewed toward rarer tokens on one side, but restricted on the other. The modification in this PR extends that range up to positive deviations that (at least up to a Sigma of 1.0) should safe to pick, keeping the average perplexity of the generate text high(er), combating repetition while still avoiding confusion/incoherency and preventing the default deterministic behavior at low-to-zero Typical-P Tau values. By allowing more tokens with positive deviation to be sampled, we also promote a longer positive tail in the distribution of deviations of the output tokens (as gauged by a language model at temperature=1). The graphs below from the Typical P paper show the distribution of deviations ( These graphs notably exhibit a long positive deviation tail (having roughly 2-3 times the absolute value of the negative deviation and an upper limit of about 10—although this might be use-case and model-dependent) and an average close to 0 or possibly very slightly positive (this is a universal observation). Boring, machine-sounding text composed by picking mostly highly-probable tokens, will exhibit a deviation distribution markedly skewed toward the negative side with very limited (if any) excursions into the positive side. The proposed Typical-p modification can allow to obtain more human-like Conditional Entropy deviation distributions from language models compared to the original version References
|
Co-authored-by: BugReporterZ <[email protected]>
|
||
neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True) | ||
# NOTE: We don't take the absolute value of the surprisal deviations | ||
# This deviates from the original implementation | ||
surprisal_deviations = neg_entropy - shifted_logits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I disagree with this change (or the intentions in the code here). The modification in my original hack (not posted in this PR) was intended to retain the basic behavior of Typical_P, which first sorts the surprisal deviations by their absolute value.
Only after this is done, then, using the signed surprisal deviations (copied into a different tensor before computing the absolute values for the other), you would obtain a second subset for extending the token selection as in the algorithm described in the explanation in the discussion.
Further testing over the past few days has revealed that,
was actually not a fair assumption and may too easily lead to incoherent text. While lower values could be used with good results, like A saner approach to this (while still retaining an additional hyperparameter over the original algorithm) would be scaling the negative deviations by a factor (Lambda) that defaults to 1.0 (same behavior as the original algorithm), before they are converted to absolute values in the code. By doing this, the tokens corresponding to the scaled deviations would be de-emphasized, and the token selection progressively skewed toward rarer tokens the higher Lambda is. Below is a simulated example using Tau=0.35 and Lambda in the 1.0~2.0 range. It can be seen how at Lambda > 1 the selection (yellow boxes) skews more toward rarer tokens:
I have code for the logic in def _apply_typical_sampling(
logits: torch.Tensor,
typical_p: torch.Tensor,
) -> torch.Tensor:
typ_p = torch.tensor(typical_p, dtype=logits.dtype, device=logits.device)
shifted_logits = torch.log_softmax(logits, dim=-1)
probs = shifted_logits.exp()
# Tensor names and logic have been slightly altered to make the procedure a bit more consistent
# with what the original Typical-P algorithm is intended to perform, thus clearer.
surprises = shifted_logits.neg()
conditional_entropy = (probs * surprises).nansum(dim=-1, keepdim=True)
surprisal_deviations = surprises - conditional_entropy
# Scale negative surprisal deviations by a scaling factor Lambda. This de-emphasizes
# high-probability tokens. Lambda would ideally be an additional hyperparameter to typical-p.
lambda_factor = 1.50
lambda_mask = surprisal_deviations < 0
# Don't affect special tokens (generally the first few tokens) by setting the mask to False
# for them. This is an ugly hack; ideally we would want to identify such tokens more directly
# as there is no guarantee that the first few tokens are special tokens or bytes (e.g. Qwen).
tokens_to_ignore = 3
lambda_mask[..., :tokens_to_ignore] = False
# Actual scaling performed here.
surprisal_deviations[lambda_mask] = surprisal_deviations[lambda_mask] * lambda_factor
# From now on, the algorithm proceeds as in the original one, sorting tokens by absolute
# surprisal deviations and picking them depending on their cumulative probability.
surprisal_deviations = surprisal_deviations.abs()
_, indices = torch.sort(surprisal_deviations)
reordered_probs = probs.gather(-1, indices)
typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typ_p.unsqueeze(dim=1)
min_tokens_to_keep = 1
# Keep at least min_tokens_to_keep
typ_mask_sorted[..., :min_tokens_to_keep] = 0
typ_mask = typ_mask_sorted.scatter(1, indices, typ_mask_sorted)
logits[typ_mask] = -float("inf")
return logits |
A different strategy with minimal modifications from the above could be, instead of scaling the positive deviations by a Lambda factor, shifting the entire deviations by a small Delta value. Again, code that would apply to the main branch: def _apply_typical_sampling(
logits: torch.Tensor,
typical_p: torch.Tensor,
) -> torch.Tensor:
typ_p = torch.tensor(typical_p, dtype=logits.dtype, device=logits.device)
shifted_logits = torch.log_softmax(logits, dim=-1)
probs = shifted_logits.exp()
# Tensor names and logic have been slightly altered to make the procedure a bit more consistent
# with what the original Typical-P algorithm is intended to perform, thus clearer.
surprises = shifted_logits.neg()
conditional_entropy = (probs * surprises).nansum(dim=-1, keepdim=True)
surprisal_deviations = surprises - conditional_entropy
# Shift surprisal deviations by a Delta value. This can both emphasize (-) or de-emphasize (+)
# high-probability tokens. Delta would ideally be an additional hyperparameter to typical-p.
# Note that small values in the 0.0~0.5 range are mostly useful here.
delta = 0.20
surprisal_deviations = surprisal_deviations - delta
# From now on, the algorithm proceeds as in the original one, sorting tokens by absolute
# surprisal deviations and picking them depending on their cumulative probability.
surprisal_deviations = surprisal_deviations.abs()
_, indices = torch.sort(surprisal_deviations)
reordered_probs = probs.gather(-1, indices)
typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typ_p.unsqueeze(dim=1)
min_tokens_to_keep = 1
# Keep at least min_tokens_to_keep
typ_mask_sorted[..., :min_tokens_to_keep] = 0
typ_mask = typ_mask_sorted.scatter(1, indices, typ_mask_sorted)
logits[typ_mask] = -float("inf")
return logits |
PR adds a new hyperparameter to typical_p sampling, which scales the maximum threshold for positive deviations in typ_p. Credits to Suikamelon (@BugReporterZ ). Untested yet.