feat(ck_tile_unification): Swizzle support + gfx950 mixPrecScale + misc#8315
feat(ck_tile_unification): Swizzle support + gfx950 mixPrecScale + misc#8315krithalith wants to merge 19 commits into
Conversation
5138319 to
fd99e01
Compare
4479f7f to
337070c
Compare
fd99e01 to
bf4ab7d
Compare
6171b73 to
8dd984e
Compare
JP-Fernando
left a comment
There was a problem hiding this comment.
This pull request adds swizzle support to the MMA-unified model. The author is looking for test cases for validation. So far my only concern is the computation of kABKLane in the three pipelines, exposed via WarpGemmAttribute::Impl.
chris-tsiaousis-hpc
left a comment
There was a problem hiding this comment.
LGTM, added some comments you might want to address!
| static constexpr index_t kM = MmaOp::kM; // Tentative | ||
|
|
||
| // Seems to be entire M size excluding blocks. Dubious for gfx1250, needs attention. | ||
| static constexpr index_t kAMLane = |
There was a problem hiding this comment.
Can we tighten this before it lands? kAMLane now feeds the old-style WarpGemmAttribute::Impl compatibility surface, so a hardcoded gfx1250 exception with Dubious in the comment feels a bit shaky. If this is only preserving existing behavior, maybe link it to the exact WMMA/gfx1250 follow-up?
There was a problem hiding this comment.
I defined the value like this because I think this will match the value for all the existing WarpGemms. The gfx1250 value is dubious because although it matches what is used currently in CK tile, it might not actually make sense for all gfx1250 intrinsics. I will add a note to whatever gfx1250 issue(s) we have to look into this (and the other) old-style layout params.
| * @tparam MmaOp Intrinsic (amdgcn_mma) to be tested | ||
| */ | ||
| template <typename MmaOp> // TODO: C++20 concept for MmaOp | ||
| template <typename MmaOp, bool CTranspose, int SFactor> // TODO: C++20 concept for MmaOp |
There was a problem hiding this comment.
NIT: SFactor is int here, but run_mma_layout_test_single and TileDistrEncCalc use index_t. Could we keep it index_t all the way through? It avoids a weird NTTP mismatch if index_t ever differs from int.
| EXPECT_EQ(h_errors[case_idx], 0u) << "Mismatch for m=" << m << " k=" << k << " n=" << n; | ||
| } | ||
| // Try all possible Swizzle factors. Incompatible intrinsics are skipped. | ||
| // CTranspose does not work with current test kernel. |
There was a problem hiding this comment.
The new WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution and i8 alias instantiate CTranspose=true with SFactor > 1, and those are the FMHA/SageAttention users. Can we add a CTranspose-aware check here, or at least a small compile-time test for that exact encoding?
There was a problem hiding this comment.
The layout test simply is not compatible with CTranspose currently. Luckily CTranspose is a pretty trivial modifier as far as the Tile Distribution Enc Calc is concerned. I will make an issue about CTranspose testing in the layout test. For now, we already know CTranspose is working in real pipelines because the of original tests and examples.
bf4ab7d to
a2816c8
Compare
f8be90f to
eeac4d5
Compare
51ef57e to
b362555
Compare
❌ PR Check — Action Required
📖 Need help? See the Policy FAQ for details on every check and how to fix failures. |
|
🚫 Please fix the failed policies before requesting reviews. The following policy checks failed:
The |
3f6ccf4 to
d933449
Compare
…tor, with minimal restrictions. Seems to work within the minimal restrictions for gfx908 in layout test for Swizzle 2 and 4. Activated only those intrinsics I expect to pass the layout test on other platforms. Needs more testing. Not all "working" layout configurations may actually make sense.
… compatible with the intrinsic.
…the Unification Dispatcher take a scalar SwizzleFactor instead of a bool because we may need a SwizzleFactor of 4 for some named WarpGemms. Added unification version of the two named WarpGemms with swizzle that are actually used in higher level code.
…AMLane and kABKLane. Re-ordered to match original order.
…ibution to the enc calc for blockless intrinsics to placate (fragile) higher level code, and add all mixed precision gfx950 intrinsics!
d933449 to
1263ea4
Compare
…x950 scale intrinsics. We add a canonical (1,2) or (2,1) AttrNumAccess to these intrinsics and make the tile distr enc calc use these as minima again.
f131075 to
f35f779
Compare
ISSUE ID #8960
#8960
Motivation
This MR is about adding Swizzle support to the Tile Distribution Encoding Calculator and Mma Pipelines in the Unification framework. Swizzle is a modifier for Tile Distribution Encodings that effectively performs a permutation in the M dimension. This means that it affects the Tile Distribution Encodings of A and C. When combined with CTranspose, it affects the Encodings of B and C instead. In principle, for a regular gemm, the Swizzle factor does not affect the correctness of the kernel, since matrix multiplication is symmetric under permutations of rows and columns (M). However, this is only true if the same Encodings are used for the loading and storing of the data. For consecutive matrix multiplications, we may be in a situation where we use Swizzle to account for the effective layout of an intermediate result, so that it can immediately be used in another matrix operation without additional shuffling. In these cases, the Swizzle factor is crucial for correctness. As far as I know, this seems most likely to occur in attention kernels.
Changes
I adapted the Tile Distribution Encoding Calculator to accept any Swizzle modifier, and use this to modify the layouts just like in CK Tile. Note that Swizzle is only compatible with certain intrinsics, due to the restriction that the Swizzle factor divides kCMNumAccess. This is possible for 32x32 MFMA instructions with SFactor 2 or 4, and for gfx11 WMMA instructions with SFactor 2, 4, or 8, although this is not used in CK Tile.
I adapted the layout test to check the correctness of layout with Swizzle modification, for all possible Swizzle factors for each intrinsic.
I adapted the Unification Dispatcher to take a Swizzle Factor and pass it on to the MmaPipelines. Note that the original dispatcher takes a boolean instead, which I convert to an SFactor of 2 when true. I believe this is correct since in all cases where CK Tile previously used the old dispatcher, and SFactor of 2 ended up being used. However, there are two named WarpGemms (WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution and WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution) which can support any Swizzle factor, and are actually used with Swizzle factors up to 4. These were not used in the old dispatcher but instead always used directly in CK Tile pipelines.
I added custom named WarpGemms in case the Unification flag is ON, for the named WarpGemms using Swizzle that are directly used in CK Tile pipelines. There are only two of them and they are the ones mentioned in the previous point.
Changes part 2
While trying to get a swizzle example to work, I ended up having to add a lot of other changes which would have normally been their own issue. We have:
Note on AttrNumAccess
For the scale gfx950 intrinsics, the "canonical" layouts for A and B have NumAccess 1 or 2, depending on the A and B types. The 8-bit types have a canonical NumAccess of 2, and the others 1. So overall we may have (1, 1), (2, 1), (1, 2), or (2, 2). This is reflected in the intrinsic definitions. However, for the fully 8-bit intrinsics I still define them with (1, 1). The reason for this is that it is in principle possible to use these intrinsics with (1, 1) as long as you don't use scale. This may actually happen in CK Tile. Furthermore, there are some pipelines that instantiate a WarpGemm with (1, 1) just to peek at some parameters. Note that the (1, 2) and (2, 1) cases MUST have these NumAccess values or the base MMA does not work (regardless of scale). This is because you can't just permute K for A without doing the same for B and vice versa.
Tests
Layout tests with swizzle work. tile_example_fmha_fwd and tile_example_fmha_bwd now compile and run, with correct verification for default settings. With fp8bf16 and init=3, get 5% wrong results on both this branch and develop, and this one is definitely sensitive to swizzle, because without swizzle it's 50% wrong. Still looking for better tests but we will encounter issues if there are any in our overall unification coverage checking scripts.