[ROCm] Split skinny_gemms_int8.cu into per-N translation units#1028
Draft
marcusr-amd wants to merge 7 commits into
Draft
[ROCm] Split skinny_gemms_int8.cu into per-N translation units#1028marcusr-amd wants to merge 7 commits into
marcusr-amd wants to merge 7 commits into
Conversation
Break the monolithic 621-line skinny_gemms_int8.cu into per-N
instantiation shards following the existing skinny_gemms_w8a8/ pattern.
This lets make parallelize across the 5 TUs so that with MAX_JOBS=2,
no two heavy HIP compilations run concurrently.
New files in csrc/rocm/skinny_gemms_int8/:
kernel.cuh - kernel template + device helpers
dispatch.cuh - dispatch_int8<scalar_t, N> with GROUP_SIZE variants
launch.h - per-N launcher declarations
instantiate_n{1-5}.cu - one TU per N, explicit half + bf16 instantiations
The entry point (skinny_gemms_int8.cu) is reduced to validation,
heuristic selection, and a switch(N) that calls launch_int8_nX().
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
Break the monolithic 3706-line attention.cu into a thin entry point plus two per-dtype dispatch shards (auto and fp8). The kernel templates and launcher functions move to headers under csrc/rocm/attention/. This gives the build system 3 .cu files × 3 GPU architectures = 9 compilation tasks that can be parallelized with MAX_JOBS >= 2, vs the original single file that serialized all template instantiations. New files in csrc/rocm/attention/: kernels.cuh - all __global__ kernel templates (GFX9/GFX11/GFX12/fallback) launcher.h - paged_attention_custom_launcher + _navi (template defs) launch_macros.h - LAUNCH_CUSTOM_ATTENTION_MFMA16/MFMA4/REDUCTION macros dispatch_auto.cu - kv_cache_dtype=="auto" dispatch (bf16/fp16 KV cache) dispatch_fp8.cu - kv_cache_dtype=="fp8" dispatch (FP8 KV cache) The entry point (attention.cu) is reduced to is_navi_gpu() detection and a kv_cache_dtype switch that calls the per-shard dispatch functions. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
Break the paged_attention_v1 launcher into 3 per-HEAD_SIZE shards (small: 32/64/80, medium: 96/112/120/128, large: 192/256) to reduce peak memory per compilation unit. Each shard defines its own launcher template with only its HEAD_SIZE cases, so the compiler instantiates fewer kernel templates per TU. New files in csrc/libtorch_stable/attention/paged_attention_v1/: common.h - shared includes, macros, launcher preamble, shard decls shard_small.cu - HEAD_SIZE 32, 64, 80 shard_medium.cu - HEAD_SIZE 96, 112, 120, 128 shard_large.cu - HEAD_SIZE 192, 256 The entry point (paged_attention_v1.cu) dispatches to the right shard by head_size at runtime. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
Move the kernel template and launcher into a header, then split the 5 explicit template instantiations across 2 per-dtype shards (fp16/fp32 and bf16). This reduces peak memory per compilation unit. New files in csrc/libtorch_stable/mamba/selective_scan_fwd/: kernel.cuh - kernel traits, __global__ kernel, launcher templates shard_fp16.cu - explicit instantiations for Half and float input types shard_bf16.cu - explicit instantiations for BFloat16 input type The entry point (selective_scan_fwd.cu) keeps only the param setup, validation, and dtype dispatch. It calls selective_scan_fwd_cuda via forward declaration; the linker resolves to the shard that has the matching explicit instantiation. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
3123d18 to
71ab095
Compare
Break the monolithic skinny_gemms_int4.cu (12.6 GB peak RSS, 16 min)
into per-N shards following the w8a8/int8 pattern. Each shard handles
both the standard wvSplitK_int4_g and fused MoE dispatch for one N
value.
Peak RSS drops from 12.6 GB to ~3.2 GB per shard, which was the
primary OOM bottleneck blocking MAX_JOBS=2 on 16 GB CI runners.
New files in csrc/rocm/skinny_gemms_int4/:
dispatch.cuh - dispatch_int4_g and dispatch_moe_int4_g templates
launch.h - per-N launcher declarations
instantiate_n{1-5}.cu - one TU per N, explicit half + bf16 instantiations
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
Break the monolithic skinny_gemms.cu (6.3 GB peak RSS, 260s) into
per-N shards for the wvSplitK dispatch. The fused variants
(fused_silu_mul, fused_silu_gate_mul) and other ops (LLMM1, wvSplitKrc,
wvSplitKQ) remain in the entry point since they are N=1 only or
unrelated to the per-N split.
New files in csrc/rocm/skinny_gemms/:
kernel.cuh - helper functions (inline) + kernel templates
dispatch.cuh - dispatch_wvsplitk<scalar_t, N> template
launch.h - per-N launcher declarations
instantiate_n{1-5}.cu - one TU per N, explicit half + bf16 instantiations
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
The heavy HIP translation units have been split into per-N shards, reducing peak RSS per compilation from 12.6 GB (skinny_gemms_int4) and 6.3 GB (skinny_gemms bf16) down to <=3.2 GB per shard. MAX_JOBS=2 was previously reverted (adf8d91) due to OOM on the 16 GB GitHub-hosted runners. With the TU splits, MAX_JOBS=2 passes with 10.7 GB peak (5 GB headroom). MAX_JOBS=4 is expected to peak at ~13 GB based on per-file RSS measurements. Split files: skinny_gemms_int4.cu 12.6 GB -> 5 shards @ 3.2 GB each skinny_gemms.cu 6.3 GB -> 5 shards @ ~1.3 GB each skinny_gemms_int8.cu split into 5 per-N shards attention.cu split into 2 per-dtype shards paged_attention_v1.cu split into 3 per-HEAD_SIZE shards selective_scan_fwd.cu split into 2 per-dtype shards Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
|
I'm concerned that this kind of large changes will make it harder for use to merge upstream/main unless we can upstream those change ourselves quickly. But at the same time, I'm happy to reduce compile time as this really hurts the capability to iterate. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Split the 6 heaviest HIP translation units into per-N / per-dtype shards so
MAX_JOBS=4fits in the 16 GB CI runner memory budget. Cold wheel build drops from ~76 min to ~29 min (2.6x).Before:
MAX_JOBS=2OOM-killedskinny_gemms_int4(12.6 GB peak RSS).MAX_JOBS=1was the only option.After: Largest shard is 3.2 GB. Measured
MAX_JOBS=4at 14.0 GB peak (2 GB headroom) in a memory-capped docker container matching CI.What changed
skinny_gemms_int4.cuskinny_gemms.cu(bf16)skinny_gemms_int8.cuattention.cupaged_attention_v1.cuselective_scan_fwd.cuAll splits follow the existing
skinny_gemms_w8a8/pattern: kernel templates in.cuhheaders, per-shard.cufiles with explicit template instantiations, thin entry point dispatching by N / dtype / head_size.Test plan
MAX_JOBS=2 @ 16gdocker: PASS (44 min, 10.7 GB peak)MAX_JOBS=4 @ 16gdocker: PASS (29 min, 14.0 GB peak)🤖 Generated with Claude Code