diff --git a/csrc/utils/cp_balance/cp_balance/context_parallel_utils.py b/csrc/utils/cp_balance/cp_balance/context_parallel_utils.py new file mode 100644 index 00000000000..b94b1675aef --- /dev/null +++ b/csrc/utils/cp_balance/cp_balance/context_parallel_utils.py @@ -0,0 +1,869 @@ +# !/usr/bin/env python3 + +""" +Context Parallel FlashMask Attention Implementation + +This module provides context parallel implementation of flashmask attention with load balancing +using DualChunkSwap strategy. Context parallelism partitions tensors along the sequence +dimension to enable long-context LLMs in a distributed fashion. +""" + +import paddle +import paddle.nn.functional as F +from paddle import _C_ops +from paddle import distributed as dist +from paddle.distributed import fleet +from paddle.nn.functional.flash_attention import flashmask_attention +from paddle.autograd.py_layer import PyLayer +import numpy as np + +def scatter_balance(input_tensor, group=None, axis=0, mode="dual_chunk", buckets=None): + """ + Evenly split input tensor along the specified axis across model parallel ranks. + + This function implements balanced scattering by taking chunks from both ends + of the tensor to ensure load balancing across ranks. + + Args: + input_tensor (paddle.Tensor): Input tensor to be scattered + group (paddle.distributed.Group, optional): Communication group. + If None, uses model parallel group from fleet + axis (int, optional): Axis along which to scatter. Defaults to 0 + + Returns: + paddle.Tensor: Scattered tensor chunk for current rank + + Note: + This API is different from distributed.scatter - it performs balanced + splitting by taking chunks from both ends of the sequence. + """ + if group is None: + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + + parallelism = group.nranks + if parallelism == 1: + return input_tensor.clone() + + rank = group.rank + seq_len = input_tensor.shape[axis] + + if(mode == "dual_chunk"): + # Ensure sequence length is divisible by parallelism * 2 for balanced splitting + assert ( + seq_len % (parallelism * 2) == 0 + ), f"Input sequence length {seq_len} can't be divided exactly by sequence parallelism * 2 {parallelism * 2}" + + interval = seq_len // parallelism // 2 + total_len = input_tensor.shape[axis] + + # Take chunk from the beginning + chunk_start = paddle.slice(input_tensor, axes=[axis], starts=[interval * rank], ends=[interval * (rank + 1)]) + + # Take chunk from the end (in reverse order) + chunk_end = paddle.slice( + input_tensor, axes=[axis], starts=[total_len - interval * (rank + 1)], ends=[total_len - interval * rank] + ) + + # Concatenate chunks + result = paddle.concat([chunk_start, chunk_end], axis=axis) + elif(mode == "balanced_swap"): + assert buckets is not None, "buckets must be provided when mode is balanced_swap" + assert len(buckets) == parallelism, "buckets should have same size as parallelism" + assert seq_len % (parallelism * len(buckets[rank])) == 0, "seq_len must be divisible by parallelism * len(buckets[rank])" + local_chunks = [] + balance_chunksize = seq_len // (parallelism * len(buckets[rank])) + for(_, idx) in buckets[rank]: + # 切分轴的start和end + chunk_start = idx * balance_chunksize + chunk_end = (idx + 1) * balance_chunksize + chunk = paddle.slice(input_tensor, axes=[axis], starts=chunk_start, ends=chunk_end) + local_chunks.append(chunk) + result = paddle.concat(local_chunks, axis=axis) + # Use assign to free the memory of the whole input tensor to avoid OOM + # since slice uses stride and maintains reference to original tensor + result = paddle.assign(result) + return result + +def all_gather_order(input_tensor, group=None, axis=0): + """ + All-gather operation to reconstruct the original tensor from ordered scattered chunks. + + Args: + input_tensor (paddle.Tensor): Input tensor chunk + group (paddle.distributed.Group, optional): Communication group + axis (int, optional): Axis along which to gather. Defaults to 0 + + Returns: + paddle.Tensor: Reconstructed full tensor + """ + if group is None: + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + + parallelism = group.nranks + if parallelism == 1: + return input_tensor.clone() + + # Create a list to hold gathered chunks + gathered_list = [paddle.empty(input_tensor.shape, dtype=input_tensor.dtype) for _ in range(parallelism)] + dist.stream.all_gather(gathered_list, input_tensor, group=group, use_calc_stream=True) + + # Concatenate in order + result = paddle.concat(gathered_list, axis=axis) + return result + +def all_gather_balance(input_tensor, group=None, axis=0, mode="dual_chunk", buckets=None): + """ + All-gather operation with balanced reconstruction. + + This function performs all-gather to reconstruct the original tensor + from balanced scattered chunks. + + Args: + input_tensor (paddle.Tensor): Input tensor chunk + group (paddle.distributed.Group, optional): Communication group + axis (int, optional): Axis along which to gather. Defaults to 0 + + Returns: + paddle.Tensor: Reconstructed full tensor + """ + if group is None: + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + + parallelism = group.nranks + if parallelism == 1: + return input_tensor.clone() + + if(mode == "dual_chunk"): + # Split input into two halves (start and end chunks) + chunk_start, chunk_end = paddle.split(input_tensor, 2, axis=axis) + + if axis == 0: + # Handle axis=0 case with optimized memory layout + output_shape_start = list(chunk_start.shape) + output_shape_start[axis] = output_shape_start[axis] * parallelism + + gathered_start = paddle.empty(shape=output_shape_start, dtype=input_tensor.dtype) + dist.stream.all_gather(gathered_start, chunk_start, group=group, use_calc_stream=True) + + # Gather end chunks + gathered_end_list = [paddle.empty(chunk_end.shape, dtype=input_tensor.dtype) for _ in range(parallelism)] + dist.stream.all_gather(gathered_end_list, chunk_end, group=group, use_calc_stream=True) + + # Reverse the end chunks to reconstruct original order + gathered_end_list = gathered_end_list[::-1] + + result = paddle.concat([gathered_start] + gathered_end_list, axis=axis) + return result + else: + # Handle other axes + gathered_start_list = [paddle.empty(chunk_start.shape, dtype=input_tensor.dtype) for _ in range(parallelism)] + dist.stream.all_gather(gathered_start_list, chunk_start, group=group, use_calc_stream=True) + + gathered_end_list = [paddle.empty(chunk_end.shape, dtype=input_tensor.dtype) for _ in range(parallelism)] + dist.stream.all_gather(gathered_end_list, chunk_end, group=group, use_calc_stream=True) + + # Reverse the end chunks + gathered_end_list = gathered_end_list[::-1] + + result = paddle.concat(gathered_start_list + gathered_end_list, axis=axis) + return result + elif(mode == "balanced_swap"): + assert buckets is not None, "buckets must be provided when mode is balanced_swap" + assert len(buckets) == parallelism, "buckets should have same size as parallelism" + chunk_shape = input_tensor.shape + chunk_size = chunk_shape[axis] // len(buckets[0]) + gathered_list = [paddle.empty(chunk_shape, dtype=input_tensor.dtype) for _ in range(parallelism)] + dist.stream.all_gather(gathered_list, input_tensor, group=group, use_calc_stream=True) + + total_shape = chunk_shape[:axis] + (chunk_shape[axis] * parallelism,) + chunk_shape[axis + 1 :] + gathered_tensor = paddle.zeros(total_shape, dtype=input_tensor.dtype) + for j in range(parallelism): + for k in range(len(buckets[j])): + _, idx = buckets[j][k] + start_idx_total = idx * chunk_size + end_idx_total = (idx + 1) * chunk_size + slices_total = [slice(None)] * len(gathered_tensor.shape) + slices_total[axis] = slice(start_idx_total, end_idx_total) + + slices_chunk = [slice(None)] * len(gathered_list[j].shape) + slices_chunk[axis] = slice(k*chunk_size, (k+1)*chunk_size) + gathered_tensor[tuple(slices_total)] = gathered_list[j][tuple(slices_chunk)] + result = gathered_tensor + return result + + +def reduce_scatter_any_axis(input_tensor, axis, group=None): + """ + Reduce-scatter operation along any axis. + + Performs element-wise reduction (sum) across ranks and scatters the result + so each rank gets a portion of the reduced tensor. + + Args: + input_tensor (paddle.Tensor): Input tensor to reduce and scatter + axis (int): Axis along which to perform reduce-scatter + group (paddle.distributed.Group, optional): Communication group + + Returns: + paddle.Tensor: Reduced and scattered tensor chunk + """ + if group is None: + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_context_parallel_group() + + parallelism = group.nranks + if parallelism == 1: + return input_tensor.clone() + + assert input_tensor.shape[axis] % parallelism == 0, ( + f"Input sequence length {input_tensor.shape[axis]} can't be ", + f"divided exactly by context parallelism {parallelism}", + ) + + if axis == 0: + # Optimized path for axis=0 + output_shape = list(input_tensor.shape) + output_shape[0] = output_shape[0] // parallelism + + output = paddle.empty(shape=output_shape, dtype=input_tensor.dtype) + dist.stream.reduce_scatter(output, input_tensor, op=dist.ReduceOp.SUM, group=group, use_calc_stream=True) + return output + else: + # General case for other axes using alltoall + input_chunks = paddle.split(input_tensor, parallelism, axis=axis) + + output_buffers = [paddle.empty(input_chunks[0].shape, dtype=input_tensor.dtype) for _ in range(parallelism)] + + dist.stream.alltoall(output_buffers, input_chunks, group=group, use_calc_stream=True) + + # Sum the received chunks + result = paddle.stack(output_buffers, axis=0).sum(axis=0) + return result + + +def reduce_scatter_any_axis_balance(input_tensor, axis, group=None): + """ + Balanced reduce-scatter operation along any axis. + + Similar to reduce_scatter_any_axis but uses balanced splitting strategy + by processing chunks from both ends of the tensor. + + Args: + input_tensor (paddle.Tensor): Input tensor to reduce and scatter + axis (int): Axis along which to perform reduce-scatter + group (paddle.distributed.Group, optional): Communication group + + Returns: + paddle.Tensor: Reduced and scattered tensor chunk with balanced distribution + """ + if group is None: + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_context_parallel_group() + + parallelism = group.nranks + if parallelism == 1: + return input_tensor.clone() + + assert input_tensor.shape[axis] % (parallelism * 2) == 0, ( + f"Input sequence length {input_tensor.shape[axis]} can't be ", + f"divided exactly by context parallelism * 2 {parallelism * 2}", + ) + + # Split input into two halves + input_start, input_end = paddle.split(input_tensor, 2, axis=axis) + + # Split each half across ranks + chunks_start = paddle.split(input_start, parallelism, axis=axis) + chunks_end = paddle.split(input_end, parallelism, axis=axis) + + # Reverse end chunks for balanced distribution + chunks_end = chunks_end[::-1] + + # Combine corresponding start and end chunks + combined_chunks = [ + paddle.concat([start_chunk, end_chunk], axis=axis) for start_chunk, end_chunk in zip(chunks_start, chunks_end) + ] + + # Perform alltoall communication + output_buffers = [paddle.empty(combined_chunks[0].shape, dtype=input_tensor.dtype) for _ in range(parallelism)] + + dist.stream.alltoall(output_buffers, combined_chunks, group=group, use_calc_stream=True) + + # Sum the received chunks + result = paddle.stack(output_buffers, axis=0).sum(axis=0) + return result + +def reduce_scatter_any_axis_order(input_tensor, axis, group=None): + """ + Ordered reduce-scatter operation along any axis. + + Splits the input tensor sequentially along the specified axis, + sends corresponding chunks to each rank, and sums the received chunks. + + Args: + input_tensor (paddle.Tensor): Input tensor to reduce and scatter + axis (int): Axis along which to perform reduce-scatter + group (paddle.distributed.Group, optional): Communication group + + Returns: + paddle.Tensor: Reduced and scattered tensor chunk with ordered distribution + """ + if group is None: + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_context_parallel_group() + + parallelism = group.nranks + if parallelism == 1: + return input_tensor.clone() + + assert input_tensor.shape[axis] % parallelism == 0, ( + f"Input sequence length {input_tensor.shape[axis]} can't be " + f"divided exactly by context parallelism {parallelism}", + ) + + # Split input into chunks along the axis (sequentially, no reverse) + chunks = paddle.split(input_tensor, parallelism, axis=axis) + + # Perform alltoall communication + output_buffers = [paddle.empty(chunks[0].shape, dtype=input_tensor.dtype) for _ in range(parallelism)] + dist.stream.alltoall(output_buffers, chunks, group=group, use_calc_stream=True) + + # Sum the received chunks + result = paddle.stack(output_buffers, axis=0).sum(axis=0) + return result + +class ContextParallelScatterOp(PyLayer): + """ + Context parallel scatter operation using PyLayer for automatic differentiation. + + Forward: Scatter input tensor using balanced splitting + Backward: All-gather gradients using balanced reconstruction + """ + + @staticmethod + def forward(ctx, input_tensor, axis=0): + """ + Forward pass: scatter input tensor across context parallel ranks. + + Args: + ctx: Context object for saving information for backward pass + input_tensor (paddle.Tensor): Input tensor to scatter + axis (int): Axis along which to scatter + + Returns: + paddle.Tensor: Scattered tensor chunk + """ + ctx.axis = axis + hcg = fleet.get_hybrid_communicate_group() + + assert hcg.get_context_parallel_world_size() > 1, ( + f"ScatterOpCP must be used with context parallel, ", + f"context_parallel_world_size={hcg.get_context_parallel_world_size()}", + ) + + group = hcg.get_context_parallel_group() + ctx.group = group + + return scatter_balance(input_tensor, axis=axis, group=group) + + @staticmethod + def backward(ctx, grad_output): + """ + Backward pass: all-gather gradients. + + Args: + ctx: Context object with saved information + grad_output (paddle.Tensor): Gradient of output + + Returns: + tuple: Gradients for input arguments + """ + grad_input = all_gather_balance(grad_output, axis=ctx.axis, group=ctx.group) + return grad_input + + +class ContextParallelGatherOp(PyLayer): + """ + Context parallel gather operation using PyLayer for automatic differentiation. + + Forward: All-gather input tensor using balanced reconstruction + Backward: Scatter gradients using balanced splitting + """ + + @staticmethod + def forward(ctx, input_tensor, axis=0): + """ + Forward pass: all-gather input tensor across context parallel ranks. + + Args: + ctx: Context object for saving information for backward pass + input_tensor (paddle.Tensor): Input tensor to gather + axis (int): Axis along which to gather + + Returns: + paddle.Tensor: Gathered full tensor + """ + ctx.axis = axis + hcg = fleet.get_hybrid_communicate_group() + + assert hcg.get_context_parallel_world_size() > 1, ( + f"GatherOpCP must be used with context parallel, ", + f"context_parallel_world_size={hcg.get_context_parallel_world_size()}", + ) + + group = hcg.get_context_parallel_group() + ctx.group = group + + return all_gather_balance(input_tensor, axis=axis, group=group) + + @staticmethod + def backward(ctx, grad_output): + """ + Backward pass: scatter gradients. + + Args: + ctx: Context object with saved information + grad_output (paddle.Tensor): Gradient of output + + Returns: + tuple: Gradients for input arguments + """ + grad_input = scatter_balance(grad_output, axis=ctx.axis, group=ctx.group) + return grad_input + + +class ContextParallelAllGatherOp(PyLayer): + """ + Context parallel all-gather operation with gradient reduction. + + Forward: All-gather input tensor (e.g., [batch, seq_len/n, hidden] -> [batch, seq_len, hidden]) + Backward: Reduce-scatter gradients with balanced distribution + + This operation is similar to AllGatherOp but maintains context parallel state + after gradient aggregation. + """ + + @staticmethod + def forward(ctx, input_tensor, axis): + """ + Forward pass: all-gather input tensor. + + Args: + ctx: Context object for saving information + input_tensor (paddle.Tensor): Input tensor with shape [batch, seq_len/n, hidden] + axis (int): Axis along which to gather + + Returns: + paddle.Tensor: Gathered tensor with shape [batch, seq_len, hidden] + """ + ctx.axis = axis + hcg = fleet.get_hybrid_communicate_group() + + assert hcg.get_context_parallel_world_size() > 1, ( + f"AllGatherOpCP must be used with context parallel, ", + f"context_parallel_world_size={hcg.get_context_parallel_world_size()}", + ) + + group = hcg.get_context_parallel_group() + ctx.group = group + + return all_gather_balance(input_tensor, axis=axis, group=group) + + @staticmethod + def backward(ctx, grad_output): + """ + Backward pass: reduce-scatter gradients. + + Args: + ctx: Context object with saved information + grad_output (paddle.Tensor): Gradient with shape [batch, seq_len, hidden] + + Returns: + tuple: Gradients with shape [batch, seq_len/n, hidden] + """ + grad_input = reduce_scatter_any_axis_balance(grad_output, axis=ctx.axis, group=ctx.group) + return grad_input + + +def preprocess_index(startend_row_indices, chunk_id, seq_blocksize, max_seqlen_q): + """ + Preprocess startend row indices for a single chunk. + + Adjusts the startend_row_indices relative to the chunk's starting position and + clips them to valid range. + + Args: + startend_row_indices (paddle.Tensor): Original startend row indices + chunk_id (int): ID of the current chunk + seq_blocksize (int): Size of each sequence block + max_seqlen_q (int): Maximum sequence length for queries + + Returns: + paddle.Tensor: Preprocessed row indices + """ + rows_min = chunk_id * seq_blocksize + adjusted_indices = startend_row_indices - rows_min + clipped_indices = paddle.clip(adjusted_indices, min=0, max=max_seqlen_q) + return clipped_indices + + +def preprocess_index_dual_chunks(startend_row_indices, chunk_id_first, chunk_id_second, seq_blocksize, max_seqlen_q): + """ + Preprocess row indices for dual chunks (DualChunkSwap strategy). + + This function handles the index preprocessing for the balanced dual-chunk + strategy where each rank processes chunks from both ends of the sequence. + + Args: + startend_row_indices (paddle.Tensor): Original row indices + chunk_id_first (int): ID of the first chunk + chunk_id_second (int): ID of the second chunk + seq_blocksize (int): Size of each sequence block + max_seqlen_q (int): Maximum sequence length for queries + + Returns: + paddle.Tensor: Preprocessed row indices for dual chunks + """ + # Calculate starting positions for both chunks + rows_min_first = chunk_id_first * seq_blocksize + rows_min_second = chunk_id_second * seq_blocksize + + # Process first chunk indices + indices_first = startend_row_indices - rows_min_first + indices_first = paddle.clip(indices_first, min=0, max=max_seqlen_q) + + # Process second chunk indices + indices_second = startend_row_indices - rows_min_second + indices_second = paddle.clip(indices_second, min=0, max=max_seqlen_q) + + # Offset second chunk indices to avoid overlap + indices_second = paddle.where(indices_second != 0, indices_second + max_seqlen_q, indices_second) + + # Combine indices from both chunks + combined_indices = paddle.maximum(indices_first, indices_second) + return combined_indices + + +def cp_flashmask_allgatherkv_balance_forward(query, key, value, startend_row_indices, group, causal, is_training, mode): + """ + Forward pass of context parallel flashmask attention with balanced all-gather strategy. + + This function implements the forward pass of flash attention with context parallelism + using the DualChunkSwap strategy for load balancing. + + Args: + query (paddle.Tensor): Query tensor with shape [batch, seq_len/n, num_heads, head_dim] + key (paddle.Tensor): Key tensor with shape [batch, seq_len/n, num_heads, head_dim] + value (paddle.Tensor): Value tensor with shape [batch, seq_len/n, num_heads, head_dim] + startend_row_indices (paddle.Tensor): Row indices for attention mask + group (paddle.distributed.Group): Communication group + causal (bool): Whether to use causal attention + is_training (bool): Whether in training mode + + Returns: + tuple: (output, log_sum_exp, processed_indices) + """ + paddle.base.core.nvprof_nvtx_push("cp_flashmask_allgatherkv_balance_forward") + + rank = group.rank + cp_size = group.world_size + + # All-gather key tensors across context parallel ranks + if(mode == "allgather_kv"): + key_gathered = all_gather_balance(key, axis=1, group=group) + + # All-gather value tensors across context parallel ranks + value_gathered = all_gather_balance(value, axis=1, group=group) + + # Calculate sequence block size for dual-chunk strategy + seq_blocksize = query.shape[1] // 2 + + # Preprocess indices for dual-chunk strategy + startend_row_indices = preprocess_index_dual_chunks( + startend_row_indices, + chunk_id_first=rank, + chunk_id_second=2 * cp_size - rank - 1, + seq_blocksize=seq_blocksize, + max_seqlen_q=seq_blocksize, + ) + elif(mode == "balance_q"): + key_gathered = all_gather_order(key, axis=1, group=group) + + value_gathered = all_gather_order(value, axis=1, group=group) + + + # Perform flashmask attention with startend_row_indices + output, log_sum_exp = flashmask_attention( + query, + key_gathered, + value_gathered, + startend_row_indices=startend_row_indices, + causal=causal, + return_softmax_lse=True, + training=is_training, + ) + + paddle.base.core.nvprof_nvtx_pop() + return output, log_sum_exp, startend_row_indices + + +def cp_flashmask_allgatherkv_balance_backward( + query, key, value, startend_row_indices, output, log_sum_exp, output_grad, group, causal, mode +): + """ + Backward pass of context parallel flashmask attention with balanced all-gather strategy. + + This function implements the backward pass of flashmask attention with context parallelism, + computing gradients for query, key, and value tensors. + + Args: + query (paddle.Tensor): Query tensor + key (paddle.Tensor): Key tensor + value (paddle.Tensor): Value tensor + startend_row_indices (paddle.Tensor): Processed startend_row_indices + output (paddle.Tensor): Forward pass output + log_sum_exp (paddle.Tensor): Log-sum-exp from forward pass + output_grad (paddle.Tensor): Gradient of output + group (paddle.distributed.Group): Communication group + causal (bool): Whether causal attention was used + + Returns: + tuple: (query_grad, key_grad, value_grad) + """ + paddle.base.core.nvprof_nvtx_push("cp_flashmask_allgatherkv_balance_backward") + + cp_size = group.world_size + + if(mode == "allgather_kv"): + # All-gather key and value tensors (same as forward pass) + key_gathered = all_gather_balance(key, axis=1, group=group) + value_gathered = all_gather_balance(value, axis=1, group=group) + elif(mode == "balance_q"): + key_gathered = all_gather_order(key, axis=1, group=group) + value_gathered = all_gather_order(value, axis=1, group=group) + else: + raise NotImplementedError + + if paddle.get_flags(["FLAGS_cudnn_deterministic"])["FLAGS_cudnn_deterministic"]: + fa_version = 2 + else: + fa_version = paddle.base.framework.get_flags(["FLAGS_flash_attn_version"])["FLAGS_flash_attn_version"] + if fa_version == 2: + # Create seed offset tensor (required for gradient computation) + seed_offset = paddle.zeros(shape=[query.shape[1], query.shape[2]], dtype=paddle.int64) + + # Compute gradients using flashmask attention backward pass + query_grad, key_grad_gathered, value_grad_gathered = paddle._C_ops.flashmask_attention_grad( + query, + key_gathered, + value_gathered, + startend_row_indices, + output, + log_sum_exp, + seed_offset, + output_grad, + 0.0, # dropout probability + causal, + ) + elif fa_version == 3: + query_grad, key_grad_gathered, value_grad_gathered = paddle._C_ops.flashmask_attention_v2_grad( + query, + key_gathered, + value_gathered, + output, + log_sum_exp, + startend_row_indices, + None, + output_grad, + query.shape[-1] ** (-0.5), + False, + ) + else: + raise ValueError(f"FlashAttention version {fa_version} is not supported.") + + # Reduce-scatter key and value gradients + if(mode == "allgather_kv"): + key_grad = reduce_scatter_any_axis_balance(key_grad_gathered, axis=1, group=group) + value_grad = reduce_scatter_any_axis_balance(value_grad_gathered, axis=1, group=group) + elif(mode == "balance_q"): + key_grad = reduce_scatter_any_axis_order(key_grad_gathered, axis=1, group=group) + value_grad = reduce_scatter_any_axis_order(value_grad_gathered, axis=1, group=group) + else: + raise NotImplementedError + + paddle.base.core.nvprof_nvtx_pop() + return query_grad, key_grad, value_grad + +class FlashMaskContextParallel(PyLayer): + """ + FlashMask attention with context parallelism implementation. + + This class implements flashmask attention with context parallelism (CP) using PyLayer + for automatic differentiation. CP partitions tensors along the sequence dimension + to enable long-context LLMs in a distributed fashion. + + The implementation uses the DualChunkSwap strategy to ensure load balancing + across CP ranks by processing chunks from both ends of the sequence. + """ + + @staticmethod + def forward( + ctx, + query, + key, + value, + startend_row_indices, + fixed_seed_offset=None, + dropout=0.0, + causal=False, + training=True, + mode="allgather_kv" + ): + """ + Forward pass of FlashMask attention with context parallelism. + + Args: + ctx: Context object for saving information for backward pass + query (paddle.Tensor): Query tensor, pre-divided by CP size + key (paddle.Tensor): Key tensor, pre-divided by CP size + value (paddle.Tensor): Value tensor, pre-divided by CP size + startend_row_indices (paddle.Tensor): Row indices for attention mask + fixed_seed_offset (paddle.Tensor, optional): Fixed seed offset for dropout + dropout (float): Dropout probability + causal (bool): Whether to use causal attention + training (bool): Whether in training mode + mode (str): Attention mode, currently supports "allgather_kv" + + Returns: + paddle.Tensor: Attention output + + Raises: + NotImplementedError: If dropout > 0.0 or causal=True + AssertionError: If query sequence length is not divisible by 2 + """ + # Validate input parameters + if dropout > 0.0: + raise NotImplementedError("Dropout is not supported in FlashMask context parallel yet.") + + if causal: + raise NotImplementedError("FlashMaskContextParallel does not support causal=True yet.") + + if fixed_seed_offset is not None: + raise NotImplementedError("Fixed seed offset is not supported yet.") + + # Get communication group + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_context_parallel_group() + + # Validate query sequence length for DualChunkSwap strategy + assert query.shape[1] % 2 == 0, ( + f"Query sequence length must be divisible by 2. " + f"FlashMaskContextParallel uses DualChunkSwap strategy for load balancing. " + f"Current query sequence length: {query.shape[1]}" + ) + + # Perform forward pass + output, log_sum_exp, startend_row_indices = cp_flashmask_allgatherkv_balance_forward( + query, key, value, startend_row_indices, group, causal, training, mode + ) + + # Save tensors for backward pass + ctx.save_for_backward(query, key, value, output, log_sum_exp, startend_row_indices) + # rank1 = paddle.distributed.get_rank() + # paddle.save(startend_row_indices,f'/root/paddlejob/workspace/env_run/xiehaoyang/flashmask/flashmask-cp/test_indices/startend_row_indices_balance_{rank1}.pd') + ctx.group = group + ctx.causal = causal + ctx.mode = mode + + return output + + @staticmethod + def backward(ctx, output_grad): + """ + Backward pass of FlashMask attention with context parallelism. + + Args: + ctx: Context object with saved information + output_grad (paddle.Tensor): Gradient of output + + Returns: + tuple: Gradients for all input arguments + """ + # Retrieve saved tensors + query, key, value, output, log_sum_exp, startend_row_indices = ctx.saved_tensor() + group = ctx.group + causal = ctx.causal + mode = ctx.mode + + # Compute gradients + query_grad, key_grad, value_grad = cp_flashmask_allgatherkv_balance_backward( + query, key, value, startend_row_indices, output, log_sum_exp, output_grad, group, causal, mode + ) + + return query_grad, key_grad, value_grad + + +def flashmask_attention_cp( + query, + key, + value, + startend_row_indices, + fixed_seed_offset=None, + dropout=0.0, + causal=False, + training=True, + mode="allgather_kv" +): + """ + FlashMask attention with context parallelism - public API. + + This is the main entry point for using FlashMask attention with context parallelism. + It provides a convenient interface that wraps the FlashMaskContextParallel PyLayer. + + Args: + query (paddle.Tensor): Query tensor with shape [batch, seq_len/n, num_heads, head_dim] + key (paddle.Tensor): Key tensor with shape [batch, seq_len/n, num_heads, head_dim] + value (paddle.Tensor): Value tensor with shape [batch, seq_len/n, num_heads, head_dim] + startend_row_indices (paddle.Tensor): Row indices for attention mask + fixed_seed_offset (paddle.Tensor, optional): Fixed seed offset for dropout + dropout (float, optional): Dropout probability. Defaults to 0.0 + causal (bool, optional): Whether to use causal attention. Defaults to False + training (bool, optional): Whether in training mode. Defaults to True + mode (str, optional): Attention mode. Defaults to "allgather_kv" + + Returns: + paddle.Tensor: Attention output with shape [batch, seq_len/n, num_heads, head_dim] + + Example: + ```python + # Initialize tensors (assuming context parallelism is set up) + query = paddle.randn([2, 512, 8, 64]) # [batch, seq_len/n, heads, head_dim] + key = paddle.randn([2, 512, 8, 64]) # [batch, seq_len/n, heads, head_dim] + value = paddle.randn([2, 512, 8, 64]) # [batch, seq_len/n, heads, head_dim] + mask_indices = paddle.randint(0, 1024, [100, 2]) + + # Apply FlashMask attention with context parallelism + output = flashmask_attention_cp( + query=query, + key=key, + value=value, + startend_row_indices=mask_indices, + training=True + ) + ``` + """ + output = FlashMaskContextParallel.apply( + query, + key, + value, + startend_row_indices, + fixed_seed_offset, + dropout, + causal, + training, + mode + ) + return output \ No newline at end of file diff --git a/csrc/utils/cp_balance/cp_balance/cp_balance.py b/csrc/utils/cp_balance/cp_balance/cp_balance.py new file mode 100644 index 00000000000..4175262b327 --- /dev/null +++ b/csrc/utils/cp_balance/cp_balance/cp_balance.py @@ -0,0 +1,379 @@ +import heapq +import paddle +import numpy as np +from .cp_balance_cuda_kernels import scanMaxMinChunkedKernel, reduce_workload, indices_to_chunks_cuda, indices_rerank_cuda +import paddle.distributed as dist +import hashlib +from typing import List, Tuple, Dict, Optional + +# --- 调试辅助函数 --- + +def save_tensor(x: paddle.Tensor, name: str): + """将 Paddle Tensor 保存为 txt 文件,用于调试。""" + x_np = x.numpy() + np.savetxt(f'{name}.txt', x_np.reshape(-1, x_np.shape[-1]), fmt='%d') + +def tensor_md5(tensor: paddle.Tensor) -> str: + """计算 Paddle Tensor 的 MD5 哈希值,用于验证数据一致性。""" + x_bytes = tensor.numpy().tobytes() + md5_hash = hashlib.md5(x_bytes).hexdigest() + print(f"Tensor MD5: {md5_hash}") + return md5_hash + +# --- 核心工作负载计算与分配 --- + +def get_q_workload( + start_row_indices: paddle.Tensor, + q_chunk_size: int, + m_block_size: int, + n_block_size: int +) -> paddle.Tensor: + """ + 根据稀疏attention的起止索引,估算每个query chunk的计算负载。 + 这是负载均衡的第一步,目的是量化每个数据块的计算成本。 + + Args: + start_row_indices (paddle.Tensor): 形状为 [B, H, S, 2] 或 [B, H, S, 4] 的张量, + 表示每个 query token 需要计算的 key token 的起止范围。 + 维度4的顺序为 [LTS, LTE, UTS, UTE]。 + 维度2的顺序为 [LTS, UTE]。 + q_chunk_size (int): Query 侧进行负载均衡分析的块大小。 + m_block_size (int): FlashAttention kernel 中 query 侧的块大小 (Br)。 + n_block_size (int): FlashAttention kernel 中 key 侧的块大小 (Bc)。 + + Returns: + paddle.Tensor: 形状为 [1, H, Tchunks, 2] 的张量, + 其中 Tchunks 是 chunk 的数量。 + 每个 chunk 的信息为 [workload, original_index], + 表示该 chunk 的估算工作量和其原始索引。 + """ + assert start_row_indices is not None, "start_row_indices cannot be None" + assert q_chunk_size % m_block_size == 0, "q_chunk_size must be divisible by m_block_size" + + # 1. 解析输入的起止索引 + # start_row_indices 可能包含下三角(LT)和上三角(UT)的起止(Start/End)信息 + LTS, LTE, UTS, UTE = None, None, None, None + if start_row_indices.shape[-1] == 4: + LTS, LTE, UTS, UTE = paddle.split(start_row_indices, 4, axis=-1) + LTS, LTE, UTS, UTE = [t.squeeze(-1) for t in (LTS, LTE, UTS, UTE)] + elif start_row_indices.shape[-1] == 2: + LTS, UTE = paddle.split(start_row_indices, 2, axis=-1) + LTS, UTE = LTS.squeeze(-1), UTE.squeeze(-1) + + # 2. 获取维度信息 + # 从任意一个非None的张量中获取 Batch, Head, Sequence Length + valid_tensor = next(t for t in [LTS, LTE, UTS, UTE] if t is not None) + B, H, S = valid_tensor.shape + + # 计算块的数量 + Tr = S // m_block_size # Query 侧块总数 + Tc = S // n_block_size # Key 侧块总数 + Tchunks = S // q_chunk_size # 用于负载均衡的 chunk 总数 + assert Tr % Tchunks == 0, "Total row blocks must be divisible by total chunks" + blocks_per_chunk = Tr // Tchunks + + # 3. 使用自定义CUDA核预计算每个 Key 块内的索引最大/最小值 + # 这一步是关键优化,它将 O(S) 的扫描操作降维到 O(S/Bc), + # 极大地加速了后续工作负载的估算。 + def scan_max_min(tensor): + if tensor is not None: + return scanMaxMinChunkedKernel(tensor, n_block_size, B, H, S) + return None, None + + LTStartMax_gpu, LTStartMin_gpu = scan_max_min(LTS) + LTEndMax_gpu, LTEndMin_gpu = scan_max_min(LTE) + UTStartMax_gpu, UTStartMin_gpu = scan_max_min(UTS) + UTEndMax_gpu, UTEndMin_gpu = scan_max_min(UTE) + + # 4. 使用自定义CUDA核计算每个 Query 块的工作负载 + # 这个核模拟了 FlashAttention 的块状计算过程,但只计算需要被激活的块的数量, + # 而不是执行实际的矩阵乘法,从而高效地估算出工作负载。 + all_indices_max_min = [ + LTStartMax_gpu, LTStartMin_gpu, LTEndMax_gpu, LTEndMin_gpu, + UTStartMax_gpu, UTStartMin_gpu, UTEndMax_gpu, UTEndMin_gpu + ] + workload_per_block = reduce_workload(all_indices_max_min, B, H, Tr, Tc, m_block_size, S) + + # 5. 将每个块的工作负载聚合到 chunk 级别 + workload_grouped = workload_per_block.reshape([B, H, Tchunks, blocks_per_chunk, 1]) + workload_per_chunk = paddle.sum(workload_grouped, axis=3).sum(axis=0).reshape([1, H, Tchunks]) + + # 6. 准备最终输出,包含工作负载和原始索引 + final_res = paddle.zeros([1, H, Tchunks, 2], dtype='int32', device=start_row_indices.place) + final_res[:, :, :, 0] = workload_per_chunk + final_res[:, :, :, 1] = paddle.arange(0, Tchunks, dtype="int32") + + return final_res + + +def assign_tasks_heap( + tasks: np.ndarray, + num_buckets: int +) -> Tuple[List[List[Tuple[int, int]]], List[int], int]: + """ + 使用小顶堆的贪心算法,将带有权重和索引的任务列表分配到 M 个桶中, + 以实现负载均衡。 + + Args: + tasks (np.ndarray): 形状为 (N, 2) 的任务数组,每行是 [weight, index]。 + num_buckets (int): 桶的数量(通常等于通信组的 world size)。 + + Returns: + Tuple: + - buckets (List[List[Tuple[int, int]]]): 分配结果,每个子列表是一个桶的任务。 + - bucket_weights (List[int]): 每个桶的总权重。 + - cuts (int): 数据切分次数,衡量数据重排后的连续性。 + """ + n = len(tasks) + if n == 0: + return [[] for _ in range(num_buckets)], [0] * num_buckets, 0 + + # 每个桶的期望任务数量 + batch_size = n // num_buckets + + # 按权重降序排序任务,优先分配最重的任务 + tasks_sorted = sorted(tasks, key=lambda x: -x[0]) + + # 初始化桶和记录每个桶当前状态的变量 + buckets = [[] for _ in range(num_buckets)] + bucket_weights = [0] * num_buckets + bucket_counts = [0] * num_buckets + + # 初始化小顶堆,用于快速找到当前总权重最小的桶 + # 堆中元素为 (current_weight, bucket_index) + heap = [(0, i) for i in range(num_buckets)] + + # 贪心分配:依次将最重的任务分配给当前总权重最小的、且未满的桶 + for weight, idx in tasks_sorted: + # 找到一个可以放入任务的桶 + temp_popped = [] + found_bucket = False + while heap: + bucket_sum, bucket_idx = heapq.heappop(heap) + if bucket_counts[bucket_idx] < batch_size: + # 找到桶,更新状态并放回堆中 + buckets[bucket_idx].append((weight, idx)) + bucket_weights[bucket_idx] += weight + bucket_counts[bucket_idx] += 1 + heapq.heappush(heap, (bucket_weights[bucket_idx], bucket_idx)) + found_bucket = True + break + else: + # 该桶已满,暂存起来,继续寻找下一个 + temp_popped.append((bucket_sum, bucket_idx)) + + # 将之前因为满了而弹出的桶重新放回堆中 + for item in temp_popped: + heapq.heappush(heap, item) + + if not found_bucket: + # 如果所有桶都满了(通常在 n % num_buckets != 0 时发生) + # 将剩余的任务分配给当前总权重最小的桶 + bucket_sum, bucket_idx = heapq.heappop(heap) + buckets[bucket_idx].append((weight, idx)) + bucket_weights[bucket_idx] += weight + bucket_counts[bucket_idx] += 1 + heapq.heappush(heap, (bucket_weights[bucket_idx], bucket_idx)) + + + # (可选)按任务原始序号对每个桶内部进行排序,方便调试 + for i in range(num_buckets): + buckets[i] = sorted(buckets[i], key=lambda x: x[1]) + + # 统计切分次数:衡量重排后数据块的连续性 + all_assigned_indices = sorted([idx for bucket in buckets for _, idx in bucket]) + cuts = sum(1 for i in range(1, len(all_assigned_indices)) if all_assigned_indices[i] != all_assigned_indices[i-1] + 1) + + return buckets, bucket_weights, cuts + + +# --- 数据通信与重排辅助函数 --- + +def get_send_dict(buckets: List[List[Tuple[int, int]]], cp_size: int, rank: int) -> Dict[int, List[int]]: + """ + 根据负载均衡分配结果,为当前 rank 生成 all-to-all 通信的发送字典。 + + Args: + buckets (List): 所有 rank 的任务分配结果。 + cp_size (int): 通信组大小。 + rank (int): 当前进程的 rank。 + + Returns: + Dict[int, List[int]]: 发送字典。key 是目标 rank,value 是要发送给该 rank 的本地 chunk 索引列表。 + """ + send_dict = {i: [] for i in range(cp_size)} + # 遍历所有桶(即所有目标 rank 的任务列表) + for target_rank, bucket in enumerate(buckets): + for _, chunk_idx in bucket: + # 如果某个 chunk 的原始属主是当前 rank,则需要将其发送 + if chunk_idx // cp_size == rank: + # chunk_idx % cp_size 得到的是在当前 rank 上的局部索引 + send_dict[target_rank].append(chunk_idx % cp_size) + return send_dict + +def get_recv_dict(bucket: List[Tuple[int, int]], cp_size: int) -> Dict[int, List[int]]: + """ + 根据当前 rank 的任务分配结果,生成 all-to-all 通信的接收字典。 + + Args: + bucket (List): 当前 rank 分配到的任务列表。 + cp_size (int): 通信组大小。 + + Returns: + Dict[int, List[int]]: 接收字典。key 是源 rank,value 是从该 rank 接收的数据块 + 应该被放置到的本地位置索引列表。 + """ + recv_dict = {i: [] for i in range(cp_size)} + # 遍历分配给我的所有任务 + for local_pos, (_, chunk_idx) in enumerate(bucket): + # chunk_idx.item() // cp_size 得到的是这个 chunk 原始所在的 rank + source_rank = chunk_idx.item() // cp_size + recv_dict[source_rank].append(local_pos) + return recv_dict + +def balance_alltoall( + input_tensor: paddle.Tensor, + cp_size: int, + cp_group, + chunk_size: int, + send_dict: Dict[int, List[int]], + recv_dict: Dict[int, List[int]] +) -> paddle.Tensor: + """ + 执行 all-to-all 通信,根据 send/recv 字典对 `input_tensor` 进行数据重排。 + 此函数已重构,可统一处理不同维度的张量。 + + Args: + input_tensor (paddle.Tensor): 待重排的张量,如 Q, K, V。 + cp_size (int): 通信组大小。 + cp_group (dist.Group): Paddle 分布式通信组。 + chunk_size (int): 数据块的大小。 + send_dict (Dict): 发送字典。 + recv_dict (Dict): 接收字典。 + + Returns: + paddle.Tensor: 重排后的张量。 + """ + original_shape = input_tensor.shape + B, S = original_shape[0], original_shape[1] + + # 将输入张量统一 reshape 为 3D (B, S, -1) 以便统一处理 + tensor_3d = input_tensor.reshape((B, S, -1)) + HD = tensor_3d.shape[-1] + + # 1. 准备发送数据 (Gather) + # 根据 send_dict,从本地张量中收集需要发送给其他 rank 的数据块 + send_list = [] + for target_rank in range(cp_size): + indices_to_send = send_dict[target_rank] + if indices_to_send: + # 将所有要发往同一个 rank 的数据块拼接在一起 + data_to_send = paddle.concat( + [tensor_3d[:, idx * chunk_size:(idx + 1) * chunk_size, :] for idx in indices_to_send], + axis=1 + ) + send_list.append(data_to_send) + else: + # 注意:NCCL alltoall 不支持大小为 0 的张量,因此发送一个虚拟的、 + # 非常小的张量作为占位符。接收方也需对应接收。 + send_list.append(paddle.zeros((B, 1, HD), dtype=input_tensor.dtype)) + + # 2. 准备接收缓冲区 (Scatter) + # 根据 recv_dict,为从其他 rank 接收的数据准备相应大小的空缓冲区 + recv_list = [] + for source_rank in range(cp_size): + num_chunks_to_recv = len(recv_dict[source_rank]) + if num_chunks_to_recv > 0: + recv_list.append( + paddle.empty((B, chunk_size * num_chunks_to_recv, HD), dtype=input_tensor.dtype) + ) + else: + # 对应发送方的虚拟张量,接收一个同样大小的虚拟缓冲区 + recv_list.append(paddle.empty((B, 1, HD), dtype=input_tensor.dtype)) + + # 3. 执行 All-to-All 通信 + dist.alltoall(out_tensor_list=recv_list, in_tensor_list=send_list, group=cp_group) + + # 4. 将接收到的数据重新组装成最终张量 + final_res_3d = paddle.empty_like(tensor_3d) + for source_rank in range(cp_size): + local_positions = recv_dict[source_rank] + if local_positions: + received_data = recv_list[source_rank] + # 将从 source_rank 接收到的数据块,放置到它们在本地应该在的位置 + for i, local_pos in enumerate(local_positions): + start_s = local_pos * chunk_size + end_s = (local_pos + 1) * chunk_size + data_start = i * chunk_size + data_end = (i + 1) * chunk_size + final_res_3d[:, start_s:end_s, :] = received_data[:, data_start:data_end, :] + + # 恢复原始形状 + return final_res_3d.reshape(original_shape) + + +# --- 主流程函数 --- + +def balance_flashmask_input( + startend_row_indices: paddle.Tensor, + cp_size: int, + cp_rank: int, + balance_chunk_size: int = 2048, + q_block_size: int = 128, + k_block_size: int = 128 +) -> Tuple[paddle.Tensor, List[List[Tuple[int, int]]]]: + """ + FlashMask 输入数据的负载均衡主流程。 + 该函数协调整个过程:估算工作负载 -> 任务分配 -> 生成通信计划 -> 数据重排。 + + Args: + startend_row_indices (paddle.Tensor): 稀疏 attention 的原始起止索引。 + cp_size (int): 通信组大小。 + cp_rank (int): 当前进程的 rank。 + balance_chunk_size (int): 用于负载均衡分析和数据移动的块大小。 + q_block_size (int): FlashAttention kernel 的 query 块大小。 + k_block_size (int): FlashAttention kernel 的 key 块大小。 + + Returns: + Tuple: + - local_startend_row_indices (paddle.Tensor): 经过负载均衡和重排后, + 当前 rank 需要处理的局部起止索引。 + - buckets (List[List[Tuple[int, int]]]): 全局的任务分配方案,用于后续 + 对 Q, K, V 等张量进行同样的重排。 + """ + # 步骤 1: 估算每个 chunk 的工作负载 + paddle.base.core.nvprof_nvtx_push("get_q_workload") + workload = get_q_workload(startend_row_indices, balance_chunk_size, q_block_size, k_block_size) + paddle.base.core.nvprof_nvtx_pop() + + # 步骤 2: 使用堆贪心算法在 CPU 上进行任务分配 + paddle.base.core.nvprof_nvtx_push("assign_tasks_heap") + # 将 workload tensor 转换成 numpy 数组以用于 heapq + tasks_np = workload.reshape([-1, 2]).cpu().numpy() + buckets, _, _ = assign_tasks_heap(tasks_np, cp_size) + paddle.base.core.nvprof_nvtx_pop() + + # 步骤 5: 根据全局分配方案 `buckets`,对原始索引张量进行重排 (Gather) + # 这一步创建了一个全局视角下、数据块被重新排列后的 `startend_row_indices`。 + paddle.base.core.nvprof_nvtx_push("startend_row_indices_rerank") + # 将 `buckets` 展平,得到一个新的 chunk 顺序 + rerank_indices = np.array([idx for bucket in buckets for _, idx in bucket], dtype=np.int32) + indices_tensor = paddle.to_tensor(rerank_indices, dtype='int32', place=startend_row_indices.place) + + # 使用 CUDA 核高效地执行 gather 操作 + startend_row_indices_rerank = indices_rerank_cuda(startend_row_indices, indices_tensor) + paddle.base.core.nvprof_nvtx_pop() + + # 步骤 6: 从重排后的全局索引中,计算出当前 rank 的局部索引 (Localize) + # 这一步将全局索引(可能跨越整个序列长度S)转换为相对于本地数据块的局部索引。 + paddle.base.core.nvprof_nvtx_push("indices_to_chunks") + local_bucket_indices = [x[1] for x in buckets[cp_rank]] + local_indices_tensor = paddle.to_tensor(local_bucket_indices, dtype='int32', place=startend_row_indices.place) + + # 使用 CUDA 核高效地执行索引的 clipping 和 offsetting + local_startend_row_indices = indices_to_chunks_cuda( + startend_row_indices_rerank, local_indices_tensor, balance_chunk_size + ) + paddle.base.core.nvprof_nvtx_pop() + + return local_startend_row_indices, buckets \ No newline at end of file diff --git a/csrc/utils/cp_balance/cp_balance/cp_balance_cuda_kernels.py b/csrc/utils/cp_balance/cp_balance/cp_balance_cuda_kernels.py new file mode 100644 index 00000000000..8e34c008cd0 --- /dev/null +++ b/csrc/utils/cp_balance/cp_balance/cp_balance_cuda_kernels.py @@ -0,0 +1,54 @@ +import cupy as cp +import numpy as np +import paddle +from paddle.utils.cpp_extension import load +import flashmask_cpbalance_cudaops as cp_balance_ops + +def scanMaxMinChunkedKernel(input_tensor, Bc, B, H, S): + maxo,mino = cp_balance_ops.scan_max_min( + input_tensor, + H, + S, + S, + Bc, + False, + 0.0, + 0, + 0 + ) + + # 取出结果(假设每行只有一个warp,maxo[:, 0]) + # LTStartMax_gpu = cp.asnumpy(maxo[:, 0]) + # LTStartMin_gpu = cp.asnumpy(mino[:, 0]) + # print(maxo) + return maxo, mino + + +def reduce_workload(start_row_maxmin_indice_list, B, H, Tr, Tc, Br, S): + ( + LTStartMax, + LTStartMin, + LTEndMax, + LTEndMin, + UTStartMax, + UTStartMin, + UTEndMax, + UTEndMin, + ) = start_row_maxmin_indice_list + + workload = cp_balance_ops.reduce_workload( + LTStartMax, LTStartMin, LTEndMax, LTEndMin, UTStartMax, UTStartMin, UTEndMax, UTEndMin, + B, H, Tr, Tc, S, Br, False, 128 + ) + + return workload + +def indices_to_chunks_cuda(startend_row_indices, bucket_idx, chunksize=2048): + result = cp_balance_ops.indices_to_chunks(startend_row_indices, bucket_idx, chunksize) + return result + +def indices_rerank_cuda(startend_row_indices, indices, balance_chunk_size=2048): + B, H, S, D = startend_row_indices.shape + num_chunks = (S + balance_chunk_size - 1) // balance_chunk_size + startend_row_indices_rerank = cp_balance_ops.indices_rerank(startend_row_indices, indices, B, H, S,D,num_chunks,balance_chunk_size) + return startend_row_indices_rerank diff --git a/csrc/utils/cp_balance/csrc/cp_balance_utils.cu b/csrc/utils/cp_balance/csrc/cp_balance_utils.cu new file mode 100644 index 00000000000..74eb3239c60 --- /dev/null +++ b/csrc/utils/cp_balance/csrc/cp_balance_utils.cu @@ -0,0 +1,654 @@ +#include "paddle/extension.h" + +#define CHECK_CUDA_INPUT(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.") + +int get_kBlockN(int head_size_rounded, bool is_flashmask, bool is_causal, bool has_softcap, + bool is_local, int seqlen_q, int seqlen_k, bool has_lt_end, bool has_ut_start) { + if (head_size_rounded <= 64) { + if (is_flashmask && !is_causal) { + return 96; + } else if ((is_causal && has_softcap) || is_flashmask) { + return 128; + } else { + return 128; + } + } else if (head_size_rounded <= 128) { + if (is_causal || is_local || has_softcap) { + return 128; + } else { + if ((seqlen_q >= 1024 || seqlen_k >= 1024) && !(has_lt_end && has_ut_start)) { + return 128; + } else { + return 64; + } + } + } else if (head_size_rounded <= 256) { + if (has_lt_end && has_ut_start) { + return 32; + } else { + return 64; + } + } else { + // 不支持的情况 + throw std::runtime_error("head_size_rounded not supported"); + } +} + +template +__global__ +void scanMaxMinChunkedKernel( + const int *input, int b, int n, int *maxo, int *mino) { + int bid = threadIdx.y + blockIdx.y * blockDim.y; + if (bid >= b) return; + int i_offset = bid * n; + input = input + i_offset; + const int nblock_seqlen = ((n + kBlockN - 1) / kBlockN + 3) & 0xfffffffc; + constexpr int nums = (kBlockN + 31) / 32; + int warpId = blockIdx.x; + int tid = threadIdx.x; + int lane_id = threadIdx.x % 32; + int maxv, minv; + int idx = warpId * kBlockN + tid; + if (warpId * kBlockN + kBlockN > n) { + maxv = 0; + minv = INT_MAX; + #pragma unroll + for (int i = 0; i < nums; i++) { + if (idx < n && lane_id + i * 32 < kBlockN) { + maxv = max(maxv, input[idx]); + minv = min(minv, input[idx]); + } + idx += 32; + } + } else { + maxv = 0; + minv = INT_MAX; + #pragma unroll + for (int i = 0; i < nums; i++) { + if(lane_id + i * 32 < kBlockN) { + maxv = max(maxv, input[idx]); + minv = min(minv, input[idx]); + idx += 32; + } + } + } + __syncwarp(); + maxv = __reduce_max_sync(0xffffffff, maxv); + minv = __reduce_min_sync(0xffffffff, minv); + if (tid == 0) { + maxo[bid * nblock_seqlen + warpId] = maxv; + mino[bid * nblock_seqlen + warpId] = minv; + } +} + +// Enum for pointer dispatching in reduce_workload_kernel +enum PtrDispatch { SINGLE_PTR = 1, DUAL_PTR = 2, FULL_PTR = 4 }; + +template +__global__ void reduce_workload_kernel( + const int* LTStartMax, const int* LTStartMin, + const int* LTEndMax, const int* LTEndMin, + const int* UTStartMax, const int* UTStartMin, + const int* UTEndMax, const int* UTEndMin, + int* workload, // [B, H, Tr, 1] + int BH, int Tr, int Tc, int S, + int Br // m_block_size +) { + int bh = blockIdx.y; + int tr = blockIdx.x; + int tc = threadIdx.x; + int warpId = threadIdx.x / 32; + int laneId = threadIdx.x % 32; + + if(tr >= Tr) return; + + int wl = 0; + bool fully_masked = true; + bool partially_masked = false; + int lt_start_max = INT_MAX; + int lt_start_min = INT_MAX; + int lt_end_max = INT_MAX; + int lt_end_min = INT_MAX; + int ut_start_max = INT_MIN; + int ut_start_min = INT_MIN; + int ut_end_max = INT_MIN; + int ut_end_min = INT_MIN; + + __shared__ int smem[32]; + + const int idx = bh * Tc + tc; + const int q_idx = bh * Tr + tr; + + const int m_block_s = q_idx * kBlockM; + const int m_block_e = m_block_s + kBlockM < S ? m_block_s + kBlockM : S; + + lt_start_max = tc < Tc ? LTStartMax[idx] : INT_MAX; + lt_start_min = tc < Tc ? LTStartMin[idx] : INT_MAX; + + // 分支展开 + if constexpr (PTR_DISPATCH_TAG == FULL_PTR) { + lt_end_max = tc < Tc ? LTEndMax[idx] : INT_MAX; + lt_end_min = tc < Tc ? LTEndMin[idx] : INT_MAX; + ut_start_max = tc < Tc ? UTStartMax[idx] : INT_MIN; + ut_start_min = tc < Tc ? UTStartMin[idx] : INT_MIN; + ut_end_max = tc < Tc ? UTEndMax[idx] : INT_MIN; + ut_end_min = tc < Tc ? UTEndMin[idx] : INT_MIN; + + fully_masked = (m_block_s >= lt_start_max && m_block_e <= lt_end_min) || + (m_block_s >= ut_start_max && m_block_e <= ut_end_min); + partially_masked = (m_block_s < lt_end_max && m_block_e > lt_start_min) || + (m_block_s < ut_end_max && m_block_e > ut_start_min); + } + else if constexpr (PTR_DISPATCH_TAG == DUAL_PTR) { + if constexpr (is_causal) { + lt_end_max = tc < Tc ? LTEndMax[idx] : INT_MAX; + lt_end_min = tc < Tc ? LTEndMin[idx] : INT_MAX; + fully_masked = m_block_s >= lt_start_max && m_block_e <= lt_end_min; + partially_masked = m_block_s < lt_end_max && m_block_e > lt_start_min; + } else { + ut_end_max = tc < Tc ? UTEndMax[idx] : INT_MIN; + ut_end_min = tc < Tc ? UTEndMin[idx] : INT_MIN; + fully_masked = (m_block_s >= lt_start_max) || (m_block_e <= ut_end_min); + partially_masked = (m_block_e > lt_start_min) || (m_block_s < ut_end_max); + } + } + else if constexpr (PTR_DISPATCH_TAG == SINGLE_PTR) { + fully_masked = m_block_s >= lt_start_max; + partially_masked = m_block_e > lt_start_min; + } + + if(tc >= Tc){ + fully_masked = true; + partially_masked = false; + } + wl = fully_masked ? 0 : 1; + + unsigned mask = 0xffffffff; + // warp reduce sum + int wl_sum = wl; + for (int offset = 16; offset > 0; offset >>= 1) { + wl_sum += __shfl_down_sync(mask, wl_sum, offset); + } + if (laneId == 0) { + smem[warpId] = wl_sum; + } + __syncthreads(); + + if (threadIdx.x < 32) { + int val = (threadIdx.x < (blockDim.x + 31)/32) ? smem[threadIdx.x] : 0; + for (int offset = 16; offset > 0; offset >>= 1) { + val += __shfl_down_sync(mask, val, offset); + } + if (threadIdx.x == 0) { + workload[q_idx] = val; + } + } +} + +__global__ void indices_to_chunks_kernel( + const int* startend_row_indices, + const int* chunk_bucket_indices, + int* chunked_result, + int num_rows, + int num_buckets, + int chunk_size) +{ + int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= num_rows) return; + + int max_chunk_index = 0; + int row_val = startend_row_indices[row]; + + for (int bucket = 0; bucket < num_buckets; ++bucket) { + int bucket_idx = chunk_bucket_indices[bucket]; + int chunk_start = bucket_idx * chunk_size; + int local_index = row_val - chunk_start; + local_index = max(local_index, 0); + local_index = min(local_index, chunk_size); + + if (local_index > 0) { + local_index += bucket * chunk_size; + } + + if (bucket == 0 || local_index > max_chunk_index) { + max_chunk_index = local_index; + } + } + chunked_result[row] = max_chunk_index; +} + +__global__ void indices_rerank_kernel( + const int* startend_row_indices, + int* output_reranked_indices, + const int* chunk_indices, + int batch_size, + int num_heads, + int seq_len, + int feature_dim, + int num_chunks, + int chunk_size +) { + int output_seq_len = num_chunks * chunk_size; + int total_elements = batch_size * output_seq_len * num_heads * feature_dim; + int flat_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (flat_idx >= total_elements) return; + + int d = flat_idx % feature_dim; + int s_out = (flat_idx / feature_dim) % output_seq_len; + int h = (flat_idx / feature_dim / output_seq_len) % num_heads; + int b = (flat_idx / feature_dim / output_seq_len / num_heads) % batch_size; + + int chunk_id = s_out / chunk_size; + int chunk_offset = s_out % chunk_size; + int src_s = chunk_indices[chunk_id] * chunk_size + chunk_offset; + + if (src_s >= seq_len) return; + + int src_flat_idx = ((b * num_heads + h) * seq_len + src_s) * feature_dim + d; + int dst_flat_idx = flat_idx; + + output_reranked_indices[dst_flat_idx] = startend_row_indices[src_flat_idx]; +} + + + + +// ============================================================================ +// ScanMaxMin Operator +// ============================================================================ + +std::vector scan_max_min_cuda( + const paddle::Tensor& input, + const int head_size_rounded, + const int seq_len_q, + const int seq_len_k, + const int blocksize = -1, + const bool is_causal = false, + const float softcap = 0.0, + const int window_size_left = 0, + const int window_size_right = 0) { + CHECK_CUDA_INPUT(input); + + const auto batch_size = input.shape()[0]; + const auto num_heads = input.shape()[1]; + const auto num_sequences = input.shape()[2]; + const auto head_dim = input.shape()[3]; + + PADDLE_ENFORCE_EQ( + num_sequences, + seq_len_k, + common::errors::InvalidArgument( + "Input tensor's third dimension (num_sequences) must be equal to seq_len_k.")); + + const bool is_local = (window_size_left >= 0 || window_size_right >= 0) && !is_causal; + const bool is_flashmask = true; + const bool has_softcap = softcap > 0.0; + const bool has_lt_end = !is_causal && head_dim >= 2; + const bool has_ut_start = head_dim == 4; + + const int kernel_block_size_n = + blocksize > 0 ? blocksize : get_kBlockN(head_size_rounded, + is_flashmask, + is_causal, + has_softcap, + is_local, + seq_len_q, + seq_len_k, + has_lt_end, + has_ut_start); + + // Pad the number of blocks to be a multiple of 4 for performance + const int num_blocks_seqlen = + ((num_sequences + kernel_block_size_n - 1) / kernel_block_size_n + 3) & 0xfffffffc; + + std::vector output_shape = {batch_size, num_blocks_seqlen}; + auto max_output = paddle::empty(output_shape, input.dtype(), input.place()); + auto min_output = paddle::empty(output_shape, input.dtype(), input.place()); + + // Launch kernel + dim3 block_dim(32, 4); + dim3 grid_dim((num_sequences + kernel_block_size_n - 1) / kernel_block_size_n, + (batch_size + 3) / 4); + + const cudaStream_t stream = input.stream(); + + switch (kernel_block_size_n) { + case 32: + scanMaxMinChunkedKernel<32><<>>( + input.data(), batch_size, num_sequences, + max_output.data(), min_output.data()); + break; + case 64: + scanMaxMinChunkedKernel<64><<>>( + input.data(), batch_size, num_sequences, + max_output.data(), min_output.data()); + break; + case 96: + scanMaxMinChunkedKernel<96><<>>( + input.data(), batch_size, num_sequences, + max_output.data(), min_output.data()); + break; + case 128: + scanMaxMinChunkedKernel<128><<>>( + input.data(), batch_size, num_sequences, + max_output.data(), min_output.data()); + break; + default: + PD_THROW("Unsupported kernel_block_size_n: %d", kernel_block_size_n); + } + return {max_output, min_output}; +} + +std::vector ScanMaxMin( + const paddle::Tensor& input, + int head_size_rounded, + int seq_len_q, + int seq_len_k, + int blocksize, + bool is_causal, + float softcap, + int window_size_left, + int window_size_right) { +#ifdef PADDLE_WITH_CUDA + if (input.is_gpu()) { + return scan_max_min_cuda(input, + head_size_rounded, + seq_len_q, + seq_len_k, + blocksize, + is_causal, + softcap, + window_size_left, + window_size_right); + } +#endif + PD_THROW("Unsupported device: ScanMaxMin operator is only available for CUDA."); +} + + +// ============================================================================ +// ReduceWorkload Operator +// ============================================================================ + +template +void launch_reduce_workload_kernel( + const paddle::Tensor& lt_start_max, + const paddle::Tensor& lt_start_min, + const paddle::optional& lt_end_max, + const paddle::optional& lt_end_min, + const paddle::optional& ut_start_max, + const paddle::optional& ut_start_min, + const paddle::optional& ut_end_max, + const paddle::optional& ut_end_min, + paddle::Tensor& workload, + int batch_times_heads, + int num_row_blocks, + int num_col_blocks, + int stride, + int row_block_size, + bool is_causal, + cudaStream_t stream) { + + dim3 block_dim(1024, 1); + dim3 grid_dim(num_row_blocks, batch_times_heads); + + int ptr_dispatch_tag = SINGLE_PTR; + if (lt_end_max || ut_end_max) { + ptr_dispatch_tag = DUAL_PTR; + if (ut_start_max) { + ptr_dispatch_tag = FULL_PTR; + } + } + + int* workload_ptr = workload.data(); + const int* lt_start_max_ptr = lt_start_max.data(); + const int* lt_start_min_ptr = lt_start_min.data(); + const int* lt_end_max_ptr = lt_end_max ? lt_end_max.get().data() : nullptr; + const int* lt_end_min_ptr = lt_end_min ? lt_end_min.get().data() : nullptr; + const int* ut_start_max_ptr = ut_start_max ? ut_start_max.get().data() : nullptr; + const int* ut_start_min_ptr = ut_start_min ? ut_start_min.get().data() : nullptr; + const int* ut_end_max_ptr = ut_end_max ? ut_end_max.get().data() : nullptr; + const int* ut_end_min_ptr = ut_end_min ? ut_end_min.get().data() : nullptr; + + if (ptr_dispatch_tag == FULL_PTR) { + reduce_workload_kernel<<>>( + lt_start_max_ptr, lt_start_min_ptr, lt_end_max_ptr, lt_end_min_ptr, + ut_start_max_ptr, ut_start_min_ptr, ut_end_max_ptr, ut_end_min_ptr, + workload_ptr, batch_times_heads, num_row_blocks, num_col_blocks, stride, row_block_size); + } else if (ptr_dispatch_tag == DUAL_PTR) { + if (is_causal) { + reduce_workload_kernel<<>>( + lt_start_max_ptr, lt_start_min_ptr, lt_end_max_ptr, lt_end_min_ptr, + ut_start_max_ptr, ut_start_min_ptr, ut_end_max_ptr, ut_end_min_ptr, + workload_ptr, batch_times_heads, num_row_blocks, num_col_blocks, stride, row_block_size); + } else { + reduce_workload_kernel<<>>( + lt_start_max_ptr, lt_start_min_ptr, lt_end_max_ptr, lt_end_min_ptr, + ut_start_max_ptr, ut_start_min_ptr, ut_end_max_ptr, ut_end_min_ptr, + workload_ptr, batch_times_heads, num_row_blocks, num_col_blocks, stride, row_block_size); + } + } else if (ptr_dispatch_tag == SINGLE_PTR) { + reduce_workload_kernel<<>>( + lt_start_max_ptr, lt_start_min_ptr, lt_end_max_ptr, lt_end_min_ptr, + ut_start_max_ptr, ut_start_min_ptr, ut_end_max_ptr, ut_end_min_ptr, + workload_ptr, batch_times_heads, num_row_blocks, num_col_blocks, stride, row_block_size); + } else { + PD_THROW("Unknown pointer dispatch tag."); + } +} + +std::vector reduce_workload_cuda( + const paddle::Tensor& lt_start_max, + const paddle::Tensor& lt_start_min, + const paddle::optional& lt_end_max, + const paddle::optional& lt_end_min, + const paddle::optional& ut_start_max, + const paddle::optional& ut_start_min, + const paddle::optional& ut_end_max, + const paddle::optional& ut_end_min, + int batch_size, + int num_heads, + int num_row_blocks, + int num_col_blocks, + int stride, + int row_block_size, + bool is_causal, + int m_block_size) { + + const int kBlockM = m_block_size; + const int batch_times_heads = batch_size * num_heads; + + // Allocate output tensor + std::vector output_shape = {batch_size, num_heads, num_row_blocks, 1}; + auto workload = paddle::empty(output_shape, lt_start_max.dtype(), lt_start_max.place()); + + cudaStream_t stream = lt_start_max.stream(); + + switch (kBlockM) { + case 64: + launch_reduce_workload_kernel<64>( + lt_start_max, lt_start_min, lt_end_max, lt_end_min, ut_start_max, + ut_start_min, ut_end_max, ut_end_min, workload, batch_times_heads, + num_row_blocks, num_col_blocks, stride, row_block_size, is_causal, stream); + break; + case 96: + launch_reduce_workload_kernel<96>( + lt_start_max, lt_start_min, lt_end_max, lt_end_min, ut_start_max, + ut_start_min, ut_end_max, ut_end_min, workload, batch_times_heads, + num_row_blocks, num_col_blocks, stride, row_block_size, is_causal, stream); + break; + case 128: + launch_reduce_workload_kernel<128>( + lt_start_max, lt_start_min, lt_end_max, lt_end_min, ut_start_max, + ut_start_min, ut_end_max, ut_end_min, workload, batch_times_heads, + num_row_blocks, num_col_blocks, stride, row_block_size, is_causal, stream); + break; + default: + PD_THROW("Unsupported m_block_size: %d", kBlockM); + } + return {workload}; +} + +std::vector ReduceWorkloadOp( + const paddle::Tensor& lt_start_max, + const paddle::Tensor& lt_start_min, + const paddle::optional& lt_end_max, + const paddle::optional& lt_end_min, + const paddle::optional& ut_start_max, + const paddle::optional& ut_start_min, + const paddle::optional& ut_end_max, + const paddle::optional& ut_end_min, + int batch_size, + int num_heads, + int num_row_blocks, + int num_col_blocks, + int stride, + int row_block_size, + bool is_causal, + int m_block_size) { +#ifdef PADDLE_WITH_CUDA + if (lt_start_max.is_gpu()) { + return reduce_workload_cuda(lt_start_max, + lt_start_min, + lt_end_max, + lt_end_min, + ut_start_max, + ut_start_min, + ut_end_max, + ut_end_min, + batch_size, + num_heads, + num_row_blocks, + num_col_blocks, + stride, + row_block_size, + is_causal, + m_block_size); + } +#endif + PD_THROW("Unsupported device: ReduceWorkload operator is only available for CUDA."); +} + + +// ============================================================================ +// IndicesToChunks & IndicesRerank Operators +// ============================================================================ + +std::vector IndicesToChunksOp( + const paddle::Tensor& row_indices, + const paddle::Tensor& chunk_bucket_indices, + int chunk_size) { +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_EQ(row_indices.is_gpu(), true, + common::errors::InvalidArgument("Input 'row_indices' must be a CUDA tensor.")); + + auto chunked_result = paddle::empty_like(row_indices); + + const int num_rows = row_indices.numel(); + const int num_buckets = chunk_bucket_indices.numel(); + const int num_threads_per_block = 256; + const int num_blocks = (num_rows + num_threads_per_block - 1) / num_threads_per_block; + + indices_to_chunks_kernel<<>>( + row_indices.data(), + chunk_bucket_indices.data(), + chunked_result.data(), + num_rows, + num_buckets, + chunk_size); + + return {chunked_result}; +#else + PD_THROW("Unsupported device: IndicesToChunks operator is only available for CUDA."); +#endif +} + +std::vector IndicesRerankOp( + const paddle::Tensor& input_row_indices, + const paddle::Tensor& chunk_indices, + int batch_size, + int num_heads, + int seq_len, + int feature_dim, + int num_chunks, + int chunk_size) { +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_EQ(input_row_indices.is_gpu(), true, + common::errors::InvalidArgument("Input 'input_row_indices' must be a CUDA tensor.")); + + const int output_seq_len = num_chunks * chunk_size; + auto reranked_indices = paddle::empty({batch_size, num_heads, output_seq_len, feature_dim}, + input_row_indices.dtype(), + input_row_indices.place()); + + const int total_elements = batch_size * output_seq_len * num_heads * feature_dim; + const int num_threads_per_block = 256; + const int num_blocks = (total_elements + num_threads_per_block - 1) / num_threads_per_block; + + indices_rerank_kernel<<>>( + input_row_indices.data(), + reranked_indices.data(), + chunk_indices.data(), + batch_size, + num_heads, + seq_len, + feature_dim, + num_chunks, + chunk_size); + + return {reranked_indices}; +#else + PD_THROW("Unsupported device: IndicesRerank operator is only available for CUDA."); +#endif +} + + +// ============================================================================ +// Operator Registrations +// ============================================================================ + +PD_BUILD_OP(scan_max_min) + .Inputs({"Input"}) + .Outputs({"MaxOut", "MinOut"}) + .Attrs({"head_size_rounded: int", + "seq_len_q: int", + "seq_len_k: int", + "blocksize: int", + "is_causal: bool", + "softcap: float", + "window_size_left: int", + "window_size_right: int"}) + .SetKernelFn(PD_KERNEL(ScanMaxMin)); + +PD_BUILD_OP(reduce_workload) + .Inputs({"LTStartMax", "LTStartMin", + paddle::Optional("LTEndMax"), paddle::Optional("LTEndMin"), + paddle::Optional("UTStartMax"), paddle::Optional("UTStartMin"), + paddle::Optional("UTEndMax"), paddle::Optional("UTEndMin")}) + .Outputs({"Workload"}) + .Attrs({"batch_size: int", + "num_heads: int", + "num_row_blocks: int", + "num_col_blocks: int", + "stride: int", + "row_block_size: int", + "is_causal: bool", + "m_block_size: int"}) + .SetKernelFn(PD_KERNEL(ReduceWorkloadOp)); + +PD_BUILD_OP(indices_to_chunks) + .Inputs({"RowIndices", "ChunkBucketIndices"}) + .Outputs({"ChunkedResult"}) + .Attrs({"chunk_size: int"}) + .SetKernelFn(PD_KERNEL(IndicesToChunksOp)); + +PD_BUILD_OP(indices_rerank) + .Inputs({"InputRowIndices", "ChunkIndices"}) + .Outputs({"RerankedIndices"}) + .Attrs({"batch_size: int", + "num_heads: int", + "seq_len: int", + "feature_dim: int", + "num_chunks: int", + "chunk_size: int"}) + .SetKernelFn(PD_KERNEL(IndicesRerankOp)); \ No newline at end of file diff --git a/csrc/utils/cp_balance/csrc/setup.py b/csrc/utils/cp_balance/csrc/setup.py new file mode 100644 index 00000000000..db837885009 --- /dev/null +++ b/csrc/utils/cp_balance/csrc/setup.py @@ -0,0 +1,121 @@ +import os +import subprocess +import shutil +import re + + +def get_version_from_txt(): + version_file = os.path.join(os.path.dirname(__file__), "version.txt") + with open(version_file, "r") as f: + version = f.read().strip() + return version + + +def custom_version_scheme(version): + base_version = get_version_from_txt() + date_str = ( + subprocess.check_output( + ["git", "log", "-1", "--format=%cd", "--date=format:%Y%m%d"] + ) + .decode() + .strip() + ) + return f"{base_version}.dev{date_str}" + + +def no_local_scheme(version): + return "" + + +def change_pwd(): + """change_pwd""" + path = os.path.dirname(__file__) + if path: + os.chdir(path) + +def get_cuda_version(): + nvcc_path = shutil.which("nvcc") + if nvcc_path is None: + raise FileNotFoundError( + "nvcc command not found. Please make sure CUDA toolkit is installed and nvcc is in PATH." + ) + + result = subprocess.run( + ["nvcc", "--version"], + capture_output=True, + text=True, + check=True, + ) + version_output = result.stdout + + match = re.search(r"release (\d+)\.(\d+)", version_output) + if not match: + raise ValueError( + f"Cannot parse CUDA version from nvcc output:\n{version_output}" + ) + cuda_major = int(match.group(1)) + cuda_minor = int(match.group(2)) + return cuda_major, cuda_minor + + +def setup_ops_extension(): + from paddle.utils.cpp_extension import CUDAExtension, setup + + # 定义 NVCC 编译参数 + nvcc_args = [ + "-O3", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT162_OPERATORS__", + "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "-maxrregcount=32", + "-lineinfo", + "-DCUTLASS_DEBUG_TRACE_LEVEL=0", + "-gencode=arch=compute_80,code=sm_80", + "-gencode=arch=compute_90a,code=sm_90a", + "-gencode=arch=compute_100,code=sm_100", + "-DNDEBUG", + ] + cuda_major, cuda_minor = get_cuda_version() + if cuda_major < 12: + raise ValueError( + f"CUDA version must be >= 12. Detected version: {cuda_major}.{cuda_minor}" + ) + if cuda_major == 12 and cuda_minor < 8: + nvcc_args = [arg for arg in nvcc_args if "compute_100" not in arg] + + ext_module = CUDAExtension( + sources=[ + # cpp files + # cuda files + "./cp_balance_utils.cu", + ], + include_dirs=[ + os.path.join(os.getcwd(), "./"), + ], + extra_compile_args={ + "cxx": [ + "-O3", + "-w", + "-Wno-abi", + "-fPIC", + "-std=c++17", + ], + "nvcc": nvcc_args, + }, + ) + + change_pwd() + setup( + name="flashmask_cpbalance_cudaops", + ext_modules=[ext_module], + version="0.0.1", + setup_requires=["setuptools_scm"], + ) + + +setup_ops_extension() \ No newline at end of file diff --git a/csrc/utils/cp_balance/test/benchmark_cp_balance_balancex.py b/csrc/utils/cp_balance/test/benchmark_cp_balance_balancex.py new file mode 100644 index 00000000000..7328313af8b --- /dev/null +++ b/csrc/utils/cp_balance/test/benchmark_cp_balance_balancex.py @@ -0,0 +1,934 @@ +import numpy as np +from functools import partial +from typing import Optional, List +from tabulate import tabulate +import paddle +import os +import paddle.nn.functional as F +from paddle.nn.functional.flash_attention import flashmask_attention +from cp_balance.context_parallel_utils import flashmask_attention_cp, scatter_balance, all_gather_balance +from cp_balance.cp_balance import assign_tasks_heap, get_q_workload,balance_flashmask_input, tensor_md5, balance_alltoall + +import paddle.distributed.fleet as fleet +import time + +cp_size = 4 +strategy = fleet.DistributedStrategy() + +strategy.hybrid_configs = { + "dp_degree": 1, + "mp_degree": 4, + "pp_degree": 1, + "sharding_degree": 4, + "sep_degree": 1, + "ep_degree": 16, + "moe_sharding_degree": 1, + "cp_degree": cp_size, + "order": ["sharding", "moe_sharding", "pp", "sep", "cp", "dp", "ep", "mp"] +} + +fleet.init(is_collective=True, strategy=strategy) +cp_group = fleet.get_hybrid_communicate_group().get_context_parallel_group() + +class bcolors: + HEADER = '\033[95m' + OKBLUE = '\033[94m' + OKCYAN = '\033[96m' + OKGREEN = '\033[92m' + WARNING = '\033[93m' + FAIL = '\033[91m' + ENDC = '\033[0m' + BOLD = '\033[1m' + UNDERLINE = '\033[4m' + +def from_paddle(x: paddle.Tensor): + if x.dtype == paddle.bfloat16 or x.dtype == "bfloat16": + return torch.from_numpy(x.view("uint16").numpy()).to("cuda").view(torch.bfloat16) + elif x.dtype == paddle.float32 or x.dtype == "float32": + return torch.from_numpy(x.numpy()).to("cuda") + else: + assert False + +def _summarize_statistics(times, quantiles, return_mode): + if quantiles is not None: + ret = paddle.quantile(times, paddle.to_tensor(quantiles, dtype=paddle.float32)).tolist() + if len(ret) == 1: + ret = ret[0] + return ret + if return_mode == "all": + return times.tolist() + return getattr(paddle, return_mode)(times).item() + +def split_sequence(sequence_length, num_answers=2): + if sequence_length < num_answers + 1: + raise ValueError(f"序列长度必须至少为 {num_answers + 1}") + + base = sequence_length // (num_answers + 1) + extra = sequence_length % (num_answers + 1) + # 前extra个部分多加1 + lengths = [base + (1 if i < extra else 0) for i in range(num_answers + 1)] + + return lengths + +def do_bench_dist(fn, cp_group, warmup=1, rep=300, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean"): + """ + Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with + the 20-th and 80-th performance percentile. + + :param fn: Function to benchmark + :type fn: Callable + :param warmup: Warmup time (in ms) + :type warmup: int + :param rep: Repetition time (in ms) + :type rep: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + :param quantiles: Performance percentile to return in addition to the median. + :type quantiles: list[float], optional + :param fast_flush: Use faster kernel to flush L2 cache between measurements + :type fast_flush: bool, default is True + :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all" Default is "mean". :type return_mode: str + """ + assert return_mode in ["min", "max", "mean", "median", "all"] + paddle.base.core.nvprof_nvtx_push("paddle") + + fn() + paddle.device.synchronize() + paddle.distributed.barrier(group=cp_group) + + + # paddle.device.synchronize() + + # We maintain a buffer of 256 MB that we clear + # before each kernel call to make sure that the L2 cache + # doesn't contain any input data before the run + # cache_size = 256 * 1024 * 1024 + # if fast_flush: + # cache = paddle.empty([int(cache_size // 4)], dtype=paddle.int32) + # else: + # cache = paddle.empty([int(cache_size)], dtype=paddle.int8) + + # Estimate the runtime of the function + + # compute number of warmup and repeat + n_warmup = 3 + n_repeat = 5 + # Warm-up + for _ in range(n_warmup): + time.sleep(0.1) + fn() + paddle.device.synchronize() + paddle.distributed.barrier(group=cp_group) + # Benchmark + dist_time = [] + for i in range(n_repeat): + time.sleep(0.1) + paddle.device.synchronize() + paddle.distributed.barrier(group=cp_group) + time0 = time.perf_counter() + # we don't want `fn` to accumulate gradient values + # if it contains a backward pass. So we clear the + # provided gradients + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + fn() + paddle.device.synchronize() + paddle.distributed.barrier(group=cp_group) + time1 = time.perf_counter() + dist_time.append((time1 - time0) * 1000) + # Record clocks + # paddle.device.synchronize() + paddle.base.core.nvprof_nvtx_pop() + return sum(dist_time) / n_repeat + + +def do_bench_flashmaskcp(q_local, k_local, v_local, o_grad_local, startend_row_indices, group, is_causal,mode = "balance_q", warmup=50, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean"): + """ + Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with + the 20-th and 80-th performance percentile. + + :param fn: Function to benchmark + :type fn: Callable + :param warmup: Warmup time (in ms) + :type warmup: int + :param rep: Repetition time (in ms) + :type rep: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + :param quantiles: Performance percentile to return in addition to the median. + :type quantiles: list[float], optional + :param fast_flush: Use faster kernel to flush L2 cache between measurements + :type fast_flush: bool, default is True + :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all" Default is "mean". :type return_mode: str + """ + assert return_mode in ["min", "max", "mean", "median", "all"] + + rank = paddle.distributed.get_rank() + out_local = flashmask_attention_cp(q_local, k_local, v_local, startend_row_indices, causal=is_causal, mode=mode) + # print('pt00') + out_local.backward(o_grad_local) + # print('pt0') + paddle.device.synchronize() + paddle.distributed.barrier(group=cp_group) + # paddle.device.synchronize() + # print('here') + + # We maintain a buffer of 256 MB that we clear + # before each kernel call to make sure that the L2 cache + # doesn't contain any input data before the run + # cache_size = 256 * 1024 * 1024 + # if fast_flush: + # cache = paddle.empty([int(cache_size // 4)], dtype=paddle.int32) + # else: + # cache = paddle.empty([int(cache_size)], dtype=paddle.int8) + + # Estimate the runtime of the function + start_event = paddle.device.Event(enable_timing=True) + end_event = paddle.device.Event(enable_timing=True) + start_event.record() + for _ in range(5): + # cache.zero_() + time.sleep(0.1) + out_local = flashmask_attention_cp(q_local, k_local, v_local, startend_row_indices, causal=is_causal, mode=mode) + out_local.backward(o_grad_local) + paddle.device.synchronize() + paddle.distributed.barrier(group=cp_group) + + # paddle.device.synchronize() + end_event.record() + paddle.device.synchronize() + paddle.distributed.barrier(group=cp_group) + # paddle.device.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # print('pt2') + # compute number of warmup and repeat + n_warmup = max(3, int(warmup / estimate_ms)) + n_repeat = max(5, int(rep / estimate_ms)) + start_event = [paddle.device.Event(enable_timing=True) for i in range(n_repeat)] + end_event = [paddle.device.Event(enable_timing=True) for i in range(n_repeat)] + # Warm-up + for _ in range(n_warmup): + time.sleep(0.1) + out_local =flashmask_attention_cp(q_local, k_local, v_local, startend_row_indices, causal=is_causal, mode=mode) + out_local.backward(o_grad_local, retain_graph=True) + paddle.device.synchronize() + paddle.distributed.barrier(group=cp_group) + # paddle.device.synchronize() + # Benchmark + times_fwd = [] + times_bwd = [] + for i in range(n_repeat): + time.sleep(0.1) + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + paddle.device.synchronize() + paddle.distributed.barrier(group=cp_group) + # paddle.device.synchronize() + paddle.base.core.nvprof_nvtx_push(f"flashmask_cp_fwd_{rank}") + t0 = time.perf_counter() + out_local = flashmask_attention_cp(q_local, k_local, v_local, startend_row_indices, causal=is_causal, mode=mode) + paddle.device.synchronize() + paddle.distributed.barrier(group=cp_group) + # paddle.device.synchronize() + paddle.base.core.nvprof_nvtx_pop() + paddle.base.core.nvprof_nvtx_push(f"flashmask_cp_bwd_{rank}") + t1 = time.perf_counter() + out_local.backward(o_grad_local, retain_graph=True) + paddle.device.synchronize() + paddle.distributed.barrier(group=cp_group) + # paddle.device.synchronize() + paddle.base.core.nvprof_nvtx_pop() + t2 = time.perf_counter() + times_fwd.append(1000 * (t1 - t0)) + times_bwd.append(1000 * (t2 - t1)) + + # Record clocks + print(times_bwd) + paddle.device.synchronize() + paddle.distributed.barrier(group=cp_group) + # paddle.device.synchronize() + # print('pt3') + return sum(times_fwd) / n_repeat, sum(times_bwd) / n_repeat + +def cp_flashmask_balance_bench(query, key, value, startend_row_indices, is_causal,o_grad,mode): + B,S,H,D = query.shape + group = cp_group + rank = group.rank + local_qs = [] + local_ks = [] + local_vs = [] + local_ograds =[] + balance_q_chunksize = 2048 + workload = get_q_workload(startend_row_indices, balance_q_chunksize, 128,128) + + # print(workload) + total_workload = paddle.sum(workload,axis = 1) + buckets, bucket_weights,cuts = assign_tasks_heap(workload.reshape(-1,2), cp_size) + hcg = fleet.get_hybrid_communicate_group() + for(_, idx) in buckets[rank]: + local_qs.append(query[:,idx * balance_q_chunksize:(idx+1) * balance_q_chunksize,:,:]) + local_ks.append(key[:,idx * balance_q_chunksize:(idx+1) * balance_q_chunksize,:,:]) + local_vs.append(value[:,idx * balance_q_chunksize:(idx+1) * balance_q_chunksize,:,:]) + local_ograds.append(o_grad[:,idx * balance_q_chunksize:(idx+1) * balance_q_chunksize,:,:]) + local_q = paddle.concat(local_qs, axis=1).detach().contiguous() + local_k = paddle.concat(local_ks, axis=1).detach().contiguous() + local_v = paddle.concat(local_vs, axis=1).detach().contiguous() + local_o_grad = paddle.concat(local_ograds, axis=1) + local_startend_row_indices, buckets = balance_flashmask_input(startend_row_indices, cp_size, rank) + local_q = scatter_balance(query, group = cp_group, axis=1,mode = "balanced_swap", buckets = buckets).detach().contiguous() + print("pass0") + balancex = lambda: balance_flashmask_input(startend_row_indices, cp_size, rank) + balance_time = do_bench_dist(balancex,cp_group = cp_group) + print("pass1") + + local_k.stop_gradient = False + local_v.stop_gradient = False + local_q.stop_gradient = False + x = query.detach().reshape(B,S,-1).contiguous() + + local_startend_row_indices, buckets = balance_flashmask_input( startend_row_indices, cp_size, rank,balance_chunk_size= balance_q_chunksize) + + scatter_x = lambda: scatter_balance(x, group = cp_group, axis=1,mode = "balanced_swap", buckets = buckets) + scatter_x_time = do_bench_dist(scatter_x,cp_group = cp_group) + local_x = scatter_balance(x, group = cp_group, axis=1,mode = "balanced_swap", buckets = buckets) + + gather_x = lambda: all_gather_balance(local_x, group = cp_group, axis=1,mode = "balanced_swap", buckets = buckets) + gather_x_time = do_bench_dist(gather_x,cp_group = cp_group) + + # a2a_o = lambda: balance_alltoall(local_o_grad, cp_size, cp_group, balance_q_chunksize, recv_dict, send_dict) + # a2a_x_time = do_bench_dist(a2a_o,cp_group = cp_group) + + # startend_row_indices.stop_gradient = False + + cp_fwd_time, cp_bwd_time = do_bench_flashmaskcp(local_q, local_k, local_v, local_o_grad, local_startend_row_indices, group, is_causal, mode) + # print(f"cp balance fwd+bwd time: {cp_fwd_bwd_time} ms\n") + print("pass2") + return balance_time,scatter_x_time,gather_x_time, cp_fwd_time, cp_bwd_time + +def test_cp_famask( + startend_row_indices, + B: int = 16, + S: int = 8192, + H: int = 16, + D: int = 64, + dtype = 'bf16', +): + """ + 测试上下文并行FlashMask注意力机制的性能基准 + + 该函数用于测试在分布式并行环境中FlashMask注意力机制的前向传播和后向传播性能, + 支持不同类型的注意力掩码生成策略。 + + Args: + generate_mask_fn: 注意力掩码生成函数,用于生成startend_row_indices和因果关系标记 + B: 批次大小,默认16 + S: 序列长度,默认8192 + H: 注意力头数,默认16 + D: 每个注意力头的维度,默认64 + dtype: 数据类型,默认'bf16' + + Returns: + tuple: 包含前向传播时间和后向传播时间的元组 (fwd_time, bwd_time),单位为毫秒 + """ + # paddle.seed(2024) + paddle.seed(2024) + # batch_size = 1 + total_q = S + total_k = S + batch_size = B + num_head = H + num_head_q = 12 * H + head_size = D + rank = cp_group.rank + # total_k = total_q * 2 + if rank == 0: + query = paddle.randn([batch_size, total_q, num_head_q, head_size], dtype=paddle.bfloat16) + key = paddle.randn([batch_size, total_k, num_head, head_size], dtype=paddle.bfloat16) + value = paddle.randn([batch_size, total_k, num_head, head_size], dtype=paddle.bfloat16) + o_grad = paddle.randn([batch_size, total_q, num_head_q, head_size], dtype=paddle.bfloat16) + else: + query = paddle.empty([batch_size, total_q, num_head_q, head_size], dtype=paddle.bfloat16) + key = paddle.empty([batch_size, total_k, num_head, head_size], dtype=paddle.bfloat16) + value = paddle.empty([batch_size, total_k, num_head, head_size], dtype=paddle.bfloat16) + o_grad = paddle.empty([batch_size, total_q, num_head_q, head_size], dtype=paddle.bfloat16) + + # 广播到所有 rank + paddle.distributed.broadcast(query, src=cp_group.ranks[0],group=cp_group) + paddle.distributed.broadcast(key, src=cp_group.ranks[0],group=cp_group) + paddle.distributed.broadcast(value, src=cp_group.ranks[0],group=cp_group) + paddle.distributed.broadcast(o_grad,src = cp_group.ranks[0],group=cp_group) + paddle.device.synchronize() + paddle.distributed.barrier(group=cp_group) + query.stop_gradient = False + key.stop_gradient = False + value.stop_gradient = False + causal = False + + balance_time, scatter_x_time, gather_x_time, fwd_time, bwd_time = cp_flashmask_balance_bench(query, key, value, startend_row_indices, causal,o_grad, "balance_q") + paddle.device.synchronize() + # out1.backward(o_grad1) + # paddle.device.synchronize() + + # print("pypt2:") + # print(startend_row_indices) + total_time = fwd_time + bwd_time + return fwd_time, bwd_time, total_time, balance_time,scatter_x_time,gather_x_time + # with open("execution_times.txt", "a") as log_file: + # log_file.write(f"bsz: {batch_size},num_head_k: {num_head},num_head_q: {num_head * 4},hsz: {head_size},seqlen: {total_q}, flashattnv1: {flashattnv1_time:.6f}s, " + # f"flashattnv2: {flashattnv2_time:.6f}s\n") + # for x,y in [(out1,out),(dq1,query.grad),(dk1,key.grad),(dv1,value.grad)]: + # strict_check(x.flatten(), y.flatten()) + # for x,y in [(out1,out)]: + # strict_check(x.flatten(), y.flatten()) + +def strict_check(x, y): + if isinstance(x, paddle.Tensor): + if x.dtype == paddle.bfloat16 or x.dtype == "float16": + # x = x.view("float16").numpy() + x = x.cast("float32").numpy() + else: + x = x.numpy() + else: + assert False + + # if isinstance(y, torch.Tensor): + # if y.dtype == torch.bfloat16 or y.dtype == "bfloat16": + # # x = x.view("float16").numpy() + # y = y.to(torch.float32).detach().cpu().numpy() + # else: + # y = y.detach().cpu().numpy() + + if isinstance(y, paddle.Tensor): + if y.dtype == paddle.bfloat16 or y.dtype == "float16": + # y = y.view("float16").numpy() + y = y.cast("float32").numpy() + else: + y = y.numpy() + + try: + print(f"{x=}, {y=}") + np.testing.assert_allclose(x.flatten(), y.flatten(),rtol=1e-2, atol=1e-2) + except Exception as e: + print('---------------') + idx = np.where(~(x == y)) + print(f"fail idx: {idx=}") + print(f"shape:'{x.shape}'") + # print(f"fail idx:'{np.unique(idx[0])}'") + print(x[idx]) + print(y[idx]) + raise e + + +def ele_check(x, y): + if isinstance(x, paddle.Tensor): + if x.dtype == paddle.bfloat16 or x.dtype == "bfloat16": + # x = x.view("uint16").numpy() + x = x.cast("float32").numpy() + else: + x = x.numpy() + else: + assert False + + if isinstance(y, torch.Tensor): + if y.dtype == torch.bfloat16 or y.dtype == "bfloat16": + # x = x.view("uint16").numpy() + y = y.to(torch.float32).detach().cpu().numpy() + else: + y = y.detach().cpu().numpy() + + # if isinstance(y, paddle.Tensor): + # if y.dtype == paddle.bfloat16 or y.dtype == "bfloat16": + # # y = y.view("uint16").numpy() + # y = y.cast("float32").numpy() + # else: + # y = y.numpy() + + try: + print(f"{x=}, {y=}") + np.testing.assert_allclose(np.sort(x.flatten()), np.sort(y.flatten()),rtol=1e-3, atol=1e-6) + except Exception as e: + print('---------------') + idx = np.where(~(x == y)) + print(f"fail idx: {idx=}") + print(f"shape:'{x.shape}'") + # print(f"fail idx:'{np.unique(idx[0])}'") + print(x[idx]) + print(y[idx]) + raise e + +def flashmask_to_densemask(startend_row_indices, dtype, causal=True): + if startend_row_indices is None: + return None + bz, num_head, seq_len, bound_num = startend_row_indices.shape + m = paddle.zeros((bz, num_head, seq_len, seq_len), dtype=dtype) + has_end = (causal and bound_num == 2) or ((not causal) and bound_num == 4) + for bi in range(bz): + for hi in range(num_head): + for j in range(seq_len): + downstart = startend_row_indices[bi, hi, j, 0] + if has_end: + downend = startend_row_indices[bi, hi, j, 1] + m[bi, hi, downstart:downend, j] = -np.inf + else: + m[bi, hi, downstart:, j] = -np.inf + if causal: + m[bi, hi, :j, j] = -np.inf + else: + if has_end: + upstart = startend_row_indices[bi, hi, j, 2] + upend = startend_row_indices[bi, hi, j, 3] + m[bi, hi, upstart:upend, j] = -np.inf + else: + upend = startend_row_indices[bi, hi, j, 1] + m[bi, hi, :upend, j] = -np.inf + return m + +def generate_none_mask(B, S, H, D, causal=True): + return None, causal + +def generate_ones_mask(B, S, H, D): + startend_row_indices = paddle.zeros( + shape=(B, H, S, 2), dtype="int32" + ) + startend_row_indices[:,:,:,0]=S + causal = False + return startend_row_indices, causal + +def generate_causal_mask(B,S,H,D): + startend_row_indices = paddle.zeros( + shape=(B, H, S, 1), dtype="int32" + ) + startend_row_indices[:,:,:,0]=S + causal = True + return startend_row_indices, causal + +def generate_sliding_window_mask(B, S, H, D, window_size=1024): + startend_row_indices = paddle.arange( + window_size, S + window_size, dtype="int32" + ).reshape((1, 1, S, 1)) + startend_row_indices = paddle.clip( + startend_row_indices, max=S + ).repeat_interleave(B, 0) + + causal=True + return startend_row_indices, causal + +# def generate_causal_document_mask(B, S, H, D, doc_seq_lens=[2538, 1742, 3213]): +def generate_causal_document_mask(B,S,H,D, doc_seq_lens=[2538, 1742, 3213]): + total_seq_len = np.sum(doc_seq_lens) + assert total_seq_len <= S, f"{total_seq_len=}, {S=}" + padding = S - np.sum(doc_seq_lens) + doc_seq_lens[-1] += padding + seq_cusums = np.cumsum(doc_seq_lens) + + startend_row_indices = np.repeat(seq_cusums, doc_seq_lens) + startend_row_indices = paddle.to_tensor(startend_row_indices, dtype=paddle.int32).reshape((1, 1, S, 1)) + startend_row_indices = startend_row_indices.repeat_interleave(B, 0) + + causal = True + return startend_row_indices, causal + +def generate_upper_document_mask(B,S,H,D, doc_seq_lens=[2538, 1742, 3213],padding_size = 256): + total_seq_len = np.sum(doc_seq_lens) + assert total_seq_len <= S + padding = S - np.sum(doc_seq_lens) + + up_right_row_indices = [] + + cur_len_so_far = 0 + for i in range(len(doc_seq_lens)): + up_right_row_indices.extend([cur_len_so_far] * doc_seq_lens[i]) + if i < len(doc_seq_lens) -1: + cur_len_so_far += doc_seq_lens[i] + if padding > 0: + up_right_row_indices.extend([cur_len_so_far] * padding) + + up_right_row_indices = paddle.to_tensor(up_right_row_indices, dtype=paddle.int32).reshape((1, 1, S, 1)).repeat_interleave(B, 0) + down_left_row_indices = paddle.ones_like(up_right_row_indices) * (S - padding_size) + startend_row_indices = paddle.concat([down_left_row_indices, up_right_row_indices], axis=-1) + + causal = False + return startend_row_indices, causal + +def generate_document_mask(B, S, H, D, doc_seq_lens=[2538, 1742, 3213]): + total_seq_len = np.sum(doc_seq_lens) + assert total_seq_len <= S + padding = S - np.sum(doc_seq_lens) + + down_left_row_indices = [] + up_right_row_indices = [] + + cur_len_so_far = doc_seq_lens[0] + for i in range(len(doc_seq_lens)): + down_left_row_indices.extend([cur_len_so_far] * doc_seq_lens[i]) + if i < len(doc_seq_lens) -1: + cur_len_so_far += doc_seq_lens[i+1] + if padding > 0: + down_left_row_indices.extend([cur_len_so_far] * padding) + + cur_len_so_far = 0 + for i in range(len(doc_seq_lens)): + up_right_row_indices.extend([cur_len_so_far] * doc_seq_lens[i]) + if i < len(doc_seq_lens) -1: + cur_len_so_far += doc_seq_lens[i] + if padding > 0: + up_right_row_indices.extend([cur_len_so_far] * padding) + + down_left_row_indices = paddle.to_tensor(down_left_row_indices, dtype=paddle.int32).reshape((1, 1, S, 1)).repeat_interleave(B, 0) + up_right_row_indices = paddle.to_tensor(up_right_row_indices, dtype=paddle.int32).reshape((1, 1, S, 1)).repeat_interleave(B, 0) + startend_row_indices = paddle.concat([down_left_row_indices, up_right_row_indices], axis=-1) + + causal = False + return startend_row_indices, causal + +def generate_share_question_mask(B, S, H, D, doc_seq_lens=[2538, 1742, 3213]): + total_seq_len = np.sum(doc_seq_lens) + assert total_seq_len <= S + assert len(doc_seq_lens) >= 3 + padding = S - total_seq_len + + startend_row_indices = [S] * doc_seq_lens[0] + + cur_len_so_far = doc_seq_lens[0] + for idx in range(1, len(doc_seq_lens)): + cur_len_so_far += doc_seq_lens[idx] + startend_row_indices.extend([cur_len_so_far] * doc_seq_lens[idx]) + + if padding > 0: + startend_row_indices.extend([cur_len_so_far] * padding) + + startend_row_indices = paddle.to_tensor(startend_row_indices, dtype=paddle.int32).reshape((1, 1, S, 1)).repeat_interleave(B, 0) + + causal = True + return startend_row_indices, causal + +def generate_global_sliding_window_mask(B, S, H, D, global_token=16, window_size=(512, 512)): + assert len(window_size) == 2 + left_window_size, right_window_size = window_size + + down_left_start_row_indices = [] + down_left_end_row_indices = [] + up_right_start_row_indices = [] + up_right_end_row_indices = [] + + down_left_start_row_indices = paddle.arange( + left_window_size + 1, S + left_window_size + 1, dtype="int32" + ).clip(max=S) + down_left_start_row_indices[:global_token] = S + down_left_start_row_indices = down_left_start_row_indices.reshape((1, 1, S, 1)).repeat_interleave(B, 0) + + down_left_end_row_indices = paddle.full([S], S, dtype="int32").reshape((1, 1, S, 1)).repeat_interleave(B, 0) + + up_right_start_row_indices = paddle.full([S], global_token, dtype="int32") + up_right_start_row_indices[:global_token+right_window_size+1] = 0 + up_right_start_row_indices = up_right_start_row_indices.reshape((1, 1, S, 1)).repeat_interleave(B, 0) + + up_right_end_row_indices = paddle.arange( + -right_window_size, S - right_window_size, dtype="int32" + ) + up_right_end_row_indices[:global_token+right_window_size+1] = 0 + up_right_end_row_indices = up_right_end_row_indices.reshape((1, 1, S, 1)).repeat_interleave(B, 0) + + startend_row_indices = paddle.concat([down_left_start_row_indices, down_left_end_row_indices, up_right_start_row_indices, up_right_end_row_indices], axis=-1) + + causal = False + return startend_row_indices, causal + +def generate_causal_blockwise_mask(B, S, H, D, doc_seq_lens=[2538, 1742, 3213]): + total_seq_len = np.sum(doc_seq_lens) + assert total_seq_len <= S + assert len(doc_seq_lens) >= 3 + padding = S - np.sum(doc_seq_lens) + + start_row_indices = [] + cur_len_so_far = doc_seq_lens[0] + for i in range(len(doc_seq_lens)): + start_row_indices.extend([cur_len_so_far] * doc_seq_lens[i]) + if i < len(doc_seq_lens) - 1: + cur_len_so_far += doc_seq_lens[i+1] + if padding > 0: + start_row_indices.extend([cur_len_so_far] * padding) + start_row_indices = paddle.to_tensor(start_row_indices, dtype=paddle.int32).reshape((1, 1, S, 1)).repeat_interleave(B, 0) + + seq_cusums = np.cumsum(doc_seq_lens) + end_row_indices = [seq_cusums[-2]] * seq_cusums[-2] + [seq_cusums[-1]] * doc_seq_lens[-1] + [S] * padding + end_row_indices = paddle.to_tensor(end_row_indices, dtype=paddle.int32).reshape((1, 1, S, 1)).repeat_interleave(B, 0) + + startend_row_indices = paddle.concat([start_row_indices, end_row_indices], axis=-1) + + causal = True + return startend_row_indices, causal + +def generate_prefix_lm_document_mask(B, S, H, D, doc_seq_lens=[(1024, 2538), (1742, 1742), (512, 3213)]): + """ + tuple(prefix_length, seq_length) + """ + assert len(doc_seq_lens) >= 2 + total_seq_len = 0 + for prefix_length, seq_length in doc_seq_lens: + total_seq_len += seq_length + assert total_seq_len <= S + padding = S - total_seq_len + + down_left_row_indices = [] + cur_len_so_far = doc_seq_lens[0][1] + for i in range(len(doc_seq_lens)): + down_left_row_indices.extend([cur_len_so_far] * doc_seq_lens[i][1]) + if i < len(doc_seq_lens) - 1: + cur_len_so_far += doc_seq_lens[i+1][1] + if padding > 0: + down_left_row_indices.extend([cur_len_so_far] * padding) + down_left_row_indices = paddle.to_tensor(down_left_row_indices, dtype=paddle.int32).reshape((1, 1, S, 1)).repeat_interleave(B, 0) + + up_right_row_indices = [] + cur_len_so_far = 0 + for prefix_length, seq_length in doc_seq_lens: + up_right_row_indices.extend([cur_len_so_far] * prefix_length + list(range(cur_len_so_far+prefix_length, cur_len_so_far+seq_length))) + cur_len_so_far += seq_length + if padding > 0: + up_right_row_indices.extend([total_seq_len] * padding) + up_right_row_indices = paddle.to_tensor(up_right_row_indices, dtype=paddle.int32).reshape((1, 1, S, 1)).repeat_interleave(B, 0) + + startend_row_indices = paddle.concat([down_left_row_indices, up_right_row_indices], axis=-1) + + causal = False + return startend_row_indices, causal + +def generate_prefix_lm_causal_mask(B, S, H, D, prefix_length=1024): + """ + tuple(prefix_length, seq_length) + """ + assert prefix_length <= S + down_left_row_indices = paddle.full([S], S, dtype=paddle.int32).reshape((1, 1, S, 1)).repeat_interleave(B, 0) + up_right_row_indices = paddle.to_tensor([0] * prefix_length + list(range(prefix_length, S)), dtype=paddle.int32).reshape((1, 1, S, 1)).repeat_interleave(B, 0) + startend_row_indices = paddle.concat([down_left_row_indices, up_right_row_indices], axis=-1) + + causal = False + return startend_row_indices, causal + +def generate_qk_sparse_mask(B, S, H, D, maskout_pair=[(1024, 538), (2358, 1700)]): + """ + tuple(offset, maskout_len) + """ + start_row_indices = [] + end_row_indices = [] + last_offset = 0 + for offset, maskout_len in maskout_pair: + assert offset > last_offset + start_row_indices.extend([S]*(offset-last_offset)) + end_row_indices.extend([S]*(offset-last_offset)) + + start_row_indices.extend(list(range(offset, offset+maskout_len))) + end_row_indices.extend([offset+maskout_len]*(maskout_len)) + + last_offset = offset + maskout_len + + last_offset <= S + start_row_indices.extend([S]*(S-last_offset)) + end_row_indices.extend([S]*(S-last_offset)) + + start_row_indices = paddle.to_tensor(start_row_indices, dtype=paddle.int32).reshape((1, 1, S, 1)).repeat_interleave(B, 0) + end_row_indices = paddle.to_tensor(end_row_indices, dtype=paddle.int32).reshape((1, 1, S, 1)).repeat_interleave(B, 0) + startend_row_indices = paddle.concat([start_row_indices, end_row_indices], axis=-1) + + causal = True + return startend_row_indices, causal + +#def generate_hash_sparse_mask(B, S, H, D, maskout_pair=[(1024, 538), (2358, 1700)]): +# """ +# tuple(offset, maskout_len) +# """ +# start_row_indices = [] +# end_row_indices = [] +# last_offset = 0 +# for offset, maskout_len in maskout_pair: +# assert offset > last_offset +# start_row_indices.append([S]*(offset-last_offset)) +# end_row_indices.append([S]*(offset-last_offset)) +# +# start_row_indices.append(list(range(offset, offset+maskout_len))) +# end_row_indices.append([offset+maskout_len]*(maskout_len)) +# +# last_offset = offset + maskout_len +# +# last_offset <= S +# start_row_indices.append([S]*(S-last_offset)) +# end_row_indices.append([S]*(S-last_offset)) +# +# start_row_indices = paddle.to_tensor(start_row_indices, dtype=paddle.int32).reshape((1, 1, S, 1)).repeat_interleave(B, 0) +# end_row_indices = paddle.to_tensor(end_row_indices, dtype=paddle.int32).reshape((1, 1, S, 1)).repeat_interleave(B, 0) +# startend_row_indices = paddle.concat([down_left_row_indices, up_right_row_indices], axis=-1) +# +# causal = False +# return startend_row_indices, causal + + +def generate_random_eviction_mask(B, S, H, D, start_row=4096): + np.random.seed(0) + start_rows_list = [] + for bz_idx in range(B): + for head_idx in range(H): + start_rows = np.array([S+1] * S) + mask_pos = np.random.choice(S-1, S - start_row, replace=False) + index = np.arange(start_row, S) + mask_pos = np.concatenate([mask_pos[mask_pos < index - 1], mask_pos[mask_pos >= index - 1]]) + start_rows[mask_pos] = index + min_index = np.arange(1,S+1) + start_rows = np.maximum(start_rows, min_index) + start_rows_list.append(start_rows) + startend_row_indices = paddle.to_tensor(start_rows_list, dtype=paddle.int32).reshape((B, H, S, 1)) + causal = True + return startend_row_indices, causal + +def gen_varlen(cu_seqlens): + # 初始化mask + mask = np.zeros((1, 1, 32768, 2), dtype=np.int32) + mask[0, 0, :, 0] = np.arange(32768, dtype=np.int32) + mask[0, 0, :, 1] = np.arange(32768, dtype=np.int32) + + # 按规则填充mask[..., 0] + for i in range(len(cu_seqlens)-1): + start = cu_seqlens[i] + end = cu_seqlens[i+1] + mask[0, 0, start:end, 0] = cu_seqlens[i+1] + + # mask[..., 1] = i + mask[0, 0, :, 1] = np.arange(32768, dtype=np.int32) + + # 转为 paddle tensor + mask_tensor = paddle.to_tensor(mask, dtype='int32') + return mask_tensor + + +def gen_varlen_causal(cu_seqlens): + # 初始化mask + mask = np.zeros((1, 1, 32768, 1), dtype=np.int32) + mask[0, 0, :, 0] = np.arange(32768, dtype=np.int32) +# mask[0, 0, :, 1] = np.arange(32768, dtype=np.int32) + + # 按规则填充mask[..., 0] + for i in range(len(cu_seqlens)-1): + start = cu_seqlens[i] + end = cu_seqlens[i+1] + mask[0, 0, start:end, 0] = cu_seqlens[i+1] + + # mask[..., 1] = i +# mask[0, 0, :, 1] = np.arange(32768, dtype=np.int32) + + # 转为 paddle tensor + mask_tensor = paddle.to_tensor(mask, dtype='int32') + print(mask_tensor) + return mask_tensor + +def main(examples: List[str] = ["all"], dtype='bf16'): + """Run the benchmark with the given examples. + + Args: + examples: List of examples to run. If "all" is specified, all examples will be run. + """ + total_length = 0 + paddle.set_flags({'FLAGS_flash_attn_version': 3}) + doc_seq_lens_list = [] + rank = paddle.distributed.get_rank() + + + + #doc_seq_lens_list = doc_seq_lens_list[::-1] + for D in [128]: + H = 1 + # print(doc_seq_lens_list) + results = [] + for idx in range(0,50): + B = 1 + startend_row_indices = paddle.load(f'/root/paddlejob/workspace/env_run/xiehaoyang/flashmask/flashmask-cp/cp_balance/dump_32k_startend_row_indices/startend_row_indices_{idx}.pdparams') + print(startend_row_indices) + S = 32768 + print(f"{B}_{S}_{H}_{D}_{idx}_{dtype}") + + + available_examples = { + # "Full": lambda: test_cp_famask(generate_mask_fn=partial(generate_none_mask, causal=False), B=B, S=S, H=H, D=D, dtype=dtype), + # "Causal": lambda: test_cp_famask(generate_mask_fn=partial(generate_none_mask, causal=True), B=B, S=S, H=H, D=D, dtype=dtype), + # "Sliding Window": lambda: test_cp_famask(generate_mask_fn=partial(generate_sliding_window_mask, window_size=int(S*0.0625)), B=B, S=S, H=H, D=D, dtype=dtype), + # "Causal Document Mask": lambda: test_cp_famask(generate_mask_fn=partial(generate_causal_document_mask, doc_seq_lens=doc_seq_lens), B=B, S=S, H=H, D=D, dtype=dtype), + "Document Mask": lambda: test_cp_famask(startend_row_indices = startend_row_indices, B=B, S=S, H=H, D=D, dtype=dtype), + # "Share Question Mask": lambda: test_cp_famask(generate_mask_fn=partial(generate_share_question_mask, doc_seq_lens=share_qa_docs), B=B, S=S, H=H, D=D, dtype=dtype), + # "Global Sliding Window": lambda: test_cp_famask(generate_mask_fn=partial(generate_global_sliding_window_mask, global_token=16, window_size=(int(S*0.0625), int(S*0.0625))), B=B, S=S, H=H, D=D, dtype=dtype), + # "Causal Blockwise Mask": lambda: test_cp_famask(generate_mask_fn=partial(generate_causal_blockwise_mask, doc_seq_lens=doc_seq_lens), B=B, S=S, H=H, D=D, dtype=dtype), + # "Prefix LM Document Mask": lambda: test_cp_famask(generate_mask_fn=partial(generate_prefix_lm_document_mask, doc_seq_lens=prefix_doc_seq_lens), B=B, S=S, H=H, D=D, dtype=dtype), + # "Prefix LM Causal Mask": lambda: test_cp_famask(generate_mask_fn=partial(generate_prefix_lm_causal_mask, prefix_length=int(S*0.5)), B=B, S=S, H=H, D=D, dtype=dtype), + # "QK-sparse Mask": lambda: test_cp_famask(generate_mask_fn=partial(generate_qk_sparse_mask, maskout_pair=maskout_pair), B=B, S=S, H=H, D=D, dtype=dtype), + # "Random Eviction Mask": lambda: test_cp_famask(generate_mask_fn=partial(generate_random_eviction_mask, start_row=S//2), B=B, S=S, H=H, D=D, dtype=dtype), + } + + global total_num + total_num = len(available_examples) + + if "all" in examples: + ex_to_run = list(available_examples.keys()) + else: + ex_to_run = examples + + for ex in ex_to_run: + if ex in available_examples: + print(ex) + fw_time, bw_time, total_time, balance_time, scatter_x_time, gather_x_time = available_examples[ex]() + results.append([idx, f"{fw_time:.4f}", f"{bw_time:.4f}", f"{total_time:.4f}", f"{balance_time:.4f}", f"{scatter_x_time:.4f}", f"{gather_x_time:.4f}"]) + print(fw_time, bw_time) + else: + print(f"Warning: Unknown example key '{ex}'. Skipping.") + # if(idx >= 3): + # return + + # print(f'avg_fwd_time:{sum([float(result[1][:-1]) for result in results]) / len(results)} avg_bwd_time:{sum([float(result[2][:-1]) for result in results]) / len(results)}') + headers = [ + "Idx", + "FW Time (ms)", + "BW Time (ms)", + "TOTAL Time (ms)", + "BALANCE Time (ms)", + "SCATTER X Time (ms)", + "GATHER X Time (ms)" + ] + print( + tabulate( + results, + headers=headers, + tablefmt="grid", + ) + ) + content2=tabulate(results, headers=headers, tablefmt="tsv") + os.makedirs(f"{dtype}_dist_test_dump", exist_ok=True) + text_file = open(f"{dtype}_dist_test_dump/flashmask_{rank}_{B}_{S}_{H}_{D}.csv","w") + text_file.write(content2) + text_file.close() + + print(f'avg_fwd_time:{sum([float(result[1][:-1]) for result in results]) / len(results)} avg_bwd_time:{sum([float(result[2][:-1]) for result in results]) / len(results)}') + # assert False + +if __name__ == "__main__": + try: + from jsonargparse import ArgumentParser + except ImportError: + raise ImportError("Be sure to run: pip install -e .'[viz]'") + parser = ArgumentParser(description="Run specific examples or all examples.") + parser.add_argument( + "--examples", + type=str, + nargs="+", + default=["all"], + help="List of examples to run. Use space to separate multiple examples. " + "Available options: causal, alibi, sliding_window, prefix_lm, " + "document, softcap, softcap_approx, or 'all' to run all examples.", + ) + parser.add_argument( + "--dtype", + type=str, + default="bf16" + ) + + args = parser.parse_args() + main(**vars(args)) diff --git a/csrc/utils/cp_balance/test/test_cp_balance_balancex.py b/csrc/utils/cp_balance/test/test_cp_balance_balancex.py new file mode 100644 index 00000000000..16d73d524f3 --- /dev/null +++ b/csrc/utils/cp_balance/test/test_cp_balance_balancex.py @@ -0,0 +1,920 @@ +import numpy as np +from functools import partial +from typing import Optional, List +from tabulate import tabulate +import paddle +import os +import paddle.nn.functional as F +from paddle.nn.functional.flash_attention import flashmask_attention +from cp_balance.context_parallel_utils import flashmask_attention_cp,scatter_balance, all_gather_balance +from cp_balance.cp_balance import assign_tasks_heap, get_q_workload,balance_flashmask_input, tensor_md5,balance_alltoall + +import paddle.distributed.fleet as fleet +import time + +cp_size = 4 +strategy = fleet.DistributedStrategy() + +strategy.hybrid_configs = { + "dp_degree": 1, + "mp_degree": 2, + "pp_degree": 1, + "sharding_degree": 4, + "sep_degree": 1, + "ep_degree": 8, + "moe_sharding_degree": 1, + "cp_degree": cp_size, + "order": ["sharding", "moe_sharding", "pp", "sep", "cp", "dp", "ep", "mp"] +} + +fleet.init(is_collective=True, strategy=strategy) +cp_group = fleet.get_hybrid_communicate_group().get_context_parallel_group() + +class bcolors: + HEADER = '\033[95m' + OKBLUE = '\033[94m' + OKCYAN = '\033[96m' + OKGREEN = '\033[92m' + WARNING = '\033[93m' + FAIL = '\033[91m' + ENDC = '\033[0m' + BOLD = '\033[1m' + UNDERLINE = '\033[4m' + +def from_paddle(x: paddle.Tensor): + if x.dtype == paddle.bfloat16 or x.dtype == "bfloat16": + return torch.from_numpy(x.view("uint16").numpy()).to("cuda").view(torch.bfloat16) + elif x.dtype == paddle.float32 or x.dtype == "float32": + return torch.from_numpy(x.numpy()).to("cuda") + else: + assert False + +def _summarize_statistics(times, quantiles, return_mode): + if quantiles is not None: + ret = paddle.quantile(times, paddle.to_tensor(quantiles, dtype=paddle.float32)).tolist() + if len(ret) == 1: + ret = ret[0] + return ret + if return_mode == "all": + return times.tolist() + return getattr(paddle, return_mode)(times).item() + +def split_sequence(sequence_length, num_answers=2): + if sequence_length < num_answers + 1: + raise ValueError(f"序列长度必须至少为 {num_answers + 1}") + + base = sequence_length // (num_answers + 1) + extra = sequence_length % (num_answers + 1) + # 前extra个部分多加1 + lengths = [base + (1 if i < extra else 0) for i in range(num_answers + 1)] + + return lengths + +def do_bench_flashmaskcp(q_local, k_local, v_local, o_grad_local, startend_row_indices, group, is_causal,bucket = None, warmup=50, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean"): + """ + Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with + the 20-th and 80-th performance percentile. + + :param fn: Function to benchmark + :type fn: Callable + :param warmup: Warmup time (in ms) + :type warmup: int + :param rep: Repetition time (in ms) + :type rep: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + :param quantiles: Performance percentile to return in addition to the median. + :type quantiles: list[float], optional + :param fast_flush: Use faster kernel to flush L2 cache between measurements + :type fast_flush: bool, default is True + :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all" Default is "mean". :type return_mode: str + """ + assert return_mode in ["min", "max", "mean", "median", "all"] + + rank = paddle.distributed.get_rank() + out_local = flashmask_attention_cp(q_local, k_local, v_local, startend_row_indices, causal=is_causal) + # print('pt00') + out_local.backward(o_grad_local) + # print('pt0') + paddle.distributed.barrier(group=cp_group) + paddle.device.synchronize() + # print('here') + + # We maintain a buffer of 256 MB that we clear + # before each kernel call to make sure that the L2 cache + # doesn't contain any input data before the run + cache_size = 256 * 1024 * 1024 + if fast_flush: + cache = paddle.empty([int(cache_size // 4)], dtype=paddle.int32) + else: + cache = paddle.empty([int(cache_size)], dtype=paddle.int8) + + # Estimate the runtime of the function + start_event = paddle.device.Event(enable_timing=True) + end_event = paddle.device.Event(enable_timing=True) + start_event.record() + for _ in range(5): + time.sleep(0.1) + cache.zero_() + out_local = flashmask_attention_cp(q_local, k_local, v_local, startend_row_indices, causal=is_causal, bucket=bucket) + out_local.backward(o_grad_local) + paddle.distributed.barrier(group=cp_group) + paddle.device.synchronize() + end_event.record() + paddle.distributed.barrier(group=cp_group) + paddle.device.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # print('pt2') + # compute number of warmup and repeat + n_warmup = max(3, int(warmup / estimate_ms)) + n_repeat = max(5, int(rep / estimate_ms)) + start_event = [paddle.device.Event(enable_timing=True) for i in range(n_repeat)] + end_event = [paddle.device.Event(enable_timing=True) for i in range(n_repeat)] + # Warm-up + for _ in range(n_warmup): + time.sleep(0.1) + out_local =flashmask_attention_cp(q_local, k_local, v_local, startend_row_indices, causal=is_causal, bucket=bucket) + out_local.backward(o_grad_local, retain_graph=True) + paddle.distributed.barrier(group=cp_group) + paddle.device.synchronize() + # Benchmark + times_fwd = [] + times_bwd = [] + for i in range(n_repeat): + time.sleep(0.1) + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + paddle.distributed.barrier(group=cp_group) + paddle.device.synchronize() + paddle.base.core.nvprof_nvtx_push(f"flashmask_cp_fwd_{rank}") + t0 = time.perf_counter() + out_local = flashmask_attention_cp(q_local, k_local, v_local, startend_row_indices, causal=is_causal, bucket=bucket) + paddle.distributed.barrier(group=cp_group) + paddle.device.synchronize() + paddle.base.core.nvprof_nvtx_pop() + paddle.base.core.nvprof_nvtx_push(f"flashmask_cp_bwd_{rank}") + t1 = time.perf_counter() + out_local.backward(o_grad_local, retain_graph=True) + paddle.distributed.barrier(group=cp_group) + paddle.device.synchronize() + paddle.base.core.nvprof_nvtx_pop() + t2 = time.perf_counter() + times_fwd.append(1000 * (t1 - t0)) + times_bwd.append(1000 * (t2 - t1)) + + # Record clocks + print(times_bwd) + paddle.distributed.barrier(group=cp_group) + paddle.device.synchronize() + # print('pt3') + return sum(times_fwd) / n_repeat, sum(times_bwd) / n_repeat + +def cp_flashmask_balance_bench(q, k, v, startend_row_indices, is_causal,o_grad,bucket): + group = cp_group + rank = group.rank + print(f"rank: {rank}") + q_blocksize = (int)(q.shape[1] // (2 * cp_size)) + k_blocksize = (int)(k.shape[1] // cp_size) + q_local_1 = q[:, rank*q_blocksize:(rank+1)*q_blocksize, :, :] + q_local_2 = q[:, (cp_size *2 -rank -1)*q_blocksize:(cp_size *2 -rank)*q_blocksize, :, :] + q_local = paddle.concat([q_local_1, q_local_2], axis=1).detach() + k_local = k[:, rank*k_blocksize:(rank+1)*k_blocksize, :, :].detach().contiguous() + v_local = v[:, rank*k_blocksize:(rank+1)*k_blocksize, :, :].detach().contiguous() + o_grad_local_1 = o_grad[:, rank * q_blocksize : (rank + 1) * q_blocksize, :, :].detach() + o_grad_local_2 = o_grad[:, (cp_size * 2 - rank - 1) * q_blocksize : (cp_size * 2 - rank) * q_blocksize, :, :].detach() + o_grad_local = paddle.concat([o_grad_local_1, o_grad_local_2], axis=1).contiguous() + + + q_local.stop_gradient = False + k_local.stop_gradient = False + v_local.stop_gradient = False + # startend_row_indices.stop_gradient = False + + cp_fwd_time, cp_bwd_time = do_bench_flashmaskcp(q_local, k_local, v_local, o_grad_local, startend_row_indices, group, is_causal, bucket) + # print(f"cp balance fwd+bwd time: {cp_fwd_bwd_time} ms\n") + return cp_fwd_time, cp_bwd_time + +def test_cp_famask( + startend_row_indices, + B: int = 16, + S: int = 8192, + H: int = 16, + D: int = 64, + dtype = 'bf16', +): + """ + 测试上下文并行FlashMask注意力机制的性能基准 + + 该函数用于测试在分布式并行环境中FlashMask注意力机制的前向传播和后向传播性能, + 支持不同类型的注意力掩码生成策略。 + + Args: + generate_mask_fn: 注意力掩码生成函数,用于生成startend_row_indices和因果关系标记 + B: 批次大小,默认16 + S: 序列长度,默认8192 + H: 注意力头数,默认16 + D: 每个注意力头的维度,默认64 + dtype: 数据类型,默认'bf16' + + Returns: + tuple: 包含前向传播时间和后向传播时间的元组 (fwd_time, bwd_time),单位为毫秒 + """ + # paddle.seed(2024) + paddle.seed(2024) + # batch_size = 1 + total_q = S + total_k = S + batch_size = B + num_head = H + num_head_q = 12 * H + head_size = D + rank = cp_group.rank + # total_k = total_q * 2 + if rank == 0: + query = paddle.randn([batch_size, total_q, num_head_q, head_size], dtype=paddle.bfloat16) + key = paddle.randn([batch_size, total_k, num_head, head_size], dtype=paddle.bfloat16) + value = paddle.randn([batch_size, total_k, num_head, head_size], dtype=paddle.bfloat16) + o_grad = paddle.randn([batch_size, total_q, num_head_q, head_size], dtype=paddle.bfloat16) + else: + query = paddle.empty([batch_size, total_q, num_head_q, head_size], dtype=paddle.bfloat16) + key = paddle.empty([batch_size, total_k, num_head, head_size], dtype=paddle.bfloat16) + value = paddle.empty([batch_size, total_k, num_head, head_size], dtype=paddle.bfloat16) + o_grad = paddle.empty([batch_size, total_q, num_head_q, head_size], dtype=paddle.bfloat16) + + # 广播到所有 rank + paddle.distributed.broadcast(query, src=cp_group.ranks[0],group=cp_group) + paddle.distributed.broadcast(key, src=cp_group.ranks[0],group=cp_group) + paddle.distributed.broadcast(value, src=cp_group.ranks[0],group=cp_group) + paddle.distributed.broadcast(o_grad,src = cp_group.ranks[0],group=cp_group) + paddle.distributed.barrier(group=cp_group) + query.stop_gradient = False + key.stop_gradient = False + value.stop_gradient = False + + causal = False + balance_q_chunksize = 2048 + + workload = get_q_workload(startend_row_indices, balance_q_chunksize, 128,128) + + # print(workload) + total_workload = paddle.sum(workload,axis = 1) + buckets, bucket_weights,cuts = assign_tasks_heap(workload.reshape(-1,2).cpu().numpy(), cp_size) + buckets0 = buckets.copy() + hcg = fleet.get_hybrid_communicate_group() + + ref_o = flashmask_attention(query, key, value, startend_row_indices, causal=causal) + ref_o.backward(o_grad) + x = query.reshape([batch_size, total_q, -1]).detach().contiguous() + local_qs = [] + local_ks = [] + local_vs = [] + local_ograds =[] + local_ref_os = [] + local_ref_grad_qs = [] + local_ref_grad_ks = [] + local_ref_grad_vs = [] + local_xs = [] + for(_, idx) in buckets[rank]: + local_qs.append(query[:,idx * balance_q_chunksize:(idx+1) * balance_q_chunksize,:,:]) + local_ks.append(key[:,idx * balance_q_chunksize:(idx+1) * balance_q_chunksize,:,:]) + local_vs.append(value[:,idx * balance_q_chunksize:(idx+1) * balance_q_chunksize,:,:]) + local_ograds.append(o_grad[:,idx * balance_q_chunksize:(idx+1) * balance_q_chunksize,:,:]) + local_ref_os.append(ref_o[:,idx * balance_q_chunksize:(idx+1) * balance_q_chunksize,:,:]) + local_ref_grad_qs.append(query.grad[:,idx * balance_q_chunksize:(idx+1) * balance_q_chunksize,:,:]) + local_ref_grad_ks.append(key.grad[:,idx * balance_q_chunksize:(idx+1) * balance_q_chunksize,:,:]) + local_ref_grad_vs.append(value.grad[:,idx * balance_q_chunksize:(idx+1) * balance_q_chunksize,:,:]) + local_xs.append(x[:,idx * balance_q_chunksize:(idx+1) * balance_q_chunksize,:]) + + local_q = paddle.concat(local_qs, axis=1).detach().contiguous() + local_ref_q = local_q.detach() + local_k = paddle.concat(local_ks, axis=1).detach().contiguous() + local_v = paddle.concat(local_vs, axis=1).detach().contiguous() + local_o_grad = paddle.concat(local_ograds, axis=1) + local_ref_o = paddle.concat(local_ref_os, axis=1) + local_ref_grad_q = paddle.concat(local_ref_grad_qs, axis=1) + local_ref_grad_k = paddle.concat(local_ref_grad_ks, axis=1) + local_ref_grad_v = paddle.concat(local_ref_grad_vs, axis=1) + local_ref_x = paddle.concat(local_xs, axis=1).detach().contiguous() + + local_startend_row_indices, buckets = balance_flashmask_input( startend_row_indices, cp_size, rank,balance_chunk_size= balance_q_chunksize) + assert buckets0 == buckets + local_q = scatter_balance(query, group = cp_group, axis=1,mode = "balanced_swap", buckets = buckets).detach().contiguous() + local_x = scatter_balance(x, group = cp_group, axis=1,mode = "balanced_swap", buckets = buckets) + gather_x = all_gather_balance(local_x, group = cp_group, axis=1,mode = "balanced_swap", buckets = buckets) + + local_k.stop_gradient = False + local_v.stop_gradient = False + # local_q = query[:,rank * S // cp_size:(rank+1) * S // cp_size,:,:].detach().contiguous() + # local_o_grad = o_grad[:,rank * S // cp_size:(rank+1) * S // cp_size,:,:] + # print(local_o_grad) + print(local_q.shape) + print(local_startend_row_indices.shape) + local_q.stop_gradient = False + local_o = flashmask_attention_cp(local_q, local_k, local_v, local_startend_row_indices, causal=causal, mode="balance_q") + local_ref_o1 = ref_o[:,rank * S // cp_size:(rank+1) * S // cp_size,:,:] + print(f"rank: {rank}") + local_o.backward(local_o_grad) + paddle.distributed.barrier(group=cp_group) + + if(rank == 0): + print(f"total_workload: {total_workload}") + print(bucket_weights) + print(cuts) + for i, bucket in enumerate(buckets): + workload_sum = 0 + for item in bucket: + workload_sum += item[0] + print(f"Bucket {i+1}: {workload_sum}") + print(f"Bucket {i+1}: {bucket}") + print(f"Bucket {i+1}: {len(bucket)}") + # x_np = buckets[0].numpy() + # np.savetxt('buckets.txt', x_np.reshape(-1, x_np.shape[-1]), fmt='%d') + # strict_check(local_o[:,:,0,0].flatten(), local_o1[:,:,0,0].flatten()) + # strict_check(local_o1[:,:,0,0].flatten(), local_ref_o[:,:,0,0].flatten()) + strict_check(local_q.flatten(), local_ref_q.flatten()) + strict_check(local_o.flatten(), local_ref_o.flatten()) + strict_check(local_k.grad.flatten(), local_ref_grad_k.flatten()) + strict_check(local_v.grad.flatten(), local_ref_grad_v.flatten()) + strict_check(local_q.grad.flatten(), local_ref_grad_q.flatten()) + strict_check(local_x.flatten(), local_ref_x.flatten()) + strict_check(gather_x, x) + + # startend_row_indices = regroup_chunks_by_buckets(startend_row_indices, buckets) + # print(buckets) + # if generate_mask_fn is not None: + # print("enter",generate_mask_fn) + # startend_row_indices = generate_mask_fn() + # startend_row_indices, causal = generate_mask_fn(total_q) + + # print(startend_row_indices) + # paddle.set_printoptions(precision=None, threshold=10000000, edgeitems=None, sci_mode=None, linewidth=None) + + # fwd_time, bwd_time = cp_flashmask_balance_bench(query, key, value, startend_row_indices, causal,o_grad,buckets[rank]) + fwd_time, bwd_time = 0,0 + paddle.device.synchronize() + # out1.backward(o_grad1) + # paddle.device.synchronize() + + # print("pypt2:") + # print(startend_row_indices) + total_time = fwd_time + bwd_time + return fwd_time, bwd_time, total_time + # with open("execution_times.txt", "a") as log_file: + # log_file.write(f"bsz: {batch_size},num_head_k: {num_head},num_head_q: {num_head * 4},hsz: {head_size},seqlen: {total_q}, flashattnv1: {flashattnv1_time:.6f}s, " + # f"flashattnv2: {flashattnv2_time:.6f}s\n") + # for x,y in [(out1,out),(dq1,query.grad),(dk1,key.grad),(dv1,value.grad)]: + # strict_check(x.flatten(), y.flatten()) + # for x,y in [(out1,out)]: + # strict_check(x.flatten(), y.flatten()) + +def strict_check(x, y): + if isinstance(x, paddle.Tensor): + if x.dtype == paddle.bfloat16 or x.dtype == "float16": + # x = x.view("float16").numpy() + x = x.cast("float32").numpy() + else: + x = x.numpy() + else: + assert False + + # if isinstance(y, torch.Tensor): + # if y.dtype == torch.bfloat16 or y.dtype == "bfloat16": + # # x = x.view("float16").numpy() + # y = y.to(torch.float32).detach().cpu().numpy() + # else: + # y = y.detach().cpu().numpy() + + if isinstance(y, paddle.Tensor): + if y.dtype == paddle.bfloat16 or y.dtype == "float16": + # y = y.view("float16").numpy() + y = y.cast("float32").numpy() + else: + y = y.numpy() + + try: + print(f"{x=}, {y=}") + np.testing.assert_allclose(x.flatten(), y.flatten(),rtol=1e-2, atol=1e-2) + except Exception as e: + print('---------------') + idx = np.where(~(x == y)) + print(f"fail idx: {idx=}") + print(f"shape:'{x.shape}'") + # print(f"fail idx:'{np.unique(idx[0])}'") + print(x[idx]) + print(y[idx]) + raise e + + +def ele_check(x, y): + if isinstance(x, paddle.Tensor): + if x.dtype == paddle.bfloat16 or x.dtype == "bfloat16": + # x = x.view("uint16").numpy() + x = x.cast("float32").numpy() + else: + x = x.numpy() + else: + assert False + + if isinstance(y, torch.Tensor): + if y.dtype == torch.bfloat16 or y.dtype == "bfloat16": + # x = x.view("uint16").numpy() + y = y.to(torch.float32).detach().cpu().numpy() + else: + y = y.detach().cpu().numpy() + + # if isinstance(y, paddle.Tensor): + # if y.dtype == paddle.bfloat16 or y.dtype == "bfloat16": + # # y = y.view("uint16").numpy() + # y = y.cast("float32").numpy() + # else: + # y = y.numpy() + + try: + print(f"{x=}, {y=}") + np.testing.assert_allclose(np.sort(x.flatten()), np.sort(y.flatten()),rtol=1e-3, atol=1e-6) + except Exception as e: + print('---------------') + idx = np.where(~(x == y)) + print(f"fail idx: {idx=}") + print(f"shape:'{x.shape}'") + # print(f"fail idx:'{np.unique(idx[0])}'") + print(x[idx]) + print(y[idx]) + raise e + +def flashmask_to_densemask(startend_row_indices, dtype, causal=True): + if startend_row_indices is None: + return None + bz, num_head, seq_len, bound_num = startend_row_indices.shape + m = paddle.zeros((bz, num_head, seq_len, seq_len), dtype=dtype) + has_end = (causal and bound_num == 2) or ((not causal) and bound_num == 4) + for bi in range(bz): + for hi in range(num_head): + for j in range(seq_len): + downstart = startend_row_indices[bi, hi, j, 0] + if has_end: + downend = startend_row_indices[bi, hi, j, 1] + m[bi, hi, downstart:downend, j] = -np.inf + else: + m[bi, hi, downstart:, j] = -np.inf + if causal: + m[bi, hi, :j, j] = -np.inf + else: + if has_end: + upstart = startend_row_indices[bi, hi, j, 2] + upend = startend_row_indices[bi, hi, j, 3] + m[bi, hi, upstart:upend, j] = -np.inf + else: + upend = startend_row_indices[bi, hi, j, 1] + m[bi, hi, :upend, j] = -np.inf + return m + +def generate_none_mask(B, S, H, D, causal=True): + return None, causal + +def generate_ones_mask(B, S, H, D): + startend_row_indices = paddle.zeros( + shape=(B, H, S, 2), dtype="int32" + ) + startend_row_indices[:,:,:,0]=S + causal = False + return startend_row_indices, causal + +def generate_causal_mask(B,S,H,D): + startend_row_indices = paddle.zeros( + shape=(B, H, S, 1), dtype="int32" + ) + startend_row_indices[:,:,:,0]=S + causal = True + return startend_row_indices, causal + +def generate_sliding_window_mask(B, S, H, D, window_size=1024): + startend_row_indices = paddle.arange( + window_size, S + window_size, dtype="int32" + ).reshape((1, 1, S, 1)) + startend_row_indices = paddle.clip( + startend_row_indices, max=S + ).repeat_interleave(B, 0) + + causal=True + return startend_row_indices, causal + +# def generate_causal_document_mask(B, S, H, D, doc_seq_lens=[2538, 1742, 3213]): +def generate_causal_document_mask(B,S,H,D, doc_seq_lens=[2538, 1742, 3213]): + total_seq_len = np.sum(doc_seq_lens) + assert total_seq_len <= S, f"{total_seq_len=}, {S=}" + padding = S - np.sum(doc_seq_lens) + doc_seq_lens[-1] += padding + seq_cusums = np.cumsum(doc_seq_lens) + + startend_row_indices = np.repeat(seq_cusums, doc_seq_lens) + startend_row_indices = paddle.to_tensor(startend_row_indices, dtype=paddle.int32).reshape((1, 1, S, 1)) + startend_row_indices = startend_row_indices.repeat_interleave(B, 0) + + causal = True + return startend_row_indices, causal + +def generate_upper_document_mask(B,S,H,D, doc_seq_lens=[2538, 1742, 3213],padding_size = 256): + total_seq_len = np.sum(doc_seq_lens) + assert total_seq_len <= S + padding = S - np.sum(doc_seq_lens) + + up_right_row_indices = [] + + cur_len_so_far = 0 + for i in range(len(doc_seq_lens)): + up_right_row_indices.extend([cur_len_so_far] * doc_seq_lens[i]) + if i < len(doc_seq_lens) -1: + cur_len_so_far += doc_seq_lens[i] + if padding > 0: + up_right_row_indices.extend([cur_len_so_far] * padding) + + up_right_row_indices = paddle.to_tensor(up_right_row_indices, dtype=paddle.int32).reshape((1, 1, S, 1)).repeat_interleave(B, 0) + down_left_row_indices = paddle.ones_like(up_right_row_indices) * (S - padding_size) + startend_row_indices = paddle.concat([down_left_row_indices, up_right_row_indices], axis=-1) + + causal = False + return startend_row_indices, causal + +def generate_document_mask(B, S, H, D, doc_seq_lens=[2538, 1742, 3213]): + total_seq_len = np.sum(doc_seq_lens) + assert total_seq_len <= S + padding = S - np.sum(doc_seq_lens) + + down_left_row_indices = [] + up_right_row_indices = [] + + cur_len_so_far = doc_seq_lens[0] + for i in range(len(doc_seq_lens)): + down_left_row_indices.extend([cur_len_so_far] * doc_seq_lens[i]) + if i < len(doc_seq_lens) -1: + cur_len_so_far += doc_seq_lens[i+1] + if padding > 0: + down_left_row_indices.extend([cur_len_so_far] * padding) + + cur_len_so_far = 0 + for i in range(len(doc_seq_lens)): + up_right_row_indices.extend([cur_len_so_far] * doc_seq_lens[i]) + if i < len(doc_seq_lens) -1: + cur_len_so_far += doc_seq_lens[i] + if padding > 0: + up_right_row_indices.extend([cur_len_so_far] * padding) + + down_left_row_indices = paddle.to_tensor(down_left_row_indices, dtype=paddle.int32).reshape((1, 1, S, 1)).repeat_interleave(B, 0) + up_right_row_indices = paddle.to_tensor(up_right_row_indices, dtype=paddle.int32).reshape((1, 1, S, 1)).repeat_interleave(B, 0) + startend_row_indices = paddle.concat([down_left_row_indices, up_right_row_indices], axis=-1) + + causal = False + return startend_row_indices, causal + +def generate_share_question_mask(B, S, H, D, doc_seq_lens=[2538, 1742, 3213]): + total_seq_len = np.sum(doc_seq_lens) + assert total_seq_len <= S + assert len(doc_seq_lens) >= 3 + padding = S - total_seq_len + + startend_row_indices = [S] * doc_seq_lens[0] + + cur_len_so_far = doc_seq_lens[0] + for idx in range(1, len(doc_seq_lens)): + cur_len_so_far += doc_seq_lens[idx] + startend_row_indices.extend([cur_len_so_far] * doc_seq_lens[idx]) + + if padding > 0: + startend_row_indices.extend([cur_len_so_far] * padding) + + startend_row_indices = paddle.to_tensor(startend_row_indices, dtype=paddle.int32).reshape((1, 1, S, 1)).repeat_interleave(B, 0) + + causal = True + return startend_row_indices, causal + +def generate_global_sliding_window_mask(B, S, H, D, global_token=16, window_size=(512, 512)): + assert len(window_size) == 2 + left_window_size, right_window_size = window_size + + down_left_start_row_indices = [] + down_left_end_row_indices = [] + up_right_start_row_indices = [] + up_right_end_row_indices = [] + + down_left_start_row_indices = paddle.arange( + left_window_size + 1, S + left_window_size + 1, dtype="int32" + ).clip(max=S) + down_left_start_row_indices[:global_token] = S + down_left_start_row_indices = down_left_start_row_indices.reshape((1, 1, S, 1)).repeat_interleave(B, 0) + + down_left_end_row_indices = paddle.full([S], S, dtype="int32").reshape((1, 1, S, 1)).repeat_interleave(B, 0) + + up_right_start_row_indices = paddle.full([S], global_token, dtype="int32") + up_right_start_row_indices[:global_token+right_window_size+1] = 0 + up_right_start_row_indices = up_right_start_row_indices.reshape((1, 1, S, 1)).repeat_interleave(B, 0) + + up_right_end_row_indices = paddle.arange( + -right_window_size, S - right_window_size, dtype="int32" + ) + up_right_end_row_indices[:global_token+right_window_size+1] = 0 + up_right_end_row_indices = up_right_end_row_indices.reshape((1, 1, S, 1)).repeat_interleave(B, 0) + + startend_row_indices = paddle.concat([down_left_start_row_indices, down_left_end_row_indices, up_right_start_row_indices, up_right_end_row_indices], axis=-1) + + causal = False + return startend_row_indices, causal + +def generate_causal_blockwise_mask(B, S, H, D, doc_seq_lens=[2538, 1742, 3213]): + total_seq_len = np.sum(doc_seq_lens) + assert total_seq_len <= S + assert len(doc_seq_lens) >= 3 + padding = S - np.sum(doc_seq_lens) + + start_row_indices = [] + cur_len_so_far = doc_seq_lens[0] + for i in range(len(doc_seq_lens)): + start_row_indices.extend([cur_len_so_far] * doc_seq_lens[i]) + if i < len(doc_seq_lens) - 1: + cur_len_so_far += doc_seq_lens[i+1] + if padding > 0: + start_row_indices.extend([cur_len_so_far] * padding) + start_row_indices = paddle.to_tensor(start_row_indices, dtype=paddle.int32).reshape((1, 1, S, 1)).repeat_interleave(B, 0) + + seq_cusums = np.cumsum(doc_seq_lens) + end_row_indices = [seq_cusums[-2]] * seq_cusums[-2] + [seq_cusums[-1]] * doc_seq_lens[-1] + [S] * padding + end_row_indices = paddle.to_tensor(end_row_indices, dtype=paddle.int32).reshape((1, 1, S, 1)).repeat_interleave(B, 0) + + startend_row_indices = paddle.concat([start_row_indices, end_row_indices], axis=-1) + + causal = True + return startend_row_indices, causal + +def generate_prefix_lm_document_mask(B, S, H, D, doc_seq_lens=[(1024, 2538), (1742, 1742), (512, 3213)]): + """ + tuple(prefix_length, seq_length) + """ + assert len(doc_seq_lens) >= 2 + total_seq_len = 0 + for prefix_length, seq_length in doc_seq_lens: + total_seq_len += seq_length + assert total_seq_len <= S + padding = S - total_seq_len + + down_left_row_indices = [] + cur_len_so_far = doc_seq_lens[0][1] + for i in range(len(doc_seq_lens)): + down_left_row_indices.extend([cur_len_so_far] * doc_seq_lens[i][1]) + if i < len(doc_seq_lens) - 1: + cur_len_so_far += doc_seq_lens[i+1][1] + if padding > 0: + down_left_row_indices.extend([cur_len_so_far] * padding) + down_left_row_indices = paddle.to_tensor(down_left_row_indices, dtype=paddle.int32).reshape((1, 1, S, 1)).repeat_interleave(B, 0) + + up_right_row_indices = [] + cur_len_so_far = 0 + for prefix_length, seq_length in doc_seq_lens: + up_right_row_indices.extend([cur_len_so_far] * prefix_length + list(range(cur_len_so_far+prefix_length, cur_len_so_far+seq_length))) + cur_len_so_far += seq_length + if padding > 0: + up_right_row_indices.extend([total_seq_len] * padding) + up_right_row_indices = paddle.to_tensor(up_right_row_indices, dtype=paddle.int32).reshape((1, 1, S, 1)).repeat_interleave(B, 0) + + startend_row_indices = paddle.concat([down_left_row_indices, up_right_row_indices], axis=-1) + + causal = False + return startend_row_indices, causal + +def generate_prefix_lm_causal_mask(B, S, H, D, prefix_length=1024): + """ + tuple(prefix_length, seq_length) + """ + assert prefix_length <= S + down_left_row_indices = paddle.full([S], S, dtype=paddle.int32).reshape((1, 1, S, 1)).repeat_interleave(B, 0) + up_right_row_indices = paddle.to_tensor([0] * prefix_length + list(range(prefix_length, S)), dtype=paddle.int32).reshape((1, 1, S, 1)).repeat_interleave(B, 0) + startend_row_indices = paddle.concat([down_left_row_indices, up_right_row_indices], axis=-1) + + causal = False + return startend_row_indices, causal + +def generate_qk_sparse_mask(B, S, H, D, maskout_pair=[(1024, 538), (2358, 1700)]): + """ + tuple(offset, maskout_len) + """ + start_row_indices = [] + end_row_indices = [] + last_offset = 0 + for offset, maskout_len in maskout_pair: + assert offset > last_offset + start_row_indices.extend([S]*(offset-last_offset)) + end_row_indices.extend([S]*(offset-last_offset)) + + start_row_indices.extend(list(range(offset, offset+maskout_len))) + end_row_indices.extend([offset+maskout_len]*(maskout_len)) + + last_offset = offset + maskout_len + + last_offset <= S + start_row_indices.extend([S]*(S-last_offset)) + end_row_indices.extend([S]*(S-last_offset)) + + start_row_indices = paddle.to_tensor(start_row_indices, dtype=paddle.int32).reshape((1, 1, S, 1)).repeat_interleave(B, 0) + end_row_indices = paddle.to_tensor(end_row_indices, dtype=paddle.int32).reshape((1, 1, S, 1)).repeat_interleave(B, 0) + startend_row_indices = paddle.concat([start_row_indices, end_row_indices], axis=-1) + + causal = True + return startend_row_indices, causal + +#def generate_hash_sparse_mask(B, S, H, D, maskout_pair=[(1024, 538), (2358, 1700)]): +# """ +# tuple(offset, maskout_len) +# """ +# start_row_indices = [] +# end_row_indices = [] +# last_offset = 0 +# for offset, maskout_len in maskout_pair: +# assert offset > last_offset +# start_row_indices.append([S]*(offset-last_offset)) +# end_row_indices.append([S]*(offset-last_offset)) +# +# start_row_indices.append(list(range(offset, offset+maskout_len))) +# end_row_indices.append([offset+maskout_len]*(maskout_len)) +# +# last_offset = offset + maskout_len +# +# last_offset <= S +# start_row_indices.append([S]*(S-last_offset)) +# end_row_indices.append([S]*(S-last_offset)) +# +# start_row_indices = paddle.to_tensor(start_row_indices, dtype=paddle.int32).reshape((1, 1, S, 1)).repeat_interleave(B, 0) +# end_row_indices = paddle.to_tensor(end_row_indices, dtype=paddle.int32).reshape((1, 1, S, 1)).repeat_interleave(B, 0) +# startend_row_indices = paddle.concat([down_left_row_indices, up_right_row_indices], axis=-1) +# +# causal = False +# return startend_row_indices, causal + + +def generate_random_eviction_mask(B, S, H, D, start_row=4096): + np.random.seed(0) + start_rows_list = [] + for bz_idx in range(B): + for head_idx in range(H): + start_rows = np.array([S+1] * S) + mask_pos = np.random.choice(S-1, S - start_row, replace=False) + index = np.arange(start_row, S) + mask_pos = np.concatenate([mask_pos[mask_pos < index - 1], mask_pos[mask_pos >= index - 1]]) + start_rows[mask_pos] = index + min_index = np.arange(1,S+1) + start_rows = np.maximum(start_rows, min_index) + start_rows_list.append(start_rows) + startend_row_indices = paddle.to_tensor(start_rows_list, dtype=paddle.int32).reshape((B, H, S, 1)) + causal = True + return startend_row_indices, causal + +def gen_varlen(cu_seqlens): + # 初始化mask + mask = np.zeros((1, 1, 32768, 2), dtype=np.int32) + mask[0, 0, :, 0] = np.arange(32768, dtype=np.int32) + mask[0, 0, :, 1] = np.arange(32768, dtype=np.int32) + + # 按规则填充mask[..., 0] + for i in range(len(cu_seqlens)-1): + start = cu_seqlens[i] + end = cu_seqlens[i+1] + mask[0, 0, start:end, 0] = cu_seqlens[i+1] + + # mask[..., 1] = i + mask[0, 0, :, 1] = np.arange(32768, dtype=np.int32) + + # 转为 paddle tensor + mask_tensor = paddle.to_tensor(mask, dtype='int32') + print(mask_tensor) + return mask_tensor + + +def gen_varlen_causal(cu_seqlens): + # 初始化mask + mask = np.zeros((1, 1, 32768, 1), dtype=np.int32) + mask[0, 0, :, 0] = np.arange(32768, dtype=np.int32) +# mask[0, 0, :, 1] = np.arange(32768, dtype=np.int32) + + # 按规则填充mask[..., 0] + for i in range(len(cu_seqlens)-1): + start = cu_seqlens[i] + end = cu_seqlens[i+1] + mask[0, 0, start:end, 0] = cu_seqlens[i+1] + + # mask[..., 1] = i +# mask[0, 0, :, 1] = np.arange(32768, dtype=np.int32) + + # 转为 paddle tensor + mask_tensor = paddle.to_tensor(mask, dtype='int32') + print(mask_tensor) + return mask_tensor + +def main(examples: List[str] = ["all"], dtype='bf16'): + """Run the benchmark with the given examples. + + Args: + examples: List of examples to run. If "all" is specified, all examples will be run. + """ + total_length = 0 + paddle.set_flags({'FLAGS_flash_attn_version': 3}) + doc_seq_lens_list = [] + rank = paddle.distributed.get_rank() + + + + #doc_seq_lens_list = doc_seq_lens_list[::-1] + for D in [128]: + H = 1 + # print(doc_seq_lens_list) + results = [] + for idx in range(0,50): + B = 2 + startend_row_indices = paddle.load(f'/root/paddlejob/workspace/env_run/xiehaoyang/flashmask/flashmask-cp/cp_balance/dump_32k_startend_row_indices/startend_row_indices_{idx}.pdparams') + startend_row_indices = startend_row_indices.repeat(B, 1, 1, 1) + print(startend_row_indices) + S = 32768 + print(f"{B}_{S}_{H}_{D}_{idx}_{dtype}") + + + available_examples = { + # "Full": lambda: test_cp_famask(generate_mask_fn=partial(generate_none_mask, causal=False), B=B, S=S, H=H, D=D, dtype=dtype), + # "Causal": lambda: test_cp_famask(generate_mask_fn=partial(generate_none_mask, causal=True), B=B, S=S, H=H, D=D, dtype=dtype), + # "Sliding Window": lambda: test_cp_famask(generate_mask_fn=partial(generate_sliding_window_mask, window_size=int(S*0.0625)), B=B, S=S, H=H, D=D, dtype=dtype), + # "Causal Document Mask": lambda: test_cp_famask(generate_mask_fn=partial(generate_causal_document_mask, doc_seq_lens=doc_seq_lens), B=B, S=S, H=H, D=D, dtype=dtype), + "Document Mask": lambda: test_cp_famask(startend_row_indices = startend_row_indices, B=B, S=S, H=H, D=D, dtype=dtype), + # "Share Question Mask": lambda: test_cp_famask(generate_mask_fn=partial(generate_share_question_mask, doc_seq_lens=share_qa_docs), B=B, S=S, H=H, D=D, dtype=dtype), + # "Global Sliding Window": lambda: test_cp_famask(generate_mask_fn=partial(generate_global_sliding_window_mask, global_token=16, window_size=(int(S*0.0625), int(S*0.0625))), B=B, S=S, H=H, D=D, dtype=dtype), + # "Causal Blockwise Mask": lambda: test_cp_famask(generate_mask_fn=partial(generate_causal_blockwise_mask, doc_seq_lens=doc_seq_lens), B=B, S=S, H=H, D=D, dtype=dtype), + # "Prefix LM Document Mask": lambda: test_cp_famask(generate_mask_fn=partial(generate_prefix_lm_document_mask, doc_seq_lens=prefix_doc_seq_lens), B=B, S=S, H=H, D=D, dtype=dtype), + # "Prefix LM Causal Mask": lambda: test_cp_famask(generate_mask_fn=partial(generate_prefix_lm_causal_mask, prefix_length=int(S*0.5)), B=B, S=S, H=H, D=D, dtype=dtype), + # "QK-sparse Mask": lambda: test_cp_famask(generate_mask_fn=partial(generate_qk_sparse_mask, maskout_pair=maskout_pair), B=B, S=S, H=H, D=D, dtype=dtype), + # "Random Eviction Mask": lambda: test_cp_famask(generate_mask_fn=partial(generate_random_eviction_mask, start_row=S//2), B=B, S=S, H=H, D=D, dtype=dtype), + } + + global total_num + total_num = len(available_examples) + + if "all" in examples: + ex_to_run = list(available_examples.keys()) + else: + ex_to_run = examples + + for ex in ex_to_run: + if ex in available_examples: + print(ex) + fw_time, bw_time, total_time = available_examples[ex]() + results.append([idx, f"{fw_time:.4f}", f"{bw_time:.4f}", f"{total_time:.4f}"]) + print(fw_time, bw_time) + else: + print(f"Warning: Unknown example key '{ex}'. Skipping.") + # if(idx >= 3): + # return + + # print(f'avg_fwd_time:{sum([float(result[1][:-1]) for result in results]) / len(results)} avg_bwd_time:{sum([float(result[2][:-1]) for result in results]) / len(results)}') + headers = [ + "Idx", + "FW Time (ms)", + "BW Time (ms)", + "TOTAL Time (ms)", + ] + print( + tabulate( + results, + headers=headers, + tablefmt="grid", + ) + ) + content2=tabulate(results, headers=headers, tablefmt="tsv") + os.makedirs(f"{dtype}_dist_test_dump", exist_ok=True) + text_file = open(f"{dtype}_dist_test_dump/flashmask_{rank}_{B}_{S}_{H}_{D}.csv","w") + text_file.write(content2) + text_file.close() + paddle.device.synchronize() + paddle.distributed.barrier() + + print(f'avg_fwd_time:{sum([float(result[1][:-1]) for result in results]) / len(results)} avg_bwd_time:{sum([float(result[2][:-1]) for result in results]) / len(results)}') + # assert False + +if __name__ == "__main__": + try: + from jsonargparse import ArgumentParser + except ImportError: + raise ImportError("Be sure to run: pip install -e .'[viz]'") + parser = ArgumentParser(description="Run specific examples or all examples.") + parser.add_argument( + "--examples", + type=str, + nargs="+", + default=["all"], + help="List of examples to run. Use space to separate multiple examples. " + "Available options: causal, alibi, sliding_window, prefix_lm, " + "document, softcap, softcap_approx, or 'all' to run all examples.", + ) + parser.add_argument( + "--dtype", + type=str, + default="bf16" + ) + + args = parser.parse_args() + main(**vars(args))