From 06474a0fe5ae8e3a912878e6a1796d4645bf2d86 Mon Sep 17 00:00:00 2001 From: sroy745 <142070531+sroy745@users.noreply.github.com> Date: Mon, 17 Jun 2024 19:29:09 -0700 Subject: [PATCH] [Speculative Decoding 1/2 ] Add typical acceptance sampling as one of the sampling techniques in the verifier (#5131) --- .../test_typical_acceptance_sampler.py | 464 ++++++++++++++++++ .../layers/rejection_sampler.py | 174 +------ .../layers/spec_decode_base_sampler.py | 206 ++++++++ .../layers/typical_acceptance_sampler.py | 186 +++++++ 4 files changed, 866 insertions(+), 164 deletions(-) create mode 100644 tests/samplers/test_typical_acceptance_sampler.py create mode 100644 vllm/model_executor/layers/spec_decode_base_sampler.py create mode 100644 vllm/model_executor/layers/typical_acceptance_sampler.py diff --git a/tests/samplers/test_typical_acceptance_sampler.py b/tests/samplers/test_typical_acceptance_sampler.py new file mode 100644 index 000000000000..87cf37bc926b --- /dev/null +++ b/tests/samplers/test_typical_acceptance_sampler.py @@ -0,0 +1,464 @@ +"""Tests for rejection sampling.""" + +import pytest +import torch + +from vllm.model_executor.layers.typical_acceptance_sampler import ( + TypicalAcceptanceSampler) +from vllm.model_executor.utils import set_random_seed + +CUDA_DEVICES = [f"cuda:{i}" for i in range(1)] + + +def get_zero_temperature_prob_dist(batch_size, k, vocab_size): + """ + Generates a fake temperature zero probability distribution. + Returns: + 1. A fake temperature zero probability distribution of shape + [batch_size, k, vocab_size] + 2. Tensor of shape [batch_size, k] containing the token ids + of the probability 1.0 tokens at each position. + """ + # Simulate temperature 0 probability distribution for target probabilities + # and create target probabilities such that only 1 token id has + # probability 1.0 + target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + probs = torch.rand(batch_size, k, vocab_size) + _, zero_temperature_token_ids = torch.max(probs, dim=-1) + # set the probability of the tokens with ids in zero_temperature_token_ids + # to 1 and the rest to 0. + target_probs = torch.zeros_like(probs).scatter_( + -1, zero_temperature_token_ids.unsqueeze(-1), 1.0) + return target_probs, zero_temperature_token_ids + + +def get_draft_token_ids(batch_size: int, k: int, vocab_size: int, + token_ids_to_exclude: torch.Tensor): + """ + Returns a tensor of shape [batch_size, k] of fake draft token ids + drawn randomly from a vocab of size vocab_size. We however ensure + that token_ids from token_ids_to_exclude are excluded at the + corresponding positions. + """ + draft_token_ids = torch.empty(batch_size, k, dtype=torch.long) + for i in range(batch_size): + for j in range(k): + # Generate a random token ID excluding token_ids_to_exclude[i, j] + while True: + token_id = torch.randint(0, vocab_size, (1, )).item() + if token_id != token_ids_to_exclude[i, j]: + draft_token_ids[i, j] = token_id + break + return draft_token_ids + + +@pytest.mark.parametrize("k", list(range(1, 6))) +@pytest.mark.parametrize("vocab_size", [30_000, 50_000]) +@pytest.mark.parametrize("batch_size", list(range(1, 32))) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, + device: str): + """ + Tests that the TypicalAcceptancSampler forward succeeds for + different combinations of k, vocab_size, batch_size and num devices. + """ + torch.set_default_device(device) + typical_acceptance_sampler = TypicalAcceptanceSampler() + typical_acceptance_sampler.init_gpu_tensors(rank=0) + target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64) + draft_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k), + dtype=torch.int64) + # Verify that sampling succeeds for all cases. + typical_acceptance_sampler(target_probs, bonus_token_ids, draft_token_ids) + + +@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"]) +@pytest.mark.parametrize("which_token_ids", + ["bonus_token_ids", "draft_token_ids"]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_raises_when_vocab_oob(above_or_below_vocab_range: str, + which_token_ids: str, device: str): + """ + Tests that we throw an exception of the token ids fall outside + the bound of the provided vocabulary. + """ + k = 3 + batch_size = 5 + vocab_size = 30_000 + torch.set_default_device(device) + typical_acceptance_sampler = TypicalAcceptanceSampler(strict_mode=True) + typical_acceptance_sampler.init_gpu_tensors(rank=0) + target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64) + draft_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k), + dtype=torch.int64) + # Verify that appropriate exceptions are thrown for out + # of bound vocabs. + oob_token_ids = None + if which_token_ids == "bonus_token_ids": + oob_token_ids = bonus_token_ids + elif which_token_ids == "draft_token_ids": + oob_token_ids = draft_token_ids + else: + raise AssertionError() + + if above_or_below_vocab_range == "above": + rogue_token_id = vocab_size + 1 + elif above_or_below_vocab_range == "below": + rogue_token_id = -1 + else: + raise AssertionError() + + oob_token_ids[0][0] = rogue_token_id + + with pytest.raises(AssertionError): + typical_acceptance_sampler(target_probs, bonus_token_ids, + draft_token_ids) + + +@pytest.mark.parametrize("seed", list(range(10))) +@pytest.mark.parametrize("disable_bonus_tokens", [True, False]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_uniform_target_distribution_accepts_all_tokens( + seed: int, disable_bonus_tokens: bool, device: str): + """ + Test the TypicalAcceptanceSampler with a uniform target probability + distribution. + + This test verifies that when provided with a uniform target probability + distribution, the TypicalAcceptanceSampler accepts all draft tokens. The + entropy of the uniform target distribution being high should lead to all + draft tokens being accepted. The test also ensures that the behavior + regarding bonus tokens is consistent with the `disable_bonus_tokens` + flag. + """ + set_random_seed(seed) + k = 3 + batch_size = 5 + vocab_size = 30_000 + torch.set_default_device(device) + typical_acceptance_sampler = TypicalAcceptanceSampler( + strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) + typical_acceptance_sampler.init_gpu_tensors(rank=0) + target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + draft_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k), + dtype=torch.int64) + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64) + output_token_ids = typical_acceptance_sampler(target_probs, + bonus_token_ids, + draft_token_ids) + # We are using a uniform target probability distribution. + # For a uniform distribution the entropy is very high and it + # should lead to all draft tokens being accepted. Verify that. + assert output_token_ids.shape[0] == batch_size + assert output_token_ids.shape[1] == (k + 1) + if disable_bonus_tokens: + assert torch.all(output_token_ids[:, -1] == -1) + else: + assert torch.all(output_token_ids[:, -1] == bonus_token_ids.squeeze()) + + assert torch.all(output_token_ids[:, :k] == draft_token_ids) + + +@pytest.mark.parametrize("seed", list(range(10))) +@pytest.mark.parametrize("disable_bonus_tokens", [True, False]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_temperature_zero_target_distribution(seed: int, + disable_bonus_tokens: bool, + device: str): + """ + Test the TypicalAcceptanceSampler with a zero-temperature target + probability distribution. + + This test verifies that when using a zero-temperature target probability + distribution, where only one token has a probability of 1.0, the + TypicalAcceptanceSampler correctly rejects all draft tokens that do not + match this probability. Additionally, it ensures that when all draft + tokens are rejected, the sampler falls back to greedy sampling to select a + single token from the target distribution. + """ + set_random_seed(seed) + k = 3 + batch_size = 5 + vocab_size = 30_000 + torch.set_default_device(device) + + typical_acceptance_sampler = TypicalAcceptanceSampler( + strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) + typical_acceptance_sampler.init_gpu_tensors(rank=0) + # Simulate temperature 0 probability distribution for target probabilities + # and create target probabilities such that only 1 token id has + # probability 1.0 + target_probs, zero_temperature_token_ids = get_zero_temperature_prob_dist( + batch_size, k, vocab_size) + # Populate draft_token_ids such that they exclude the token_ids + # with probability = 1.0 + draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size, + zero_temperature_token_ids) + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64) + # The target probaility distribution is a temperature zero distribution + # with zero entroy. Since our draft token ids don't match the probability + # 1.0 tokens in the target distribution we will reject all of them and + # fallback to the greedy sampling for selecting 1 token for each sequence. + # Verify the same. + output_token_ids = typical_acceptance_sampler(target_probs, + bonus_token_ids, + draft_token_ids) + assert output_token_ids.shape[0] == batch_size + assert output_token_ids.shape[1] == (k + 1) + assert torch.all(output_token_ids[:, -1] == -1) + assert torch.all(output_token_ids[:, 0] == zero_temperature_token_ids[:, + 0]) + + +@pytest.mark.parametrize("seed", list(range(10))) +@pytest.mark.parametrize("disable_bonus_tokens", [True, False]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool, + device: str): + """ + Test the TypicalAcceptanceSampler with a mixed target probability + distribution. + + This test ensures that the TypicalAcceptanceSampler handles a mixed + target probability distribution correctly. Specifically, it uses a + zero-temperature distribution for some sequences and a uniform + distribution for others. The test verifies that: + + - For sequences with a zero-temperature distribution, only the token + with a probability of 1.0 is accepted, and all other tokens are rejected. + - For sequences with a uniform distribution, all draft tokens are + accepted. + - When `disable_bonus_tokens` is False, the bonus tokens are also accepted + for sequences with a uniform distribution. + """ + set_random_seed(seed) + k = 3 + batch_size = 4 + vocab_size = 30_000 + torch.set_default_device(device) + typical_acceptance_sampler = TypicalAcceptanceSampler( + strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) + typical_acceptance_sampler.init_gpu_tensors(rank=0) + # For sequences 0 and 2 set the distribution to a temperature + # zero distribution. For sequences 1 and 3 set it to a uniform + # distribution. + target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist( + batch_size, k, vocab_size)) + draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size, + zero_temperature_token_ids) + uniform_probs = torch.rand(2, k, vocab_size, dtype=torch.float32) + target_probs[[1, 3]] = uniform_probs + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64) + output_token_ids = typical_acceptance_sampler(target_probs, + bonus_token_ids, + draft_token_ids) + # verify the shape of output_token_ids + assert output_token_ids.shape[0] == batch_size + assert output_token_ids.shape[1] == (k + 1) + # For sequences 0 and 2 verify that only 1 token is accepted + # which is the token with probability 1.0 in the target distribution + # at position 0. + assert torch.all(output_token_ids[[0, 2], 1:] == -1) + assert (torch.all(output_token_ids[[0, 2], + 0] == zero_temperature_token_ids[[0, 2], + 0])) + # For sequences 1 and 3 verify that all tokens are accepted since the + # target probability distribution is uniform. In addition verify that + # if disable_bonus_tokens is false then we also accept the bonus tokens. + assert torch.all( + output_token_ids[[1, 3], :-1] == draft_token_ids[[1, 3], :]) + if disable_bonus_tokens: + assert torch.all(output_token_ids[[1, 3], -1] == -1) + else: + assert torch.all(output_token_ids[[1, 3], -1] != -1) + + +@pytest.mark.parametrize("seed", list(range(10))) +@pytest.mark.parametrize("disable_bonus_tokens", [True, False]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, + device: str): + """ + Test the TypicalAcceptanceSampler's behavior when only a subset of draft + tokens should be accepted. + + This test verifies that the TypicalAcceptanceSampler correctly accepts or + rejects draft tokens based on a zero-temperature target probability + distribution. Specifically, it ensures that: + + - When all draft tokens match tokens with a probability of 1.0 in the + target distribution, all draft tokens are accepted. + - When only some draft tokens match tokens with a probability of 1.0 in + the target distribution, only those matching tokens are accepted, and the + rest are rejected. + """ + set_random_seed(seed) + k = 5 + batch_size = 1 + vocab_size = 30_000 + torch.set_default_device(device) + typical_acceptance_sampler = TypicalAcceptanceSampler( + strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) + typical_acceptance_sampler.init_gpu_tensors(rank=0) + # Create a temperature zero target probability distribution and ensure + # all draft token ids correspond to the tokens with 1.0 probability. + # Verify that all of them are accepted. + target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist( + batch_size, k, vocab_size)) + draft_token_ids = zero_temperature_token_ids + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64) + output_token_ids = typical_acceptance_sampler(target_probs, + bonus_token_ids, + draft_token_ids) + assert output_token_ids.shape[0] == batch_size + assert output_token_ids.shape[1] == (k + 1) + assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids) + if disable_bonus_tokens: + assert torch.all(output_token_ids[:, -1] == -1) + else: + assert torch.all(output_token_ids[:, -1] == bonus_token_ids) + # Next only keep the first 2 draft tokens same as the zero temperature + # tokens. For the remaining 3 choose some other tokens. In the + # response we will expect the first 2 tokens to be the same as the + # draft tokens and the rest as -1 + draft_token_ids_to_replace = get_draft_token_ids( + batch_size, k, vocab_size, zero_temperature_token_ids) + draft_token_ids = torch.cat( + (draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1) + output_token_ids = typical_acceptance_sampler(target_probs, + bonus_token_ids, + draft_token_ids) + assert output_token_ids.shape[0] == batch_size + assert output_token_ids.shape[1] == (k + 1) + assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2]) + assert torch.all(output_token_ids[:, -3:] == -1) + + +@pytest.mark.parametrize("seed", list(range(1))) +@pytest.mark.parametrize("disable_bonus_tokens", [True, False]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_accept_tokens_set_non_default_posteriors(seed: int, + disable_bonus_tokens: bool, + device: str): + """ + Test the TypicalAcceptanceSampler with custom posterior thresholds and + alpha values. This test verifies that by modifying the posterior + thresholds and alpha values we can change the acceptance behavior of the + sampler. + """ + set_random_seed(seed) + k = 5 + batch_size = 1 + vocab_size = 30_000 + torch.set_default_device(device) + typical_acceptance_sampler = TypicalAcceptanceSampler( + strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) + typical_acceptance_sampler.init_gpu_tensors(rank=0) + # Simulate temperature 0 probability distribution for target + # probabilities and create target probabilities such that only 1 token + # id has probability 1.0 and others have a very low probability of + # 0.00001. Populate draft_token_ids such that they exclude the token_ids + # with probability = 1.0. Without any changes to the posterior thresholds + # none of the draft tokens are accepted. + target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist( + batch_size, k, vocab_size)) + target_probs[target_probs == 0] = 0.00001 + draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size, + zero_temperature_token_ids) + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64) + output_token_ids = typical_acceptance_sampler(target_probs, + bonus_token_ids, + draft_token_ids) + assert output_token_ids.shape[0] == batch_size + assert output_token_ids.shape[1] == (k + 1) + assert torch.all(output_token_ids[:, 1:-1] == -1) + + # Change the posterior threshold values to 0.0 so that we will + # now accept even draft tokens with very low probability in the + # target distribution. Simulate and verify the same. + typical_acceptance_sampler = TypicalAcceptanceSampler( + strict_mode=True, + disable_bonus_tokens=disable_bonus_tokens, + posterior_threshold=0.0, + posterior_alpha=0.0) + typical_acceptance_sampler.init_gpu_tensors(rank=0) + output_token_ids = typical_acceptance_sampler(target_probs, + bonus_token_ids, + draft_token_ids) + assert output_token_ids.shape[0] == batch_size + assert output_token_ids.shape[1] == (k + 1) + assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids) + if disable_bonus_tokens: + assert torch.all(output_token_ids[:, -1] == -1) + else: + assert torch.all(output_token_ids[:, -1] == bonus_token_ids) + + +@pytest.mark.parametrize("seed", list(range(10))) +@pytest.mark.parametrize("disable_bonus_tokens", [True, False]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool, + device: str): + """ + Test the TypicalAcceptanceSampler's method for generating + replacement token IDs. + + This test verifies that the `_replacement_token_ids` method of the + TypicalAcceptanceSampler correctly identifies the token IDs to be used + as replacements based on the target probability distribution. + Specifically, it ensures that the method correctly identifies the + tokens with the highest probability for each sequence in the batch. + """ + set_random_seed(seed) + k = 10 + batch_size = 5 + vocab_size = 30_000 + torch.set_default_device(device) + typical_acceptance_sampler = TypicalAcceptanceSampler( + strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) + typical_acceptance_sampler.init_gpu_tensors(rank=0) + target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + expected_replacement_tokens = -torch.ones( + (batch_size, k), dtype=torch.long) + expected_replacement_tokens[:, 0] = torch.argmax(target_probs[:, 0, :], + dim=1) + actual_replacement_tokens = ( + typical_acceptance_sampler._replacement_token_ids(target_probs)) + assert torch.all(expected_replacement_tokens == actual_replacement_tokens) diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index a80703155c0b..fe9b2fac1117 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -1,12 +1,15 @@ from functools import cached_property -from typing import Optional, Tuple +from typing import Tuple import torch import torch.jit import torch.nn as nn +from vllm.model_executor.layers.spec_decode_base_sampler import ( + SpecDecodeBaseSampler) -class RejectionSampler(nn.Module): + +class RejectionSampler(SpecDecodeBaseSampler, nn.Module): """Apply modified rejection sampling as described in "Accelerating Large Language Model Decoding with Speculative Sampling" https://arxiv.org/pdf/2302.01318.pdf. @@ -22,39 +25,11 @@ def __init__(self, Require when bonus tokens will cause corrupt KV cache for proposal methods that require KV cache. strict_mode: Whether or not to perform shape/device/dtype checks - during sampling. This catches correctness issues but adds - nontrivial latency. + during sampling. This catches correctness issues but adds + nontrivial latency. """ - super().__init__() - self._disable_bonus_tokens = disable_bonus_tokens - self._strict_mode = strict_mode - - # NOTE: A "bonus token" is accepted iff all proposal tokens are - # accepted. There is always only one possible bonus token. We store this - # value in a variable for readability. - self._num_bonus_tokens = 1 - - self.num_accepted_tokens: Optional[torch.Tensor] = None - self.num_emitted_tokens: Optional[torch.Tensor] = None - self.num_draft_tokens: int = 0 - - def init_gpu_tensors(self, rank: int) -> None: - assert self.num_accepted_tokens is None - device = f"cuda:{rank}" - self.num_accepted_tokens = torch.tensor(0, - dtype=torch.long, - device=device) - self.num_emitted_tokens = torch.tensor(0, - dtype=torch.long, - device=device) - - @property - def probs_dtype(self): - return torch.float32 - - @property - def token_id_dtype(self): - return torch.int64 + SpecDecodeBaseSampler.__init__(self, disable_bonus_tokens, strict_mode) + nn.Module.__init__(self) def forward( self, @@ -100,15 +75,8 @@ def forward( # Only perform shape/dtype/device checking in strict mode, as it adds # overhead. if self._strict_mode: - self._raise_if_incorrect_shape(target_probs, bonus_token_ids, - draft_probs, draft_token_ids) - self._raise_if_incorrect_dtype(target_probs, bonus_token_ids, + self._raise_if_incorrect_input(target_probs, bonus_token_ids, draft_probs, draft_token_ids) - self._raise_if_inconsistent_device(target_probs, bonus_token_ids, - draft_probs, draft_token_ids) - self._raise_if_out_of_bounds_vocab(target_probs.shape[-1], - bonus_token_ids, - draft_token_ids) accepted, recovered_token_ids = self._batch_modified_rejection_sampling( target_probs, @@ -272,128 +240,6 @@ def _smallest_positive_value(self) -> float: """ return torch.finfo(self.probs_dtype).tiny - def _create_output( - self, - accepted: torch.Tensor, # [batch_size, k] - recovered_token_ids: torch.Tensor, # [batch_size, k] - draft_token_ids: torch.Tensor, # [batch_size, k] - bonus_token_ids: torch.Tensor, # [batch_size] - ) -> torch.Tensor: - """Format output. Returns a matrix of token ids. When - a token is rejected via rejection sampling, all subsequent - token ids are set to -1 for the sequence. - - shape = [batch_size, k + num_bonus_tokens] - """ - bonus_token_ids = bonus_token_ids.squeeze() - batch_size, k = recovered_token_ids.shape - - # Determine the index of the first False value for each row. - limits = (accepted == 0).max(1).indices - limits[~(accepted == 0).any(1)] = k - - # Create masks using the indices. - indices = torch.arange(k, device=accepted.device).unsqueeze(0) - accepted_mask = indices < limits.unsqueeze(1) - after_false_mask = indices == limits.unsqueeze(1) - - # Create an extended output tensor - output_with_bonus_tokens = -torch.ones( - (batch_size, k + self._num_bonus_tokens), - dtype=self.token_id_dtype, - device=accepted.device) - output = output_with_bonus_tokens[:, :k] - - # Fill in the first k columns of the output tensor using masks and data - # tensors. - torch.where(accepted_mask, - draft_token_ids, - -torch.ones_like(draft_token_ids), - out=output) - - # Fill the last column. - # We check output directly as accepted may have True values inconsistent - # with causal acceptance. - output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1, - bonus_token_ids, -1) - - # We disable bonus tokens because it causes corrupt KV cache for - # proposal methods that require KV cache. We can fix it by "prefilling" - # the bonus token in the proposer. The following issue tracks the fix. - # https://github.com/vllm-project/vllm/issues/4212 - if self._disable_bonus_tokens: - output_with_bonus_tokens[:, -1] = -1 - - # Fill the recovered token ids. - output.mul_(~after_false_mask).add_( - recovered_token_ids.mul(after_false_mask)) - - self.num_accepted_tokens += accepted.sum() - self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum() - self.num_draft_tokens += batch_size * k - - return output_with_bonus_tokens - - def _raise_if_incorrect_shape( - self, - target_probs: torch.Tensor, - bonus_token_ids: torch.Tensor, - draft_probs: torch.Tensor, - draft_token_ids: torch.Tensor, - ) -> None: - (target_batch_size, num_target_probs, - target_vocab_size) = target_probs.shape - bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape - draft_batch_size, num_draft_probs, draft_vocab_size = draft_probs.shape - draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape - - assert draft_batch_size == target_batch_size - assert num_draft_probs == num_target_probs - assert (draft_vocab_size == target_vocab_size - ), f"{draft_vocab_size=} {target_vocab_size=}" - - assert draft_token_ids_batch_size == draft_batch_size - assert num_draft_token_ids == num_draft_probs - - assert bonus_batch_size == target_batch_size - assert num_bonus_tokens == self._num_bonus_tokens - - def _raise_if_incorrect_dtype( - self, - target_probs: torch.Tensor, - bonus_token_ids: torch.Tensor, - draft_probs: torch.Tensor, - draft_token_ids: torch.Tensor, - ) -> None: - assert all(probs.dtype == self.probs_dtype - for probs in [target_probs, draft_probs]) - assert all(token_ids.dtype == self.token_id_dtype - for token_ids in [bonus_token_ids, draft_token_ids]) - - def _raise_if_inconsistent_device( - self, - target_probs: torch.Tensor, - bonus_token_ids: torch.Tensor, - draft_probs: torch.Tensor, - draft_token_ids: torch.Tensor, - ) -> None: - devices = [ - t.device for t in - [target_probs, bonus_token_ids, draft_probs, draft_token_ids] - ] - assert all([devices[0] == device for device in devices]) - - def _raise_if_out_of_bounds_vocab( - self, - vocab_size: int, - bonus_token_ids: torch.Tensor, - draft_token_ids: torch.Tensor, - ) -> None: - assert torch.all(bonus_token_ids < vocab_size) - assert torch.all(bonus_token_ids >= 0) - assert torch.all(draft_token_ids < vocab_size) - assert torch.all(draft_token_ids >= 0) - # torch.multinomial forces a GPU<->CPU sync. # Therefore, we use an optimized implementation instead that skips the sync. diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py new file mode 100644 index 000000000000..9856a7e7ddea --- /dev/null +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -0,0 +1,206 @@ +from typing import Optional + +import torch + + +class SpecDecodeBaseSampler(): + """Base class for samplers used for Speculative Decoding verification + step. + """ + + def __init__(self, + disable_bonus_tokens: bool = True, + strict_mode: bool = False): + """Base class constructor. + Args: + disable_bonus_tokens: Whether or not to disable the bonus token. + Require when bonus tokens will cause corrupt KV cache for + proposal methods that require KV cache. + strict_mode: Whether or not to perform shape/device/dtype checks + during sampling. This catches correctness issues but adds + nontrivial latency. + """ + super().__init__() + self._disable_bonus_tokens = disable_bonus_tokens + self._strict_mode = strict_mode + + # NOTE: A "bonus token" is accepted iff all proposal tokens are + # accepted. There is always only one possible bonus token. We store this + # value in a variable for readability. + self._num_bonus_tokens = 1 + + self.num_accepted_tokens: Optional[torch.Tensor] = None + self.num_emitted_tokens: Optional[torch.Tensor] = None + self.num_draft_tokens: int = 0 + + def init_gpu_tensors(self, rank: int) -> None: + assert self.num_accepted_tokens is None + device = f"cuda:{rank}" + self.num_accepted_tokens = torch.tensor(0, + dtype=torch.long, + device=device) + self.num_emitted_tokens = torch.tensor(0, + dtype=torch.long, + device=device) + + @property + def probs_dtype(self): + return torch.float32 + + @property + def token_id_dtype(self): + return torch.int64 + + def _create_output( + self, + accepted: torch.Tensor, # [batch_size, k] + substitute_token_ids: torch.Tensor, # [batch_size, k] + draft_token_ids: torch.Tensor, # [batch_size, k] + bonus_token_ids: torch.Tensor, # [batch_size] + ) -> torch.Tensor: + """Format output. Returns a matrix of token ids. When + a token is rejected via sampling, all subsequent token ids are + set to -1 for the sequence. + + Args: + accepted: A boolean tensor indicating if the corresponding + draft token in draft_token_ids should be accepted or not. + substitute_token_ids: A tensor of token_ids that can be used + as substitutes for the draft token ids if the proposed token + is rejected. + draft_token_ids: A tensor of token ids speculated by the + draft model. + bonus_token_ids: Token ids to use as the bonus token if + all the draft tokens are accepted. + Returns: + A tensor containing the accepted token ids. The shape of the + tensor is [batch_size, k + num_bonus_tokens] + """ + batch_size, k = substitute_token_ids.shape + bonus_token_ids = bonus_token_ids.squeeze() + # Determine the index of the first False value for each row. + limits = (accepted == 0).max(1).indices + limits[~(accepted == 0).any(1)] = k + + # Create masks using the indices. + indices = torch.arange(k, device=accepted.device).unsqueeze(0) + accepted_mask = indices < limits.unsqueeze(1) + after_false_mask = indices == limits.unsqueeze(1) + + # Create an extended output tensor + output_with_bonus_tokens = -torch.ones( + (batch_size, k + self._num_bonus_tokens), + dtype=self.token_id_dtype, + device=accepted.device) + output = output_with_bonus_tokens[:, :k] + + # Fill in the first k columns of the output tensor using masks and data + # tensors. + output[:, :k] = torch.where(accepted_mask, draft_token_ids, + -torch.ones_like(draft_token_ids)) + + # Fill the last column. + # We check output directly as accepted may have True values inconsistent + # with causal acceptance. + output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1, + bonus_token_ids, -1) + + # We disable bonus tokens because it causes corrupt KV cache for + # proposal methods that require KV cache. We can fix it by "prefilling" + # the bonus token in the proposer. The following issue tracks the fix. + # https://github.com/vllm-project/vllm/issues/4212 + if self._disable_bonus_tokens: + output_with_bonus_tokens[:, -1] = -1 + + # Fill the recovered token ids. + output.mul_(~after_false_mask).add_( + substitute_token_ids.mul(after_false_mask)) + + self.num_accepted_tokens += accepted.sum() + self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum() + self.num_draft_tokens += batch_size * k + + return output_with_bonus_tokens + + def _raise_if_incorrect_input( + self, + target_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: Optional[torch.Tensor] = None, + ) -> None: + self._raise_if_incorrect_shape(target_probs, draft_token_ids, + bonus_token_ids, draft_probs) + self._raise_if_incorrect_dtype(target_probs, draft_token_ids, + bonus_token_ids, draft_probs) + self._raise_if_inconsistent_device(target_probs, draft_token_ids, + bonus_token_ids, draft_probs) + self._raise_if_out_of_bounds_vocab(target_probs.shape[-1], + draft_token_ids, bonus_token_ids) + + def _raise_if_incorrect_shape( + self, + target_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: Optional[torch.Tensor] = None, + ) -> None: + (target_batch_size, num_target_probs, + target_vocab_size) = target_probs.shape + + # validate the shape of draft token ids. + draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape + assert draft_token_ids_batch_size == target_batch_size + assert num_draft_token_ids == num_target_probs + + # validate the shape of bonus token ids + bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape + assert bonus_batch_size == target_batch_size + assert num_bonus_tokens == self._num_bonus_tokens + + # validate the shape of draft probs if it is set + if draft_probs is not None: + (draft_batch_size, num_draft_probs, + draft_vocab_size) = draft_probs.shape + assert draft_batch_size == target_batch_size + assert num_draft_probs == num_target_probs + assert (draft_vocab_size == target_vocab_size + ), f"{draft_vocab_size=} {target_vocab_size=}" + + def _raise_if_incorrect_dtype( + self, + target_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: Optional[torch.Tensor] = None, + ) -> None: + assert target_probs.dtype == self.probs_dtype + assert draft_token_ids.dtype == self.token_id_dtype + assert bonus_token_ids.dtype == self.token_id_dtype + if draft_probs is not None: + assert draft_probs.dtype == self.probs_dtype + + def _raise_if_inconsistent_device( + self, + target_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: Optional[torch.Tensor] = None, + ) -> None: + devices = [ + t.device for t in + [target_probs, bonus_token_ids, draft_probs, draft_token_ids] + if t is not None + ] + assert all([devices[0] == device for device in devices]) + + def _raise_if_out_of_bounds_vocab( + self, + vocab_size: int, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + ) -> None: + assert torch.all(bonus_token_ids < vocab_size) + assert torch.all(bonus_token_ids >= 0) + assert torch.all(draft_token_ids < vocab_size) + assert torch.all(draft_token_ids >= 0) diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py new file mode 100644 index 000000000000..f12d6a03b4d1 --- /dev/null +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -0,0 +1,186 @@ +import torch +import torch.jit +import torch.nn as nn + +from vllm.model_executor.layers.spec_decode_base_sampler import ( + SpecDecodeBaseSampler) + + +class TypicalAcceptanceSampler(SpecDecodeBaseSampler, nn.Module): + """Apply typical acceptance sampling as described in section 3.3.1 in + "MEDUSA: Simple LLM Inference Acceleration Framework with + Multiple Decoding Heads" + https://arxiv.org/pdf/2401.10774 + """ + + def __init__( + self, + disable_bonus_tokens: bool = False, + strict_mode: bool = False, + posterior_threshold: float = 0.09, + posterior_alpha: float = 0.3, + ): + """Create a Typical Acceptance Sampler. + + Args: + disable_bonus_tokens: Whether or not to disable the bonus token. + Require when bonus tokens will cause corrupt KV cache for + proposal methods that require KV cache. + strict_mode: Whether or not to perform shape/device/dtype checks + during sampling. This catches correctness issues but adds + nontrivial latency. + posterior_threshold : A threshold value that sets a lower bound + on the posterior probability of a token in target model for it + to be accepted. Default is 0.09 + posterior_alpha : A scaling factor for the entropy-based + threshold in typical acceptance sampling. Typically defaults to + sqrt of posterior_threshold and is set to 0.3. + """ + SpecDecodeBaseSampler.__init__( + self, + disable_bonus_tokens=disable_bonus_tokens, + strict_mode=strict_mode) + nn.Module.__init__(self) + self._posterior_threshold = posterior_threshold + self._posterior_alpha = posterior_alpha + + def forward( + self, + target_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_token_ids: torch.Tensor, + ) -> torch.Tensor: + """Sample token ids using typical acceptance sampling. This accepts + or rejects tokens proposed by the draft model using the probability + of each token according to the draft and target models. + + In the worst case where all draft tokens are rejected, it is guaranteed + one token will be emitted. + + In the case where all draft tokens are accepted, the bonus token will be + accepted conditioned on self._disable_bonus_tokens being false. + + Args: + target_probs: The probability distribution over token ids given + context according to the target model. + shape = [batch_size, num_speculative_tokens, vocab_size] + + bonus_token_ids: The "bonus" token ids that are accepted iff all + speculative tokens in a sequence are accepted. + shape = [batch_size, num_bonus_tokens] + + draft_token_ids: The token ids that were sampled from the draft + probabilities. + shape = [batch_size, num_speculative_tokens] + + Returns: + output_token_ids: The token ids sampled via rejection sampling, + or -1 if unable to sample a token because the previous token + was rejected. + shape = [batch_size, num_speculative_tokens + num_bonus_tokens] + """ + # Only perform shape/dtype/device checking in strict mode, as it adds + # overhead. + if self._strict_mode: + self._raise_if_incorrect_input(target_probs, draft_token_ids, + bonus_token_ids) + accepted = self._evaluate_accepted_tokens(target_probs, + draft_token_ids) + recovered_token_ids = self._replacement_token_ids(target_probs) + output_token_ids = self._create_output(accepted, recovered_token_ids, + draft_token_ids, + bonus_token_ids) + return output_token_ids + + def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): + r""" + Evaluates and returns a mask of accepted tokens based on the + posterior probabilities. + + Parameters: + ---------- + target_probs : torch.Tensor + A tensor of shape (batch_size, k, vocab_size) representing + the probabilities of each token in the vocabulary for each + position in the proposed sequence. This is the distribution + generated by the target model. + draft_token_ids : torch.Tensor + A tensor of shape (batch_size, k) representing the proposed + token ids. + + A draft token_id x_{n+k} is accepted if it satisfies the + following condition + + .. math:: + p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) > + \min \left( \epsilon, \delta * \exp \left( + -H(p_{\text{original}}( + \cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right) + + where :math:`p_{\text{original}}` corresponds to target_probs + and :math:`\epsilon` and :math:`\delta` correspond to hyperparameters + specified using self._posterior_threshold and self._posterior_alpha + + This method computes the posterior probabilities for the given + draft token ids based on the provided target probabilities. It + calculates the entropy of the posterior distribution and determines + a dynamic threshold for each token position using the provided + posterior_threshold and posterior_alpha values. The method then + returns a boolean mask indicating which tokens can be accepted. + + Returns: + ------- + torch.Tensor + A boolean tensor of shape (batch_size, k) where each element + indicates whether the corresponding draft token has been accepted + or rejected. True indicates acceptance and false indicates + rejection. + + """ + device = target_probs.device + candidates_prob = torch.gather( + target_probs, dim=-1, + index=draft_token_ids.unsqueeze(-1)).squeeze(-1) + # A small constant added to prevent computing the logarithm of zero, + # which can lead to undefined values. + epsilon = 1e-5 + posterior_entropy = -torch.sum( + target_probs * torch.log(target_probs + epsilon), dim=-1) + threshold = torch.minimum( + torch.ones_like(posterior_entropy, device=device) * + self._posterior_threshold, + torch.exp(-posterior_entropy) * self._posterior_alpha, + ) + accepted_mask = candidates_prob > threshold + return accepted_mask + + def _replacement_token_ids(self, target_probs): + """ + Generate one replacement token ID for each sequence based on target + probabilities. The replacement token is used as the fallback option + if typical acceptance sampling does not accept any draft tokens for + that particular sequence. + + This method computes the token IDs to be replaced by selecting the + token with the highest probability for each sequence in the first + position. The rest of the output is filled with -1. + + Parameters + ---------- + target_probs : torch.Tensor + A tensor of shape (batch_size, k, vocab_size) containing + the target probability distribution + + Returns + ------- + torch.Tensor + A tensor of shape (batch_size, k) with the replacement + token IDs. Only the first column is set, and the rest of the + columns are filled with -1. + """ + max_indices = torch.argmax(target_probs[:, 0, :], dim=1) + output = -torch.ones((target_probs.shape[0], target_probs.shape[1]), + dtype=self.token_id_dtype, + device=target_probs.device) + output[:, 0] = max_indices + return output