-
Notifications
You must be signed in to change notification settings - Fork 607
Flash attn pad bw seqs #2596
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Flash attn pad bw seqs #2596
Conversation
Signed-off-by: Sudhakar Singh <[email protected]>
…_attn_pad_bw_seqs
for more information, see https://pre-commit.ci
Signed-off-by: Sudhakar Singh <[email protected]>
…ransformerEngine into flash_attn_pad_bw_seqs
Greptile Summaryenables FlashAttention 3 backend for THD format with Key Changes
Verification Neededcheck that FA3 with Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant DotProductAttention
participant Utils as get_attention_backend
participant FlashAttention as FlashAttention Backend
participant FlashAttn3 as flash_attn_3 Library
User->>DotProductAttention: forward(q, k, v, pad_between_seqs=True, cu_seqlens_q_padded, cu_seqlens_kv_padded)
DotProductAttention->>Utils: get_attention_backend(qkv_format="thd", pad_between_seqs=True)
alt FlashAttention 3 installed
Utils-->>DotProductAttention: use_flash_attention=True
else FlashAttention 2 only
Utils-->>DotProductAttention: use_flash_attention=False
end
DotProductAttention->>FlashAttention: forward(q, k, v, cu_seqlens_q, cu_seqlens_q_padded, pad_between_seqs=True)
FlashAttention->>FlashAttention: Use cu_seqlens_q_padded instead of cu_seqlens_q
FlashAttention->>FlashAttention: Calculate seqused_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
FlashAttention->>FlashAttention: Calculate seqused_k = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
FlashAttention->>FlashAttn3: flash_attn_varlen_func_v3(q, k, v, cu_seqlens_q_padded, cu_seqlens_kv_padded, seqused_q, seqused_k)
FlashAttn3-->>FlashAttention: output (with padding handled internally)
FlashAttention-->>DotProductAttention: output
DotProductAttention-->>User: output
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 1 comment
| # if `pad_between_seqs` is True, provide flash_attn_3 with `seqused_q` and `seqused_k` | ||
| # in addition to `cu_seqlens_q_padded` and `cu_seqlens_kv_padded` to avoid affecting the | ||
| # padding positions. | ||
| if pad_between_seqs: | ||
| fa_3_optional_forward_kwargs["seqused_q"] = ( | ||
| cu_seqlens_q[1:] - cu_seqlens_q[:-1] | ||
| ) | ||
| fa_3_optional_forward_kwargs["seqused_k"] = ( | ||
| cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: verify that flash_attn_3 with seqused_q/seqused_k truly avoids writing to padding positions - the related issue #2391 mentions "we need to manually set the output of the padded positions to zero" (similar to how FusedAttention zeroes output in C++ for THD format). if flash_attn_3 doesn't zero these internally, output may have garbage values in padded positions. have you verified that flash_attn_3 correctly handles padding internally with these parameters?
Description
With THD and FlashAttn3, TE should support
pad_between_seqs=True. (Right now, this works for non-CP cases. Currently, working on enabling this with CP)Fixes #2399
Type of change
Changes
Please list the changes introduced in this PR:
dpa/dot_product_attention.py::DotProductAttention->forward: plumb throughpad_between_seqs,cu_seqlens_q_paddedandcu_seqlens_k_paddedbackends.py::FlashAttention->forward:pad_between_seqs,cu_seqlens_q_paddedandcu_seqlens_k_paddedseqused_q/seqused_kbefore callingflash_attn_varlen_funcfromflash_attn_3dpa/utils.py::get_attention_packendto switchuse_flash_attentiontoTruewhenflash_attn_3is installedtests/pytorch/attention/test_attention.py::_run_dot_product_attentionto run FlashAttn for THD andpad_between_seqs=TrueChecklist: