Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FlashMLA from DeepSeek #892

Open
zhyncs opened this issue Feb 24, 2025 · 14 comments
Open

FlashMLA from DeepSeek #892

zhyncs opened this issue Feb 24, 2025 · 14 comments

Comments

@zhyncs
Copy link
Member

zhyncs commented Feb 24, 2025

as titled

ref https://github.com/deepseek-ai/FlashMLA

@celsowm
Copy link

celsowm commented Feb 24, 2025

I went here for it ! @zhyncs was really fast

@MichoChan
Copy link

#887 how about this?compare vs https://github.com/deepseek-ai/FlashMLA

@yzh119
Copy link
Collaborator

yzh119 commented Feb 24, 2025

The pipeline design is a little bit different from #887, I'll check what we can learn from it.

@yzh119
Copy link
Collaborator

yzh119 commented Feb 24, 2025

@zhyncs @celsowm @MichoChan here is the result I got on H100, by running the latest flashinfer and FlashMLA mainline (the higher the better), for flashinfer we use page_size=1 and FlashMLA uses page_size=64.

Image

@abcdabcd987
Copy link
Member

abcdabcd987 commented Feb 24, 2025

Here's my benchmark code and result on H100:

https://gist.github.com/abcdabcd987/b215c5f00f4b5e8399b95d7933bcf475

https://docs.google.com/spreadsheets/d/1t0Txa7Ph9u7Su9LyWpS24vqr9A5FB-FyL0EZNpYOqwg/edit?gid=0#gid=0

Both are using page size 64. FlashMLA is faster in general, way faster on small batch sizes.

yzh119 added a commit that referenced this issue Feb 24, 2025
As pointed in
#892 (comment)
The second stage of split-k seems to have a huge overhead. This PR is
the first second in addressing these issues, by changing the vector size
from 4 to 8.
@yzh119
Copy link
Collaborator

yzh119 commented Feb 24, 2025

Hi @abcdabcd987 , yes I didn't profiled the low batch size use cases, and I just realized we get low performance for small batch and long context.

#894 alleviate the issue a little bit.

Regarding the cases (qo_len * num_heads >= 128), the current flashinfer implementation is not good at this, because we prioritize page_size=1 (for larger page_size, using tma + multicast would help). I'll also take a look at FlashMLA's implementation and check how does their schedule deal with this case.

@liangzelang
Copy link

I found DeepSeek FlashMLA is very very faster than flashinfer when q_head_num equals to 128 (tp1) , almost faster 100% when bs=32. but when q_head_num equals to [16 32 64], faster 10%-20%.
test on H800

@yzh119
Copy link
Collaborator

yzh119 commented Feb 24, 2025

We will try out the FlashMLA-style warp specialization in the next release.

Created an issue for performance tracking: #897

yzh119 added a commit that referenced this issue Feb 26, 2025
As observed in #892 , we found flashinfer mla's second stage of split-k
is very slow (when batch size is small), this is because our scheduler
only uses one CTA for the second stage of split-k.

This PR fixes the issue.
@yanghailong-git
Copy link

Here's my benchmark code and result on H100:

https://gist.github.com/abcdabcd987/b215c5f00f4b5e8399b95d7933bcf475

https://docs.google.com/spreadsheets/d/1t0Txa7Ph9u7Su9LyWpS24vqr9A5FB-FyL0EZNpYOqwg/edit?gid=0#gid=0

Both are using page size 64. FlashMLA is faster in general, way faster on small batch sizes.

Hello, I noticed the significant speed improvement in the latest test results, but the test script throws errors when running with the new version of FlashInfer. If modifications are needed for the test script?

@yzh119
Copy link
Collaborator

yzh119 commented Feb 27, 2025

@yanghailong-git can you report the error message?

@yanghailong-git
Copy link

@yanghailong-git can you report the error message?

When running this script https://gist.github.com/abcdabcd987/b215c5f00f4b5e8399b95d7933bcf475 with version v0.2.2.post1, I encountered the error below. How should I resolve this? Thanks.

Image

@yzh119
Copy link
Collaborator

yzh119 commented Feb 28, 2025

Can you post the full error message in text instead, some key information were clipped in your screenshot.

@yanghailong-git
Copy link

yanghailong-git commented Feb 28, 2025

Can you post the full error message in text instead, some key information were clipped in your screenshot.

The detailed error is as follows:

