Skip to content

Commit

Permalink
sampling: simplify min-p sampling (#713)
Browse files Browse the repository at this point in the history
As suggested by @aikitoria and @rolandtannous in #710 , we don't need
rejection algorithm for min-p sampling, this PR simplifies the design.

There is a breaking change on API: we no longer returns `success` array
for min-p sampling because there is no risk of rejecting the samples. In
later PRs, we will also remove the `success` array in top-p/top-k
sampling APIs.
  • Loading branch information
yzh119 authored Jan 3, 2025
1 parent 561f646 commit 0f80329
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 119 deletions.
4 changes: 2 additions & 2 deletions csrc/flashinfer_sampling_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ void top_k_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at:
unsigned int top_k_val, bool deterministic, int64_t cuda_stream);

void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples,
at::Tensor success, std::optional<at::Tensor> maybe_min_p_arr,
double min_p_val, bool deterministic, int64_t cuda_stream);
std::optional<at::Tensor> maybe_min_p_arr, double min_p_val,
bool deterministic, int64_t cuda_stream);

void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples,
at::Tensor samples, at::Tensor success,
Expand Down
13 changes: 6 additions & 7 deletions csrc/sampling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -90,26 +90,25 @@ void top_k_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at:
}

void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples,
at::Tensor success, std::optional<at::Tensor> maybe_min_p_arr,
double min_p_val, bool deterministic, int64_t cuda_stream) {
std::optional<at::Tensor> maybe_min_p_arr, double min_p_val,
bool deterministic, int64_t cuda_stream) {
CHECK_INPUT(probs);
CHECK_INPUT(uniform_samples);
auto device = probs.device();
CHECK_EQ(uniform_samples.device(), device);
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
CHECK_DIM(2, uniform_samples); // uniform_samples: (max_rounds, batch_size)
CHECK_DIM(1, uniform_samples); // uniform_samples: (batch_size)
unsigned int batch_size = probs.size(0);
unsigned int vocab_size = probs.size(1);
unsigned int max_rounds = uniform_samples.size(0);
CHECK_EQ(uniform_samples.size(1), batch_size);
CHECK_EQ(uniform_samples.size(0), batch_size);
bool has_min_p_arr = maybe_min_p_arr.has_value();

cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
cudaError_t status = sampling::MinPSamplingFromProb<float, int>(
static_cast<float*>(probs.data_ptr()), static_cast<float*>(uniform_samples.data_ptr()),
has_min_p_arr ? static_cast<float*>(maybe_min_p_arr->data_ptr()) : nullptr,
static_cast<int*>(samples.data_ptr()), static_cast<bool*>(success.data_ptr()), batch_size,
min_p_val, vocab_size, max_rounds, deterministic, stream);
static_cast<int*>(samples.data_ptr()), batch_size, min_p_val, vocab_size, deterministic,
stream);
TORCH_CHECK(status == cudaSuccess, "MinPSamplingFromProb failed with error code " +
std::string(cudaGetErrorString(status)));
}
Expand Down
26 changes: 11 additions & 15 deletions flashinfer/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,26 +164,24 @@ def min_p_sampling_from_probs(
maybe_min_p_arr: Optional[torch.Tensor],
min_p_val: float,
deterministic: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
with probs.device as device:
probs = probs.float()
uniform_samples = uniform_samples.float()
maybe_min_p_arr = (
maybe_min_p_arr.float() if maybe_min_p_arr is not None else None
)
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
module.min_p_sampling_from_probs(
probs,
uniform_samples,
samples,
success,
maybe_min_p_arr,
min_p_val,
deterministic,
get_cuda_stream(device),
)
return samples, success
return samples

# torch library for top_k_top_p_sampling_from_probs

Expand Down Expand Up @@ -634,7 +632,7 @@ def min_p_sampling_from_probs(
min_p: Union[torch.Tensor, float],
deterministic: bool = True,
check_nan: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
r"""Fused GPU kernel for `min_p sampling <https://arxiv.org/abs/2407.01082>`_ from probabilities,
this operator implements GPU-based rejection sampling without explicit sorting.
Expand All @@ -647,8 +645,7 @@ def min_p_sampling_from_probs(
probs: torch.Tensor
Probabilities, shape ``(batch_size, num_classes)``.
uniform_samples: torch.Tensor
The uniform samples used as needle for sampling, shape ``(max_top_k_rounds, batch_size,)``,
where the first dimension is the maximum number of rounds for rejection sampling.
The uniform samples used as needle for sampling, shape ``(batch_size,)``,
Expected to be uniformly distributed in ``[0, 1)``.
min_p: torch.Tensor
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for min-p sampling.
Expand All @@ -663,9 +660,6 @@ def min_p_sampling_from_probs(
-------
samples: torch.Tensor
Sampled categories, shape ``(batch_size,)``.
success: torch.Tensor
Whether the sampling is successful within ``max_top_k_rounds`` rounds,
shape ``(batch_size,)``.
Examples
--------
Expand All @@ -676,7 +670,6 @@ def min_p_sampling_from_probs(
<torch._C.Generator object at 0x7f8b3db06df0>
>>> batch_size = 4
>>> vocab_size = 5
>>> max_rounds = 3
>>> min_p = torch.full((batch_size,), 0.05).to(0)
>>> pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
>>> norm_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
Expand All @@ -685,19 +678,22 @@ def min_p_sampling_from_probs(
[0.2205, 0.0942, 0.2912, 0.3452, 0.0489],
[0.2522, 0.1602, 0.2346, 0.1532, 0.2000],
[0.1543, 0.3182, 0.2062, 0.0958, 0.2255]], device='cuda:0')
>>> uniform_samples = torch.rand(max_rounds, batch_size).to(0)
>>> samples, success = flashinfer.sampling.min_p_sampling_from_probs(norm_prob, uniform_samples, min_p)
>>> uniform_samples = torch.rand(batch_size).to(0)
>>> samples = flashinfer.sampling.min_p_sampling_from_probs(norm_prob, uniform_samples, min_p)
>>> samples
tensor([1, 2, 1, 4], device='cuda:0', dtype=torch.int32)
>>> success
tensor([True, True, True, True], device='cuda:0')
Note
----
This function expects float32 inputs, and the output is int32.
We encourage users to set ``max_rounds`` to a reasonable value, e.g., 32. The actual
implementation usually use much fewer rounds for rejection sampling because of early stopping.
"""
# NOTE(Zihao): for backward compatiblity (https://github.com/flashinfer-ai/flashinfer/pull/713)
if uniform_samples.dim() == 2:
# Take the first row (round) of uniform_samples
uniform_samples = uniform_samples[0]

if check_nan:
if torch.any(torch.isnan(probs)):
raise ValueError("Input probs contains NaN.")
Expand Down
140 changes: 56 additions & 84 deletions include/flashinfer/sampling.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -184,18 +184,18 @@ __device__ __forceinline__ void DeterministicInclusiveSum(
}

template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM, bool DETERMINISTIC, typename T>
BlockReduceAlgorithm REDUCE_ALGORITHM, bool DETERMINISTIC, typename T, typename Predicate>
__device__ __forceinline__ void DeviceSamplingFromProb(
uint32_t i, uint32_t d, T threshold, T u, vec_t<T, VEC_SIZE> prob_vec, T& aggregate,
uint32_t i, uint32_t d, Predicate pred, T u, vec_t<T, VEC_SIZE> prob_vec, T& aggregate,
SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>* temp_storage) {
const uint32_t tx = threadIdx.x;
T prob_greater_than_threshold[VEC_SIZE];
T inclusive_cdf[VEC_SIZE];
bool greater_than_u[VEC_SIZE], valid[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
prob_greater_than_threshold[j] = (prob_vec[j] > threshold) ? prob_vec[j] : T(0);
valid[j] = prob_vec[j] > threshold && (i * BLOCK_THREADS + tx) * VEC_SIZE < d;
prob_greater_than_threshold[j] = pred(prob_vec[j]) ? prob_vec[j] : T(0);
valid[j] = pred(prob_vec[j]) && (i * BLOCK_THREADS + tx) * VEC_SIZE < d;
}
T aggregate_local =
BlockReduce<T, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage->block_prim.reduce)
Expand All @@ -219,7 +219,7 @@ __device__ __forceinline__ void DeviceSamplingFromProb(

#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
greater_than_u[j] = inclusive_cdf[j] + aggregate > u;
greater_than_u[j] = (inclusive_cdf[j] + aggregate > u) && valid[j];
}

bool greater_than_u_diff[VEC_SIZE];
Expand All @@ -234,13 +234,8 @@ __device__ __forceinline__ void DeviceSamplingFromProb(

#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
if (greater_than_u_diff[j] && valid[j]) {
if constexpr (DETERMINISTIC) {
temp_storage->sampled_id = (i * BLOCK_THREADS + tx) * VEC_SIZE + j;
} else {
// cub's block scan result might not be monotonic, so we need to find the first element
atomicMin(&(temp_storage->sampled_id), (i * BLOCK_THREADS + tx) * VEC_SIZE + j);
}
if (greater_than_u_diff[j]) {
atomicMin(&(temp_storage->sampled_id), (i * BLOCK_THREADS + tx) * VEC_SIZE + j);
}
}
__syncthreads();
Expand Down Expand Up @@ -275,7 +270,8 @@ __global__ void SamplingFromProbKernel(DType* probs, DType* uniform_samples, IdT
}

DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM, DETERMINISTIC,
DType>(i, d, DType(0), u, probs_vec, aggregate, &temp_storage);
DType>(
i, d, [](DType x) { return x > 0; }, u, probs_vec, aggregate, &temp_storage);
if (float(aggregate) > u) {
break;
}
Expand Down Expand Up @@ -316,8 +312,8 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples,
}

DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM,
DETERMINISTIC, DType>(i, d, pivot, u, probs_vec, aggregate,
&temp_storage);
DETERMINISTIC, DType>(
i, d, [&](DType x) { return x > pivot; }, u, probs_vec, aggregate, &temp_storage);
if (aggregate > u) {
break;
}
Expand Down Expand Up @@ -404,8 +400,8 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
}

DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM,
DETERMINISTIC, DType>(i, d, pivot, u, probs_vec, aggregate,
&temp_storage);
DETERMINISTIC, DType>(
i, d, [&](DType x) { return x > pivot; }, u, probs_vec, aggregate, &temp_storage);
if (aggregate > u) {
break;
}
Expand Down Expand Up @@ -459,8 +455,7 @@ template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
typename DType, typename IdType>
__global__ void MinPSamplingFromProbKernel(DType* probs, DType* uniform_samples, DType* min_p_arr,
IdType* output, bool* success, float min_p_val,
uint32_t d, uint32_t max_min_p_rounds) {
IdType* output, float min_p_val, uint32_t d) {
const uint32_t batch_size = gridDim.x;
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
DType p = (min_p_arr == nullptr) ? min_p_val : min_p_arr[bx];
Expand All @@ -472,9 +467,6 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);

vec_t<DType, VEC_SIZE> probs_vec;
DType aggregate;
DType q = DType(1);
DType pivot = DType(0);

DType max_p = 0;
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
Expand All @@ -495,70 +487,50 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
temp_storage.block_aggregate.max_p = max_p;
}
__syncthreads();
DType scaled_p = temp_storage.block_aggregate.max_p * p;
DType pivot = temp_storage.block_aggregate.max_p * p;

