Skip to content

【3/N】add RDNA4 PA decode FP8 kernel, refactor shared reduce code#357

Closed
vivienfanghuagood wants to merge 1 commit intoROCm:mainfrom
vivienfanghuagood:rdna4-pa-decode-fp8
Closed

【3/N】add RDNA4 PA decode FP8 kernel, refactor shared reduce code#357
vivienfanghuagood wants to merge 1 commit intoROCm:mainfrom
vivienfanghuagood:rdna4-pa-decode-fp8

Conversation

@vivienfanghuagood
Copy link
Copy Markdown
Collaborator

@vivienfanghuagood vivienfanghuagood commented Apr 7, 2026

Motivation

Add RDNA4 (gfx120x) paged-attention FP8 decode kernel to FlyDSL, enabling LLM inference decode on consumer RDNA GPUs. The existing CDNA kernel uses MFMA wave64 instructions and cannot run on RDNA hardware.

Technical Details

  • kernels/pa_common.py (new): shared constants, stride computation, and reduce kernels extracted from pa_decode_fp8.py to eliminate duplication between CDNA and RDNA paths
  • kernels/rdna_pa_decode_fp8.py (new): RDNA4 decode dot kernel using wmma_f32_16x16x16_fp8_fp8 with wave32 (8 warps × 32 lanes), softmax P staged through LDS as f32
  • kernels/pa_decode_fp8.py (modified): CDNA kernel now imports shared code from pa_common, removing ~360 duplicate lines
  • tests/kernels/test_pa.py (modified): unified test for both CDNA and RDNA — arch-aware kernel selection, aiter made optional (RDNA validates against torch reference; CDNA validates against Gluon when aiter is available)
  • tests/arch_compat.py (modified): test_pa.py removed from CDNA_ONLY_TESTS since it now self-manages arch dispatch

Test Plan

  • RDNA correctness: python tests/kernels/test_pa.py on gfx1201 — all PASS (cos_sim > 0.999 vs torch reference)
  • RDNA performance: 2.5–146× faster than PyTorch SDPA (bf16) on gfx1201
  • CDNA regression: existing run_single() path unchanged, needs gfx9xx CI validation

Test Result

gfx1201 (RX 9070 XT, ROCm 7.1, PyTorch 2.9.1):

batch ctx kernel (us) status
1 128 6.7 PASS
1 256 7.6 PASS
4 4096 7.7 PASS
32 4096 35.1 PASS

Submission Checklist

@vivienfanghuagood
Copy link
Copy Markdown
Collaborator Author

vivienfanghuagood commented Apr 7, 2026

@coderfeli Hi Felix, we propose to add attention kernel for RDNA, and reuse most codes from CDNA's implementation. Can you help review our submissions? Thanks a lot!

@coderfeli
Copy link
Copy Markdown
Collaborator

@vivienfanghuagood our PA is in a big reconstruction and perf tuning. could you wait several days for that?

@vivienfanghuagood
Copy link
Copy Markdown
Collaborator Author

@vivienfanghuagood our PA is in a big reconstruction and perf tuning. could you wait several days for that?

Sure, it's okay. May I ask whether other kernels have similar refactoring plans? We can avoid these kernels when developing.

@coderfeli
Copy link
Copy Markdown
Collaborator

coderfeli commented Apr 7, 2026

also moe @vivienfanghuagood but code style change only. Current PA has many functional and perf issues.

Add RDNA4 paged-attention FP8 decode kernel using WMMA wave32, refactor
shared reduce/stride code into pa_common.py, and unify test_pa.py to
handle both CDNA and RDNA architectures.

New files:
- kernels/pa_common.py: shared constants, compute_pa_strides(),
  build_ps_reduce_kernel(), build_v2_reduce_kernel()
- kernels/rdna_pa_decode_fp8.py: RDNA4 WMMA dot kernel (383 lines)

Modified:
- kernels/pa_decode_fp8.py: imports shared code from pa_common
- tests/kernels/test_pa.py: unified CDNA+RDNA test with arch-aware
  kernel selection (aiter optional for RDNA path)
- tests/arch_compat.py: test_pa.py removed from CDNA_ONLY_TESTS
  (now self-manages via IS_RDNA/HAS_AITER guards)

Removed:
- tests/kernels/test_rdna_pa.py: merged into test_pa.py

RDNA4 architecture: 8 warps x 32 lanes, WMMA f32_16x16x16_fp8_fp8,
P staged as f32 in LDS (~16.5 KB). Correctness cos_sim > 0.999,
performance 2.5-146x vs PyTorch SDPA on gfx1201.

Made-with: Cursor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants