@@ -276,17 +276,12 @@ def get_attn_backend_cls(
276276 "FLASHMLA, FLASH_ATTN_MLA, or TRITON_MLA. Alternatively, set "
277277 "VLLM_MLA_DISABLE=1 to disable MLA for this model."
278278 )
279- if not use_v1 :
280- raise RuntimeError (
281- "MLA attention backends require the V1 engine. "
282- "Set VLLM_USE_V1=1 to enable them."
283- )
284279
285280 from vllm .attention .ops .flashmla import is_flashmla_dense_supported
286281 from vllm .attention .utils .fa_utils import flash_attn_supports_mla
287282
288283 if use_sparse :
289- logger .info_once ("Using Sparse MLA backend on V1 engine ." )
284+ logger .info_once ("Using Sparse MLA backend." )
290285 return (
291286 "vllm.v1.attention.backends.mla.flashmla_sparse."
292287 "FlashMLASparseBackend"
@@ -313,15 +308,13 @@ def get_attn_backend_cls(
313308 )
314309
315310 if use_cutlassmla :
316- logger .info_once (
317- "Using Cutlass MLA backend on V1 engine." , scope = "local"
318- )
311+ logger .info_once ("Using Cutlass MLA backend." , scope = "local" )
319312 return "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
320313 if use_flashinfermla :
321314 from vllm .v1 .attention .backends .utils import set_kv_cache_layout
322315
323316 set_kv_cache_layout ("HND" )
324- logger .info_once ("Using FlashInfer MLA backend on V1 engine ." )
317+ logger .info_once ("Using FlashInfer MLA backend." )
325318 return (
326319 "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
327320 )
@@ -333,116 +326,107 @@ def get_attn_backend_cls(
333326 block_size ,
334327 )
335328 else :
336- logger .info_once ("Using FlashMLA backend on V1 engine ." )
329+ logger .info_once ("Using FlashMLA backend." )
337330 return "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend"
338331 if use_flashattn :
339- logger .info_once ("Using FlashAttention MLA backend on V1 engine ." )
332+ logger .info_once ("Using FlashAttention MLA backend." )
340333 return (
341334 "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
342335 )
343336 if use_triton :
344- logger .info_once ("Using Triton MLA backend on V1 engine ." )
337+ logger .info_once ("Using Triton MLA backend." )
345338 return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
346- if use_v1 :
347- FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
348- FLEX_ATTENTION_V1 = (
349- "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
350- )
351- TRITON_ATTN = (
352- "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
353- )
354- FLASH_ATTN_V1 = (
355- "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
356- )
357- TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501
358- XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501
359339
360- use_fp8_kv_cache = kv_cache_dtype is not None and kv_cache_dtype .startswith (
361- "fp8"
362- )
340+ FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
341+ FLEX_ATTENTION_V1 = (
342+ "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
343+ )
344+ TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
345+ FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
346+ TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501
347+ XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501
363348
364- if selected_backend == _Backend .FLASHINFER :
365- logger .info_once ("Using FlashInfer backend on V1 engine." )
366- if cls .has_device_capability (100 ):
367- from vllm .v1 .attention .backends .utils import set_kv_cache_layout
349+ use_fp8_kv_cache = kv_cache_dtype is not None and kv_cache_dtype .startswith (
350+ "fp8"
351+ )
368352
369- set_kv_cache_layout ("HND" )
370- return FLASHINFER_V1
371- elif selected_backend == _Backend .FLEX_ATTENTION :
372- logger .info_once ("Using FlexAttention backend on V1 engine." )
373- return FLEX_ATTENTION_V1
374- elif selected_backend == _Backend .TRITON_ATTN :
375- logger .info_once ("Using Triton backend on V1 engine." )
376- return TRITON_ATTN
377- elif selected_backend == _Backend .FLASH_ATTN :
378- logger .info_once ("Using Flash Attention backend on V1 engine." )
379- return FLASH_ATTN_V1
380- elif selected_backend == _Backend .TREE_ATTN :
381- logger .info_once ("Using Tree Attention backend on V1 engine." )
382- return TREE_ATTN_V1
383- elif selected_backend == _Backend .XFORMERS :
384- logger .info_once ("Using XFormers backend on V1 engine." )
385- return XFORMERS_V1
353+ if selected_backend == _Backend .FLASHINFER :
354+ logger .info_once ("Using FlashInfer backend." )
355+ if cls .has_device_capability (100 ):
356+ from vllm .v1 .attention .backends .utils import set_kv_cache_layout
386357
387- from vllm .attention .selector import is_attn_backend_supported
358+ set_kv_cache_layout ("HND" )
359+ return FLASHINFER_V1
360+ elif selected_backend == _Backend .FLEX_ATTENTION :
361+ logger .info_once ("Using FlexAttention backend." )
362+ return FLEX_ATTENTION_V1
363+ elif selected_backend == _Backend .TRITON_ATTN :
364+ logger .info_once ("Using Triton backend." )
365+ return TRITON_ATTN
366+ elif selected_backend == _Backend .FLASH_ATTN :
367+ logger .info_once ("Using Flash Attention backend." )
368+ return FLASH_ATTN_V1
369+ elif selected_backend == _Backend .TREE_ATTN :
370+ logger .info_once ("Using Tree Attention backend." )
371+ return TREE_ATTN_V1
372+ elif selected_backend == _Backend .XFORMERS :
373+ logger .info_once ("Using XFormers backend." )
374+ return XFORMERS_V1
375+
376+ from vllm .attention .selector import is_attn_backend_supported
377+
378+ # Default backends for V1 engine
379+ # Prefer FlashInfer for Blackwell GPUs if installed
380+ if cls .is_device_capability (100 ):
381+ if is_default_backend_supported := is_attn_backend_supported (
382+ FLASHINFER_V1 , head_size , dtype
383+ ):
384+ from vllm .v1 .attention .backends .utils import set_kv_cache_layout
388385
389- # Default backends for V1 engine
390- # Prefer FlashInfer for Blackwell GPUs if installed
391- if cls .is_device_capability (100 ):
392- if is_default_backend_supported := is_attn_backend_supported (
393- FLASHINFER_V1 , head_size , dtype
394- ):
395- from vllm .v1 .attention .backends .utils import set_kv_cache_layout
396-
397- logger .info_once (
398- "Using FlashInfer backend with HND KV cache layout on "
399- "V1 engine by default for Blackwell (SM 10.0) GPUs."
400- )
401- set_kv_cache_layout ("HND" )
386+ logger .info_once (
387+ "Using FlashInfer backend with HND KV cache layout on "
388+ "V1 engine by default for Blackwell (SM 10.0) GPUs."
389+ )
390+ set_kv_cache_layout ("HND" )
402391
403- return FLASHINFER_V1
392+ return FLASHINFER_V1
404393
405- if not is_default_backend_supported .can_import :
406- logger .warning_once (
407- "FlashInfer failed to import for V1 engine on "
408- "Blackwell (SM 10.0) GPUs; it is recommended to "
409- "install FlashInfer for better performance."
410- )
394+ if not is_default_backend_supported .can_import :
395+ logger .warning_once (
396+ "FlashInfer failed to import on Blackwell (SM 10.0) GPUs; "
397+ " it is recommended to install FlashInfer for better "
398+ " performance."
399+ )
411400
412- # FlashAttention is the default for SM 8.0+ GPUs
413- if cls .has_device_capability (80 ):
414- if (has_sink or use_fp8_kv_cache ) and not cls .is_device_capability (90 ):
415- logger .info_once ("Using Triton backend on V1 engine." )
416- return TRITON_ATTN
417- elif is_default_backend_supported := is_attn_backend_supported (
418- FLASH_ATTN_V1 , head_size , dtype , allow_import_error = False
419- ):
420- logger .info_once ("Using Flash Attention backend on V1 engine." )
421- return FLASH_ATTN_V1
422-
423- # FlexAttention is the default for older GPUs
424- else :
425- logger .info_once ("Using FlexAttention backend on V1 engine." )
426- return FLEX_ATTENTION_V1
401+ # FlashAttention is the default for SM 8.0+ GPUs
402+ if cls .has_device_capability (80 ):
403+ if (has_sink or use_fp8_kv_cache ) and not cls .is_device_capability (90 ):
404+ logger .info_once ("Using Triton backend." )
405+ return TRITON_ATTN
406+ elif is_default_backend_supported := is_attn_backend_supported (
407+ FLASH_ATTN_V1 , head_size , dtype , allow_import_error = False
408+ ):
409+ logger .info_once ("Using Flash Attention backend." )
410+ return FLASH_ATTN_V1
427411
428- assert not is_default_backend_supported
412+ # FlexAttention is the default for older GPUs
413+ else :
414+ logger .info_once ("Using FlexAttention backend." )
415+ return FLEX_ATTENTION_V1
429416
430- use_flex_attention_reason = {}
431- if not is_default_backend_supported .head_size :
432- use_flex_attention_reason ["head_size" ] = head_size
433- if not is_default_backend_supported .dtype :
434- use_flex_attention_reason ["dtype" ] = dtype
417+ assert not is_default_backend_supported
435418
436- logger . info_once (
437- "Using FlexAttention backend for %s on V1 engine." ,
438- ", " . join ( f" { k } = { v } " for k , v in use_flex_attention_reason . items ()),
439- )
440- return FLEX_ATTENTION_V1
419+ use_flex_attention_reason = {}
420+ if not is_default_backend_supported . head_size :
421+ use_flex_attention_reason [ "head_size" ] = head_size
422+ if not is_default_backend_supported . dtype :
423+ use_flex_attention_reason [ "dtype" ] = dtype
441424
442- raise RuntimeError (
443- "V0 attention backends have been removed. Set VLLM_USE_V1=1 "
444- "to select a supported backend."
425+ logger . info_once (
426+ "Using FlexAttention backend for %s." ,
427+ ", " . join ( f" { k } = { v } " for k , v in use_flex_attention_reason . items ()),
445428 )
429+ return FLEX_ATTENTION_V1
446430
447431 @classmethod
448432 def get_punica_wrapper (cls ) -> str :
0 commit comments