IdType sampled_id;
for (uint32_t round = 0; round < max_min_p_rounds; ++round) {
temp_storage.sampled_id = d - 1;
__syncthreads();
DType u = uniform_samples[round * batch_size + bx] * q;
aggregate = DType(0);
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(DType(0));
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
}

DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM,
DETERMINISTIC, DType>(i, d, pivot, u, probs_vec, aggregate,
&temp_storage);
if (aggregate > u) {
break;
}
}
__syncthreads();
sampled_id = temp_storage.sampled_id;
pivot = max(pivot, probs[bx * d + sampled_id]);
if (pivot >= scaled_p) {
break;
DType aggregate_gt_pivot = DType(0);
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(DType(0));
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
}

DType aggregate_gt_pivot = DType(0);
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(DType(0));
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
}

DType probs_gt_pivot[VEC_SIZE];
DType probs_gt_pivot[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
probs_gt_pivot[j] = (probs_vec[j] > pivot) ? probs_vec[j] : DType(0);
}
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
probs_gt_pivot[j] = (probs_vec[j] >= pivot) ? probs_vec[j] : DType(0);
}

aggregate_gt_pivot += BlockReduce<DType, BLOCK_THREADS>(temp_storage.block_prim.reduce)
.Sum<VEC_SIZE>(probs_gt_pivot);
if (tx == 0) {
temp_storage.block_aggregate.value = aggregate_gt_pivot;
}
__syncthreads();
aggregate_gt_pivot += BlockReduce<DType, BLOCK_THREADS>(temp_storage.block_prim.reduce)
.Sum<VEC_SIZE>(probs_gt_pivot);
if (tx == 0) {
temp_storage.block_aggregate.value = aggregate_gt_pivot;
}
q = temp_storage.block_aggregate.value;
__syncthreads();
}