2025-02-28 11:38:16,665 - INFO - flashinfer.jit: Prebuilt kernels not found, using JIT backend
nhead   q_len   kv_len  bs      FA2     FA3     FlashMLA
2025-02-28 11:38:16,846 - INFO - flashinfer.jit: Loading JIT ops: batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64
2025-02-28 11:38:45,662 - INFO - flashinfer.jit: Finished loading JIT ops: batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64
2025-02-28 11:38:45,667 - INFO - flashinfer.jit: Loading JIT ops: batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_sm90
Traceback (most recent call last):
  File "/root/miniconda3/envs/torch/lib/python3.10/site-packages/torch/utils/cpp_extension.py", line 2104, in _run_ninja_build
    subprocess.run(
  File "/root/miniconda3/envs/torch/lib/python3.10/subprocess.py", line 526, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/root/workspace/work/yhl/mla-flashinfer-vs-deepseek-0.2.2.post1.py", line 155, in <module>
    main()
  File "/root/workspace/work/yhl/mla-flashinfer-vs-deepseek-0.2.2.post1.py", line 151, in main
    bench_ragged_vs_mla(num_heads, q_len, kv_len, batch_size)
  File "/root/miniconda3/envs/torch/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/root/workspace/work/yhl/mla-flashinfer-vs-deepseek-0.2.2.post1.py", line 66, in bench_ragged_vs_mla
    mla.plan(
  File "/root/miniconda3/envs/torch/lib/python3.10/site-packages/flashinfer/mla.py", line 225, in plan
    self._cached_module = get_batch_mla_module(self._backend)(
  File "/root/miniconda3/envs/torch/lib/python3.10/site-packages/flashinfer/mla.py", line 44, in backend_module
    modules_dict[args] = gen_batch_mla_module(backend, *args)
  File "/root/miniconda3/envs/torch/lib/python3.10/site-packages/flashinfer/jit/attention/pytorch.py", line 181, in gen_batch_mla_module
    return load_cuda_ops(
  File "/root/miniconda3/envs/torch/lib/python3.10/site-packages/flashinfer/jit/core.py", line 123, in load_cuda_ops
    torch_cpp_ext.load(
  File "/root/miniconda3/envs/torch/lib/python3.10/site-packages/torch/utils/cpp_extension.py", line 1314, in load
    return _jit_compile(
  File "/root/miniconda3/envs/torch/lib/python3.10/site-packages/torch/utils/cpp_extension.py", line 1721, in _jit_compile
    _write_ninja_file_and_build_library(
  File "/root/miniconda3/envs/torch/lib/python3.10/site-packages/torch/utils/cpp_extension.py", line 1833, in _write_ninja_file_and_build_library
    _run_ninja_build(
  File "/root/miniconda3/envs/torch/lib/python3.10/site-packages/torch/utils/cpp_extension.py", line 2120, in _run_ninja_build
    raise RuntimeError(message) from e
RuntimeError: Error building extension 'batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_sm90': [1/4] /root/miniconda3/envs/torch/bin/nvcc --generate-dependencies-with-compile --dependency-output batch_mla_sm90_run.cuda.o.d -ccbin /root/miniconda3/envs/torch/bin/x86_64-conda-linux-gnu-cc -DTORCH_EXTENSION_NAME=batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_sm90 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -I/root/miniconda3/envs/torch/lib/python3.10/site-packages/flashinfer/data/include -I/root/miniconda3/envs/torch/lib/python3.10/site-packages/flashinfer/data/csrc -I/root/miniconda3/envs/torch/lib/python3.10/site-packages/flashinfer/data/cutlass/include -I/root/miniconda3/envs/torch/lib/python3.10/site-packages/flashinfer/data/cutlass/tools/util/include -isystem /root/miniconda3/envs/torch/lib/python3.10/site-packages/torch/include -isystem /root/miniconda3/envs/torch/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /root/miniconda3/envs/torch/lib/python3.10/site-packages/torch/include/TH -isystem /root/miniconda3/envs/torch/lib/python3.10/site-packages/torch/include/THC -isystem /root/miniconda3/envs/torch/include -isystem /root/miniconda3/envs/torch/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 --expt-relaxed-constexpr -gencode=arch=compute_90,code=sm_90 --compiler-options '-fPIC' -O3 -std=c++17 --threads 4 -use_fast_math -DFLASHINFER_ENABLE_F16 -DFLASHINFER_ENABLE_BF16 -DFLASHINFER_ENABLE_FP8_E4M3 -DFLASHINFER_ENABLE_FP8_E5M2 -gencode=arch=compute_90a,code=sm_90a -c /root/.cache/flashinfer/90/generated/batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_sm90/batch_mla_sm90_run.cu -o batch_mla_sm90_run.cuda.o 
FAILED: batch_mla_sm90_run.cuda.o 
/root/miniconda3/envs/torch/bin/nvcc --generate-dependencies-with-compile --dependency-output batch_mla_sm90_run.cuda.o.d -ccbin /root/miniconda3/envs/torch/bin/x86_64-conda-linux-gnu-cc -DTORCH_EXTENSION_NAME=batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_sm90 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -I/root/miniconda3/envs/torch/lib/python3.10/site-packages/flashinfer/data/include -I/root/miniconda3/envs/torch/lib/python3.10/site-packages/flashinfer/data/csrc -I/root/miniconda3/envs/torch/lib/python3.10/site-packages/flashinfer/data/cutlass/include -I/root/miniconda3/envs/torch/lib/python3.10/site-packages/flashinfer/data/cutlass/tools/util/include -isystem /root/miniconda3/envs/torch/lib/python3.10/site-packages/torch/include -isystem /root/miniconda3/envs/torch/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /root/miniconda3/envs/torch/lib/python3.10/site-packages/torch/include/TH -isystem /root/miniconda3/envs/torch/lib/python3.10/site-packages/torch/include/THC -isystem /root/miniconda3/envs/torch/include -isystem /root/miniconda3/envs/torch/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 --expt-relaxed-constexpr -gencode=arch=compute_90,code=sm_90 --compiler-options '-fPIC' -O3 -std=c++17 --threads 4 -use_fast_math -DFLASHINFER_ENABLE_F16 -DFLASHINFER_ENABLE_BF16 -DFLASHINFER_ENABLE_FP8_E4M3 -DFLASHINFER_ENABLE_FP8_E5M2 -gencode=arch=compute_90a,code=sm_90a -c /root/.cache/flashinfer/90/generated/batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_sm90/batch_mla_sm90_run.cu -o batch_mla_sm90_run.cuda.o 
nvcc warning : incompatible redefinition for option 'compiler-bindir', the last value of this option was used
In file included from /root/.cache/flashinfer/90/generated/batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_sm90/batch_mla_sm90_run.cu:21:
/root/.cache/flashinfer/90/generated/batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_sm90/batch_mla_sm90_config.inc:20:111: warning: backslash-newline at end of file
   20 | #define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE, Params, ...) \
      |                                                                                                                
In file included from /root/.cache/flashinfer/90/generated/batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_sm90/batch_mla_sm90_run.cu:21:
/root/.cache/flashinfer/90/generated/batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_sm90/batch_mla_sm90_config.inc:20:111: warning: backslash-newline at end of file
   20 | #define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE, Params, ...) \
      |                                                                                                                
In file included from /root/.cache/flashinfer/90/generated/batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_sm90/batch_mla_sm90_run.cu:21:
/root/.cache/flashinfer/90/generated/batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_sm90/batch_mla_sm90_config.inc:20:111: warning: backslash-newline at end of file
   20 | #define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE, Params, ...) \
      |                                                                                                                
/root/miniconda3/envs/torch/lib/python3.10/site-packages/flashinfer/data/include/flashinfer/attention/mla_hopper.cuh(554): error: explicit type is missing ("int" assumed)
  __attribute__((device)) __inline__ __attribute__((always_inline)) convert_s_to_p(float* s_frag, uint32_t* p_frag) {
                          ^

1 error detected in the compilation of "/root/.cache/flashinfer/90/generated/batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_sm90/batch_mla_sm90_run.cu".
[2/4] /root/miniconda3/envs/torch/bin/nvcc --generate-dependencies-with-compile --dependency-output batch_mla_sm90_pybind.cuda.o.d -ccbin /root/miniconda3/envs/torch/bin/x86_64-conda-linux-gnu-cc -DTORCH_EXTENSION_NAME=batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_sm90 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -I/root/miniconda3/envs/torch/lib/python3.10/site-packages/flashinfer/data/include -I/root/miniconda3/envs/torch/lib/python3.10/site-packages/flashinfer/data/csrc -I/root/miniconda3/envs/torch/lib/python3.10/site-packages/flashinfer/data/cutlass/include -I/root/miniconda3/envs/torch/lib/python3.10/site-packages/flashinfer/data/cutlass/tools/util/include -isystem /root/miniconda3/envs/torch/lib/python3.10/site-packages/torch/include -isystem /root/miniconda3/envs/torch/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /root/miniconda3/envs/torch/lib/python3.10/site-packages/torch/include/TH -isystem /root/miniconda3/envs/torch/lib/python3.10/site-packages/torch/include/THC -isystem /root/miniconda3/envs/torch/include -isystem /root/miniconda3/envs/torch/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 --expt-relaxed-constexpr -gencode=arch=compute_90,code=sm_90 --compiler-options '-fPIC' -O3 -std=c++17 --threads 4 -use_fast_math -DFLASHINFER_ENABLE_F16 -DFLASHINFER_ENABLE_BF16 -DFLASHINFER_ENABLE_FP8_E4M3 -DFLASHINFER_ENABLE_FP8_E5M2 -gencode=arch=compute_90a,code=sm_90a -c /root/.cache/flashinfer/90/generated/batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_sm90/batch_mla_sm90_pybind.cu -o batch_mla_sm90_pybind.cuda.o 
nvcc warning : incompatible redefinition for option 'compiler-bindir', the last value of this option was used
In file included from /root/.cache/flashinfer/90/generated/batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_sm90/batch_mla_sm90_pybind.cu:16:
/root/.cache/flashinfer/90/generated/batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_sm90/batch_mla_sm90_config.inc:20:111: warning: backslash-newline at end of file
   20 | #define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE, Params, ...) \
      |                                                                                                                
In file included from /root/.cache/flashinfer/90/generated/batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_sm90/batch_mla_sm90_pybind.cu:16:
/root/.cache/flashinfer/90/generated/batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_sm90/batch_mla_sm90_config.inc:20:111: warning: backslash-newline at end of file
   20 | #define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE, Params, ...) \
      |                                                                                                                
In file included from /root/.cache/flashinfer/90/generated/batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_sm90/batch_mla_sm90_pybind.cu:16:
/root/.cache/flashinfer/90/generated/batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_sm90/batch_mla_sm90_config.inc:20:111: warning: backslash-newline at end of file
   20 | #define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE, Params, ...) \
      |                                                                                                                
[3/4] /root/miniconda3/envs/torch/bin/nvcc --generate-dependencies-with-compile --dependency-output batch_mla_sm90_plan.cuda.o.d -ccbin /root/miniconda3/envs/torch/bin/x86_64-conda-linux-gnu-cc -DTORCH_EXTENSION_NAME=batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_sm90 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -I/root/miniconda3/envs/torch/lib/python3.10/site-packages/flashinfer/data/include -I/root/miniconda3/envs/torch/lib/python3.10/site-packages/flashinfer/data/csrc -I/root/miniconda3/envs/torch/lib/python3.10/site-packages/flashinfer/data/cutlass/include -I/root/miniconda3/envs/torch/lib/python3.10/site-packages/flashinfer/data/cutlass/tools/util/include -isystem /root/miniconda3/envs/torch/lib/python3.10/site-packages/torch/include -isystem /root/miniconda3/envs/torch/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /root/miniconda3/envs/torch/lib/python3.10/site-packages/torch/include/TH -isystem /root/miniconda3/envs/torch/lib/python3.10/site-packages/torch/include/THC -isystem /root/miniconda3/envs/torch/include -isystem /root/miniconda3/envs/torch/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 --expt-relaxed-constexpr -gencode=arch=compute_90,code=sm_90 --compiler-options '-fPIC' -O3 -std=c++17 --threads 4 -use_fast_math -DFLASHINFER_ENABLE_F16 -DFLASHINFER_ENABLE_BF16 -DFLASHINFER_ENABLE_FP8_E4M3 -DFLASHINFER_ENABLE_FP8_E5M2 -gencode=arch=compute_90a,code=sm_90a -c /root/.cache/flashinfer/90/generated/batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_sm90/batch_mla_sm90_plan.cu -o batch_mla_sm90_plan.cuda.o 
nvcc warning : incompatible redefinition for option 'compiler-bindir', the last value of this option was used
In file included from /root/.cache/flashinfer/90/generated/batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_sm90/batch_mla_sm90_plan.cu:19:
/root/.cache/flashinfer/90/generated/batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_sm90/batch_mla_sm90_config.inc:20:111: warning: backslash-newline at end of file
   20 | #define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE, Params, ...) \
      |                                                                                                                
In file included from /root/.cache/flashinfer/90/generated/batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_sm90/batch_mla_sm90_plan.cu:19:
/root/.cache/flashinfer/90/generated/batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_sm90/batch_mla_sm90_config.inc:20:111: warning: backslash-newline at end of file
   20 | #define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE, Params, ...) \
      |                                                                                                                
In file included from /root/.cache/flashinfer/90/generated/batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_sm90/batch_mla_sm90_plan.cu:19:
/root/.cache/flashinfer/90/generated/batch_mla_attention_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_sm90/batch_mla_sm90_config.inc:20:111: warning: backslash-newline at end of file
   20 | #define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE, Params, ...) \
      |                                                                                                                
ninja: build stopped: subcommand failed.

@yzh119
Copy link
Collaborator

yzh119 commented Feb 28, 2025

@yanghailong-git #904 should fix it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants