Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion vllm/config/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,16 @@ def compute_hash(self) -> str:
def validate_backend_before(cls, value: Any) -> Any:
"""Enable parsing of the `backend` enum type from string."""
if isinstance(value, str):
return AttentionBackendEnum[value.upper()]
value = AttentionBackendEnum[value.upper()]

if value == AttentionBackendEnum.FLASH_ATTN_CUTE:
raise ValueError(
"AttentionConfig.backend does not support FLASH_ATTN_CUTE "
"(FA4 / flash_attn.cute). This is a ViT/MM-encoder-only "
"attention tag. Use --mm-encoder-attn-backend / "
"MultiModalConfig.mm_encoder_attn_backend instead."
)

return value

def _set_from_env_if_set(self, field_name: str, env_var_name: str) -> None:
Expand Down
43 changes: 42 additions & 1 deletion vllm/model_executor/layers/attention/mm_encoder_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.ops.vit_attn_wrappers import (
vit_fa4_flash_attn_wrapper,
vit_flash_attn_wrapper,
vit_torch_sdpa_wrapper,
)
Expand Down Expand Up @@ -79,6 +80,10 @@ def __init__(
AttentionBackendEnum.ROCM_AITER_FA,
}

self.is_fa4_backend = (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN_CUTE
)

self._fa_version = (
get_flash_attn_version() if self.is_flash_attn_backend else None
)
Expand Down Expand Up @@ -182,6 +187,40 @@ def _forward_fa(
output = output.reshape(bsz, q_len, -1)
return output

def _forward_fa4(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None,
) -> torch.Tensor:
"""FA4 (flash_attn.cute) attention for multimodal encoder."""
assert (cu_seqlens is not None and max_seqlen is not None) or (
cu_seqlens is None and max_seqlen is None
), "cu_seqlens and max_seqlen should be both set or both None."

bsz, q_len = query.size()[:2]
kv_len = key.size(1)
is_reshaped = query.dim() != 4

query, key, value = self.maybe_reshape_qkv_to_4d(
query, key, value, bsz, q_len, kv_len
)

output = vit_fa4_flash_attn_wrapper(
q=query,
k=key,
v=value,
batch_size=bsz,
scale=self.scale,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
if is_reshaped:
output = output.reshape(bsz, q_len, -1)
return output

def forward_native(
self,
query: torch.Tensor,
Expand All @@ -200,7 +239,9 @@ def forward_cuda(
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
if self.is_flash_attn_backend:
if self.is_fa4_backend:
return self._forward_fa4(query, key, value, cu_seqlens, max_seqlen)
elif self.is_flash_attn_backend:
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
return self._forward_sdpa(query, key, value, cu_seqlens)
Expand Down
5 changes: 4 additions & 1 deletion vllm/model_executor/layers/rotary_embedding/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,10 @@ def __init__(

self.apply_rotary_emb_flash_attn = None
if find_spec("flash_attn") is not None:
from flash_attn.ops.triton.rotary import apply_rotary
try:
from flash_attn.ops.triton.rotary import apply_rotary
except (ImportError, ModuleNotFoundError):
apply_rotary = None

self.apply_rotary_emb_flash_attn = apply_rotary

Expand Down
38 changes: 37 additions & 1 deletion vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ def get_attn_backend_cls(
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
return [
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.FLASH_ATTN_CUTE,
AttentionBackendEnum.FLASH_ATTN,
]

Expand All @@ -376,10 +377,45 @@ def get_vit_attn_backend(
f"Backend {backend} is not supported for vit attention. "
f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
)
if backend == AttentionBackendEnum.FLASH_ATTN_CUTE:
cc = cls.get_device_capability()
if cc is None or cc.major < 10:
raise ValueError(
"FLASH_ATTN_CUTE (FA4) requires Blackwell (SM100+). "
f"Current device: SM{cc.major}{cc.minor}" if cc
else "No device found."
)
from vllm.v1.attention.backends.fa4_utils import (
is_flash_attn_cute_available,
)
if not is_flash_attn_cute_available():
raise ImportError(
"flash_attn.cute is not installed. "
"Install with: pip install "
"git+https://github.com/Dao-AILab/flash-attention.git"
"#subdirectory=flash_attn/cute"
)
logger.info_once(f"Using backend {backend} for vit attention")
return backend

# Try FlashAttention first
# On Blackwell, try FA4 first
if (cc := cls.get_device_capability()) and cc.major >= 10:
try:
from vllm.v1.attention.backends.fa4_utils import (
is_flash_attn_cute_available,
)
if is_flash_attn_cute_available() and dtype in (
torch.float16, torch.bfloat16
):
logger.info_once(
"Auto-selecting FLASH_ATTN_CUTE (FA4) for ViT on "
"Blackwell."
)
return AttentionBackendEnum.FLASH_ATTN_CUTE
except ImportError:
pass

# Try FlashAttention (FA2)
if (cc := cls.get_device_capability()) and cc.major >= 8:
try:
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
Expand Down
82 changes: 82 additions & 0 deletions vllm/v1/attention/backends/fa4_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# SPDX-License-Identifier: Apache-2.0
"""Utilities for Flash Attention 4 (flash_attn.cute) on Blackwell."""

import torch

from vllm.logger import init_logger

logger = init_logger(__name__)

_FA4_AVAILABLE: bool | None = None
_FA4_FUNC = None

# Head sizes optimized for FA4 on Blackwell
FA4_SUPPORTED_HEAD_SIZES = (64, 96, 128, 192)


def _import_fa4_fwd():
"""Try importing FA4. Prefer flash_attn_cute to avoid polluting the
flash_attn namespace which would break vllm's flash_attn.ops imports."""
try:
from flash_attn_cute.interface import _flash_attn_fwd
return _flash_attn_fwd
except (ImportError, ModuleNotFoundError):
pass
try:
from flash_attn.cute.interface import _flash_attn_fwd
return _flash_attn_fwd
except (ImportError, ModuleNotFoundError):
pass
return None


def is_flash_attn_cute_available() -> bool:
global _FA4_AVAILABLE
if _FA4_AVAILABLE is not None:
return _FA4_AVAILABLE
_FA4_AVAILABLE = _import_fa4_fwd() is not None
return _FA4_AVAILABLE


def _get_fa4_func():
global _FA4_FUNC
if _FA4_FUNC is None:
_FA4_FUNC = _import_fa4_fwd()
if _FA4_FUNC is None:
raise ImportError(
"flash_attn.cute is not available. "
"Install flash-attn-4 for Blackwell FA4 support."
)
return _FA4_FUNC


def flash_attn_cute_varlen_func(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: float | None = None,
causal: bool = False,
) -> torch.Tensor:
"""Wrapper around flash_attn.cute for varlen (variable-length) attention."""
fa4_fwd = _get_fa4_func()

if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)

result = fa4_fwd(
q, k, v,
softmax_scale=softmax_scale,
causal=causal,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
)
# _flash_attn_fwd returns (output, softmax_lse); we only need output
if isinstance(result, tuple):
return result[0]
return result
3 changes: 0 additions & 3 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,9 +1059,6 @@ def build(
## DECODE PATHWAY
if num_decodes > 0:
if decode_use_trtllm:
assert num_decode_tokens % num_decodes == 0, (
"TRTLLM decode requires uniform query lengths per request."
)
attn_metadata.decode = TRTLLMDecode(
block_tables=block_table_tensor[:num_decodes],
seq_lens=seq_lens[:num_decodes],
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/attention/backends/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
"vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse.ROCMAiterMLASparseBackend"
)
TORCH_SDPA = "" # this tag is only used for ViT
FLASH_ATTN_CUTE = "" # FA4 via flash_attn.cute, ViT/MM encoder only
FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
FLASHINFER_MLA = (
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
Expand Down
68 changes: 68 additions & 0 deletions vllm/v1/attention/ops/vit_attn_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,71 @@ def vit_torch_sdpa_wrapper(
cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, scale, cu_seqlens)


# ---- FA4 (flash_attn.cute) wrappers ----

def fa4_flash_attn_maxseqlen_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
batch_size: int,
scale: float | None = None,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None,
) -> torch.Tensor:
from vllm.v1.attention.backends.fa4_utils import flash_attn_cute_varlen_func

q_len = q.size(1)
if cu_seqlens is None:
cu_seqlens = torch.arange(
0, (batch_size + 1) * q_len, step=q_len,
dtype=torch.int32, device=q.device,
)
max_seqlen_val = q_len if max_seqlen is None else max_seqlen.item()

q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_cute_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen_val,
max_seqlen_k=max_seqlen_val,
softmax_scale=scale,
causal=False,
)
context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size)
return context_layer


def fa4_flash_attn_maxseqlen_wrapper_fake(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
batch_size: int,
scale: float | None = None,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.empty_like(q)


direct_register_custom_op(
op_name="fa4_flash_attn_maxseqlen_wrapper",
op_func=fa4_flash_attn_maxseqlen_wrapper,
fake_impl=fa4_flash_attn_maxseqlen_wrapper_fake,
)


def vit_fa4_flash_attn_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
batch_size: int,
scale: float | None = None,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.ops.vllm.fa4_flash_attn_maxseqlen_wrapper(
q, k, v, batch_size, scale, cu_seqlens, max_seqlen,
)
4 changes: 4 additions & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4382,7 +4382,11 @@ def _dummy_run(
self.seq_lens.copy_to_gpu()

cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens)
self.query_start_loc.np[0] = 0
self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
self.query_start_loc.np[num_reqs + 1 :].fill(cum_num_tokens[-1])
self.query_start_loc.copy_to_gpu()

pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL
Expand Down