Skip to content

Conversation

@sudhakarsingh27
Copy link
Collaborator

Description

With THD and FlashAttn3, TE should support pad_between_seqs=True. (Right now, this works for non-CP cases. Currently, working on enabling this with CP)

Fixes #2399

Type of change

  • New feature (non-breaking change which adds functionality)

Changes

Please list the changes introduced in this PR:

  • dpa/dot_product_attention.py::DotProductAttention->forward: plumb through pad_between_seqs, cu_seqlens_q_padded and cu_seqlens_k_padded
  • backends.py::FlashAttention->forward:
    • plumb through pad_between_seqs, cu_seqlens_q_padded and cu_seqlens_k_padded
    • calculate seqused_q/seqused_k before calling flash_attn_varlen_func from flash_attn_3
  • dpa/utils.py::get_attention_packend to switch use_flash_attention to True when flash_attn_3 is installed
  • tests/pytorch/attention/test_attention.py::_run_dot_product_attention to run FlashAttn for THD and pad_between_seqs=True

Checklist:

  • I have read and followed the contributing guidelines
  • [ a] The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@sudhakarsingh27 sudhakarsingh27 self-assigned this Jan 14, 2026
pre-commit-ci bot and others added 3 commits January 14, 2026 01:13
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 14, 2026

Greptile Summary

enables FlashAttention 3 backend for THD format with pad_between_seqs=True by leveraging FA3's seqused_q and seqused_k parameters

Key Changes

  • backends.py: passes cu_seqlens_q_padded/cu_seqlens_kv_padded to FA3 and computes seqused_q/seqused_k from the original cumulative sequence lengths to inform FA3 about actual sequence lengths vs padded lengths
  • utils.py: enables FlashAttention backend when FA3 is installed for THD+padding case (previously disabled for FA2)
  • dot_product_attention.py: plumbs through new parameters cu_seqlens_q_padded, cu_seqlens_kv_padded, and pad_between_seqs
  • test_attention.py: updates test to exercise FlashAttention with padded inputs similar to FusedAttention

Verification Needed

check that FA3 with seqused_q/seqused_k parameters correctly avoids writing to padding positions - issue #2391 mentions padding positions may need manual zeroing (as FusedAttention does in C++)

Confidence Score: 4/5

  • safe to merge with verification recommended
  • implementation correctly plumbs parameters through layers and follows FA3's expected interface for handling padding, but lacks explicit verification that FA3 internally zeroes padded output positions (unlike FusedAttention's explicit zeroing)
  • verify backends.py - confirm FA3 handles padding correctly with seqused_q/seqused_k

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/backends.py Added support for pad_between_seqs=True with FlashAttention 3 by passing cu_seqlens_q_padded, cu_seqlens_kv_padded, and computing seqused_q/seqused_k from the difference in cumulative sequence lengths
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py Plumbed through cu_seqlens_q_padded, cu_seqlens_kv_padded, and pad_between_seqs parameters to the FlashAttention backend
transformer_engine/pytorch/attention/dot_product_attention/utils.py Enabled FlashAttention 3 for THD format with padding between sequences, while keeping FlashAttention 2 disabled for this case
tests/pytorch/attention/test_attention.py Updated test to treat FlashAttention similar to FusedAttention by passing padded inputs and extracting valid ranges from output

Sequence Diagram

sequenceDiagram
    participant User
    participant DotProductAttention
    participant Utils as get_attention_backend
    participant FlashAttention as FlashAttention Backend
    participant FlashAttn3 as flash_attn_3 Library

    User->>DotProductAttention: forward(q, k, v, pad_between_seqs=True, cu_seqlens_q_padded, cu_seqlens_kv_padded)
    DotProductAttention->>Utils: get_attention_backend(qkv_format="thd", pad_between_seqs=True)
    
    alt FlashAttention 3 installed
        Utils-->>DotProductAttention: use_flash_attention=True
    else FlashAttention 2 only
        Utils-->>DotProductAttention: use_flash_attention=False
    end
    
    DotProductAttention->>FlashAttention: forward(q, k, v, cu_seqlens_q, cu_seqlens_q_padded, pad_between_seqs=True)
    
    FlashAttention->>FlashAttention: Use cu_seqlens_q_padded instead of cu_seqlens_q
    FlashAttention->>FlashAttention: Calculate seqused_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
    FlashAttention->>FlashAttention: Calculate seqused_k = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
    
    FlashAttention->>FlashAttn3: flash_attn_varlen_func_v3(q, k, v, cu_seqlens_q_padded, cu_seqlens_kv_padded, seqused_q, seqused_k)
    FlashAttn3-->>FlashAttention: output (with padding handled internally)
    FlashAttention-->>DotProductAttention: output
    DotProductAttention-->>User: output
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +974 to +983
# if `pad_between_seqs` is True, provide flash_attn_3 with `seqused_q` and `seqused_k`
# in addition to `cu_seqlens_q_padded` and `cu_seqlens_kv_padded` to avoid affecting the
# padding positions.
if pad_between_seqs:
fa_3_optional_forward_kwargs["seqused_q"] = (
cu_seqlens_q[1:] - cu_seqlens_q[:-1]
)
fa_3_optional_forward_kwargs["seqused_k"] = (
cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: verify that flash_attn_3 with seqused_q/seqused_k truly avoids writing to padding positions - the related issue #2391 mentions "we need to manually set the output of the padded positions to zero" (similar to how FusedAttention zeroes output in C++ for THD format). if flash_attn_3 doesn't zero these internally, output may have garbage values in padded positions. have you verified that flash_attn_3 correctly handles padding internally with these parameters?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support FlashAttention with pad_between_seqs=True

1 participant