Skip to content

add CDNA3 flash attention forward kernels#33

Open
Jayluci4 wants to merge 1 commit into
HazyResearch:cdna3from
Jayluci4:cdna3-attn-forward
Open

add CDNA3 flash attention forward kernels#33
Jayluci4 wants to merge 1 commit into
HazyResearch:cdna3from
Jayluci4:cdna3-attn-forward

Conversation

@Jayluci4
Copy link
Copy Markdown

Summary

  • 4 flash attention forward kernels for MI300X (gfx942): non-causal + causal, D=128 + D=64
  • FlashAttention-2 online softmax with base-2 exp, GQA support via head index mapping
  • register-buffer pipeline (load_global_to_register_buffer / store_register_buffer_to_shared) overlaps global memory loads with MFMA compute
  • launch_bounds(256,1) for D=128 eliminates all VGPR spills by using AGPRs, scratch dropped from 876 to 20 bytes/lane
  • scheduling hints (__builtin_amdgcn_s_setprio, sched_barrier) around MMA clusters

Performance (MI300X, B=4 H=32 H_KV=8 N=1024)

  • D=128 non-causal: 69.0 TFLOPS (1.25x PyTorch SDPA)
  • D=64 non-causal: 72.7 TFLOPS (1.37x PyTorch SDPA)
  • D=128 causal: 54.8 TFLOPS
  • D=64 causal: 56.5 TFLOPS

Test plan

  • correctness vs torch.nn.functional.scaled_dot_product_attention (cosine_sim > 0.999)
  • tested across N=64, 100, 128, 512, 1024, 2048, 4096
  • both causal and non-causal verified
  • zero VGPR spills confirmed via -Rpass-analysis=kernel-resource-usage

Closes #12

🤖 Generated with Claude Code

…causal)

FlashAttention-2 forward pass using mfma_f32_16x16x16 with online softmax
(base-2 exp), GQA head mapping, and register-buffer pipeline for overlapping
global loads with MFMA compute. Beats PyTorch SDPA by 1.25-1.37x on MI300X.

Closes HazyResearch#12

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@willhu-jpg
Copy link
Copy Markdown
Collaborator

Hi! Thank you for the contribution. I compiled the gqa and gqa causal kernels and found vgpr spillage and scratch usage. I think the performance would improve a lot once we get rid of those!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants