【3/N】add RDNA4 PA decode FP8 kernel, refactor shared reduce code#357
Closed
vivienfanghuagood wants to merge 1 commit intoROCm:mainfrom
Closed
【3/N】add RDNA4 PA decode FP8 kernel, refactor shared reduce code#357vivienfanghuagood wants to merge 1 commit intoROCm:mainfrom
vivienfanghuagood wants to merge 1 commit intoROCm:mainfrom
Conversation
Collaborator
Author
|
@coderfeli Hi Felix, we propose to add attention kernel for RDNA, and reuse most codes from CDNA's implementation. Can you help review our submissions? Thanks a lot! |
Collaborator
|
@vivienfanghuagood our PA is in a big reconstruction and perf tuning. could you wait several days for that? |
Collaborator
Author
Sure, it's okay. May I ask whether other kernels have similar refactoring plans? We can avoid these kernels when developing. |
Collaborator
|
also moe @vivienfanghuagood but code style change only. Current PA has many functional and perf issues. |
Add RDNA4 paged-attention FP8 decode kernel using WMMA wave32, refactor shared reduce/stride code into pa_common.py, and unify test_pa.py to handle both CDNA and RDNA architectures. New files: - kernels/pa_common.py: shared constants, compute_pa_strides(), build_ps_reduce_kernel(), build_v2_reduce_kernel() - kernels/rdna_pa_decode_fp8.py: RDNA4 WMMA dot kernel (383 lines) Modified: - kernels/pa_decode_fp8.py: imports shared code from pa_common - tests/kernels/test_pa.py: unified CDNA+RDNA test with arch-aware kernel selection (aiter optional for RDNA path) - tests/arch_compat.py: test_pa.py removed from CDNA_ONLY_TESTS (now self-manages via IS_RDNA/HAS_AITER guards) Removed: - tests/kernels/test_rdna_pa.py: merged into test_pa.py RDNA4 architecture: 8 warps x 32 lanes, WMMA f32_16x16x16_fp8_fp8, P staged as f32 in LDS (~16.5 KB). Correctness cos_sim > 0.999, performance 2.5-146x vs PyTorch SDPA on gfx1201. Made-with: Cursor
c944d2f to
9e84b58
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Add RDNA4 (gfx120x) paged-attention FP8 decode kernel to FlyDSL, enabling LLM inference decode on consumer RDNA GPUs. The existing CDNA kernel uses MFMA wave64 instructions and cannot run on RDNA hardware.
Technical Details
kernels/pa_common.py(new): shared constants, stride computation, and reduce kernels extracted frompa_decode_fp8.pyto eliminate duplication between CDNA and RDNA pathskernels/rdna_pa_decode_fp8.py(new): RDNA4 decode dot kernel usingwmma_f32_16x16x16_fp8_fp8with wave32 (8 warps × 32 lanes), softmax P staged through LDS as f32kernels/pa_decode_fp8.py(modified): CDNA kernel now imports shared code frompa_common, removing ~360 duplicate linestests/kernels/test_pa.py(modified): unified test for both CDNA and RDNA — arch-aware kernel selection, aiter made optional (RDNA validates against torch reference; CDNA validates against Gluon when aiter is available)tests/arch_compat.py(modified):test_pa.pyremoved fromCDNA_ONLY_TESTSsince it now self-manages arch dispatchTest Plan
python tests/kernels/test_pa.pyon gfx1201 — all PASS (cos_sim > 0.999 vs torch reference)run_single()path unchanged, needs gfx9xx CI validationTest Result
gfx1201 (RX 9070 XT, ROCm 7.1, PyTorch 2.9.1):
Submission Checklist