DType aggregate(0);
DType q = temp_storage.block_aggregate.value;

IdType sampled_id;
temp_storage.sampled_id = d - 1;
__syncthreads();
if (tx == 0) {
output[bx] = sampled_id;
if (pivot < scaled_p) {
// failed to sample within MAX_ROUNDS
if (success != nullptr) {
success[bx] = false;
}
} else {
if (success != nullptr) {
success[bx] = true;
}
DType u = uniform_samples[bx] * q;
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(DType(0));
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
}

DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM, DETERMINISTIC,
DType>(
i, d, [&](DType x) { return x >= pivot; }, u, probs_vec, aggregate, &temp_storage);
if (aggregate > u) {
break;
}
}
output[bx] = temp_storage.sampled_id;
}

template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
Expand Down Expand Up @@ -596,8 +568,8 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, DType* uniform_samp
}

DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM,
DETERMINISTIC, DType>(i, d, pivot, u, probs_vec, aggregate,
&temp_storage);
DETERMINISTIC, DType>(
i, d, [&](DType x) { return x > pivot; }, u, probs_vec, aggregate, &temp_storage);
if (aggregate > u) {
break;
}
Expand Down Expand Up @@ -749,16 +721,15 @@ cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b

template <typename T, typename IdType>
cudaError_t MinPSamplingFromProb(T* probs, T* uniform_samples, T* min_p_arr, IdType* output,
bool* success, uint32_t batch_size, float min_p_val, uint32_t d,
uint32_t max_rounds, bool deterministic, cudaStream_t stream = 0) {
uint32_t batch_size, float min_p_val, uint32_t d,
bool deterministic, cudaStream_t stream = 0) {
constexpr uint32_t BLOCK_THREADS = 1024;
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);

const uint32_t smem_size = sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&probs, &uniform_samples, &min_p_arr, &output,
&success, &min_p_val, &d, &max_rounds};
void* args[] = {&probs, &uniform_samples, &min_p_arr, &output, &min_p_val, &d};

DISPATCH_ALIGNED_VEC_SIZE(
vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
Expand Down Expand Up @@ -1350,8 +1321,9 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token
}

DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM, DETERMINISTIC,
DType>(i, d, DType(0), u, relu_q_minus_p_vec, aggregate_relu_q_minus_p,
&temp_storage);
DType>(
i, d, [&](DType x) { return x > 0; }, u, relu_q_minus_p_vec, aggregate_relu_q_minus_p,
&temp_storage);
if (aggregate_relu_q_minus_p > u) {
break;
}
Expand Down
Loading

0 comments on commit 0f80329

Please sign in to comment.