Skip to content

[AMDGPU][Codegen] Optimizing skinny GEMM workloads with sparse matrix instructions #22863

@efric

Description

@efric

Using Sparse Matrix Instructions for Skinny GEMM on AMDGPU

This issue is an evolving document describing the design and implementation of using sparse matrix instructions for optimizing skinny GEMM on AMDGPU in IREE.

Introduction

The original trick was described in the excellent blog post Creating custom kernels for the AMD MI300 by Rémi Ouazan Reboul and seungrok jung. The motivation is to enable this optimization in IREE for our skinny GEMM workloads.

Using FP8 as an example: currently skinny GEMM (M=8) workloads of this size are padded to M=16 and use the smallest MFMA operation. On CDNA3, this is 16x16x32 and takes 16 cycles. Notably, redundant work is performed due to the padding required. The blog post proposes using sparse MFMA instructions instead. The sparse equivalent is 16x16x64 and takes the same 16 cycles, allowing us to process twice the K-depth for the same number of cycles.

Semantics of Instruction (on CDNA3+)

The V_SMFMAC family of instructions perform matrix multiply-accumulate operations on a 4:2 structurally sparse matrix A and dense matrices B, C, and D: D = C + A × B. The A matrix is sparse and requires 2 VGPRs per lane. The B matrix is dense and requires 4 VGPRs per lane. C/D share 16 VGPRs. The C operand has been repurposed to hold the index data.

Index Data Encoding

For 16-bit source data: We have 2 VGPRs × 32 bits per lane = 64 bits per lane / 16-bit source = 4 non-zero elements per thread from the original 8 elements. With 4:2 structure, this is described in 2 groups encoding the indices of 4 non-zero elements. Each group needs 4 bits: [idx1[1:0], idx0[1:0]] where idx0 < idx1.

For example, suppose we have the following 8 elements:

a0 a1 0 0 | 0 0 a2 a3 | a4 a5 0 0 | 0 0 a6 a7 

Then the indices of the non-zero elements, packed as [group3][group2][group1][group0], are:

1110 0100 1110 0100 = 0xE4E4

For 8-bit source data: This works similarly, with 8 non-zero elements per thread from the original 16 elements. Thus there are 4 groups describing the indices of 8 non-zero elements.

CBSZ and ABID

The CBSZ and ABID fields have been repurposed to select the desired portion of the index data. If CBSZ == 0, then ABID[1:0] specifies which index set to use. If CBSZ != 0, then the first set is always selected.

Other Differences from Dense MFMA

  • SRC2 (index values) ignores the ACC_CD bit; SRC2 always comes from architectural VGPRs since it contains index values rather than matrix data.
  • The blgp field is not relevant.

For full details, please refer to the complete CDNA3/4 ISA.

Usage in Skinny GEMM

In the kernel implementation that the blog describes, we see how this trick is applied. Even-numbered threads cover indices 0,1 for each group of 4 elements along K in matrix A, while odd-numbered threads cover indices 2,3. After the intrinsic, the partial sums from paired threads are added together to produce the full result.

The thread layout per subgroup is as follows:

                        K (64 positions)
                K=0-15    K=16-31   K=32-47   K=48-63
Dense row 0:    (t0,t1)   (t2,t3)   (t4,t5)   (t6,t7)
Dense row 1:    (t8,t9)   (t10,t11) (t12,t13) (t14,t15)
Dense row 2:    (t16,t17) (t18,t19) (t20,t21) (t22,t23)
Dense row 3:    (t24,t25) (t26,t27) (t28,t29) (t30,t31)
Dense row 4:    (t32,t33) (t34,t35) (t36,t37) (t38,t39)
Dense row 5:    (t40,t41) (t42,t43) (t44,t45) (t46,t47)
Dense row 6:    (t48,t49) (t50,t51) (t52,t53) (t54,t55)
Dense row 7:    (t56,t57) (t58,t59) (t60,t61) (t62,t63)

Each (even, odd) thread pair jointly covers 16 K positions for one dense row. For example, for K=0-15 of dense row 0:

  • t0 (index 0x4444): reads K = 0,1,4,5,8,9,12,13 (8 values at positions 0,1 in each group)
  • t1 (index 0xEEEE): reads K = 2,3,6,7,10,11,14,15 (8 values at positions 2,3 in each group)

Together, all 16 K positions are covered, and the hardware produces partial sums that are later combined to yield the full dense result.

Support in IREE

  • AMDGPU wrapper for sparse matrix instructions
  • #iree_gpu.sparse_mma_layout (?)
  • Required plumbing, including selecting sparse_mfma when deducing desired MMA.
  • Extend to RDNA equivalents

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions