-
Notifications
You must be signed in to change notification settings - Fork 806
Description
Using Sparse Matrix Instructions for Skinny GEMM on AMDGPU
This issue is an evolving document describing the design and implementation of using sparse matrix instructions for optimizing skinny GEMM on AMDGPU in IREE.
Introduction
The original trick was described in the excellent blog post Creating custom kernels for the AMD MI300 by Rémi Ouazan Reboul and seungrok jung. The motivation is to enable this optimization in IREE for our skinny GEMM workloads.
Using FP8 as an example: currently skinny GEMM (M=8) workloads of this size are padded to M=16 and use the smallest MFMA operation. On CDNA3, this is 16x16x32 and takes 16 cycles. Notably, redundant work is performed due to the padding required. The blog post proposes using sparse MFMA instructions instead. The sparse equivalent is 16x16x64 and takes the same 16 cycles, allowing us to process twice the K-depth for the same number of cycles.
Semantics of Instruction (on CDNA3+)
The V_SMFMAC family of instructions perform matrix multiply-accumulate operations on a 4:2 structurally sparse matrix A and dense matrices B, C, and D: D = C + A × B. The A matrix is sparse and requires 2 VGPRs per lane. The B matrix is dense and requires 4 VGPRs per lane. C/D share 16 VGPRs. The C operand has been repurposed to hold the index data.
Index Data Encoding
For 16-bit source data: We have 2 VGPRs × 32 bits per lane = 64 bits per lane / 16-bit source = 4 non-zero elements per thread from the original 8 elements. With 4:2 structure, this is described in 2 groups encoding the indices of 4 non-zero elements. Each group needs 4 bits: [idx1[1:0], idx0[1:0]] where idx0 < idx1.
For example, suppose we have the following 8 elements:
a0 a1 0 0 | 0 0 a2 a3 | a4 a5 0 0 | 0 0 a6 a7
Then the indices of the non-zero elements, packed as [group3][group2][group1][group0], are:
1110 0100 1110 0100 = 0xE4E4
For 8-bit source data: This works similarly, with 8 non-zero elements per thread from the original 16 elements. Thus there are 4 groups describing the indices of 8 non-zero elements.
CBSZ and ABID
The CBSZ and ABID fields have been repurposed to select the desired portion of the index data. If CBSZ == 0, then ABID[1:0] specifies which index set to use. If CBSZ != 0, then the first set is always selected.
Other Differences from Dense MFMA
- SRC2 (index values) ignores the ACC_CD bit; SRC2 always comes from architectural VGPRs since it contains index values rather than matrix data.
- The
blgpfield is not relevant.
For full details, please refer to the complete CDNA3/4 ISA.
Usage in Skinny GEMM
In the kernel implementation that the blog describes, we see how this trick is applied. Even-numbered threads cover indices 0,1 for each group of 4 elements along K in matrix A, while odd-numbered threads cover indices 2,3. After the intrinsic, the partial sums from paired threads are added together to produce the full result.
The thread layout per subgroup is as follows:
K (64 positions)
K=0-15 K=16-31 K=32-47 K=48-63
Dense row 0: (t0,t1) (t2,t3) (t4,t5) (t6,t7)
Dense row 1: (t8,t9) (t10,t11) (t12,t13) (t14,t15)
Dense row 2: (t16,t17) (t18,t19) (t20,t21) (t22,t23)
Dense row 3: (t24,t25) (t26,t27) (t28,t29) (t30,t31)
Dense row 4: (t32,t33) (t34,t35) (t36,t37) (t38,t39)
Dense row 5: (t40,t41) (t42,t43) (t44,t45) (t46,t47)
Dense row 6: (t48,t49) (t50,t51) (t52,t53) (t54,t55)
Dense row 7: (t56,t57) (t58,t59) (t60,t61) (t62,t63)
Each (even, odd) thread pair jointly covers 16 K positions for one dense row. For example, for K=0-15 of dense row 0:
- t0 (index 0x4444): reads K = 0,1,4,5,8,9,12,13 (8 values at positions 0,1 in each group)
- t1 (index 0xEEEE): reads K = 2,3,6,7,10,11,14,15 (8 values at positions 2,3 in each group)
Together, all 16 K positions are covered, and the hardware produces partial sums that are later combined to yield the full dense result.
Support in IREE
- AMDGPU wrapper for sparse matrix instructions
- #iree_gpu.sparse_mma_layout (?)
- Required plumbing, including selecting sparse_mfma when deducing desired MMA.
- Extend to RDNA equivalents