Skip to content

[ROCm] Split skinny_gemms_int8.cu into per-N translation units#1028

Draft
marcusr-amd wants to merge 7 commits into
gfx11from
marcusr/split-heavy-tu
Draft

[ROCm] Split skinny_gemms_int8.cu into per-N translation units#1028
marcusr-amd wants to merge 7 commits into
gfx11from
marcusr/split-heavy-tu

Conversation

@marcusr-amd

@marcusr-amd marcusr-amd commented Jun 26, 2026

Copy link
Copy Markdown

Summary

Split the 6 heaviest HIP translation units into per-N / per-dtype shards so MAX_JOBS=4 fits in the 16 GB CI runner memory budget. Cold wheel build drops from ~76 min to ~29 min (2.6x).

Before: MAX_JOBS=2 OOM-killed skinny_gemms_int4 (12.6 GB peak RSS). MAX_JOBS=1 was the only option.

After: Largest shard is 3.2 GB. Measured MAX_JOBS=4 at 14.0 GB peak (2 GB headroom) in a memory-capped docker container matching CI.

What changed

File Was Now
skinny_gemms_int4.cu 12.6 GB, 16 min 5 shards @ 3.2 GB
skinny_gemms.cu (bf16) 6.3 GB, 4.3 min 5 shards @ ~1.3 GB
skinny_gemms_int8.cu monolithic 5 shards (same pattern as w8a8)
attention.cu 3706 lines 2 per-dtype dispatch shards + headers
paged_attention_v1.cu 9 HEAD_SIZE instantiations 3 shards (small/medium/large)
selective_scan_fwd.cu 5 dtype instantiations 2 shards (fp16/bf16)

All splits follow the existing skinny_gemms_w8a8/ pattern: kernel templates in .cuh headers, per-shard .cu files with explicit template instantiations, thin entry point dispatching by N / dtype / head_size.

Test plan

  • Full wheel build on Strix Halo (gfx1103+gfx1150+gfx1151), all ops register
  • MAX_JOBS=2 @ 16g docker: PASS (44 min, 10.7 GB peak)
  • MAX_JOBS=4 @ 16g docker: PASS (29 min, 14.0 GB peak)
  • CI wheel build + kernel correctness tests
  • CI kernel performance tests

🤖 Generated with Claude Code

Break the monolithic 621-line skinny_gemms_int8.cu into per-N
instantiation shards following the existing skinny_gemms_w8a8/ pattern.
This lets make parallelize across the 5 TUs so that with MAX_JOBS=2,
no two heavy HIP compilations run concurrently.

New files in csrc/rocm/skinny_gemms_int8/:
  kernel.cuh        - kernel template + device helpers
  dispatch.cuh      - dispatch_int8<scalar_t, N> with GROUP_SIZE variants
  launch.h          - per-N launcher declarations
  instantiate_n{1-5}.cu - one TU per N, explicit half + bf16 instantiations

The entry point (skinny_gemms_int8.cu) is reduced to validation,
heuristic selection, and a switch(N) that calls launch_int8_nX().

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
Break the monolithic 3706-line attention.cu into a thin entry point
plus two per-dtype dispatch shards (auto and fp8). The kernel templates
and launcher functions move to headers under csrc/rocm/attention/.

This gives the build system 3 .cu files × 3 GPU architectures = 9
compilation tasks that can be parallelized with MAX_JOBS >= 2, vs the
original single file that serialized all template instantiations.

New files in csrc/rocm/attention/:
  kernels.cuh      - all __global__ kernel templates (GFX9/GFX11/GFX12/fallback)
  launcher.h       - paged_attention_custom_launcher + _navi (template defs)
  launch_macros.h  - LAUNCH_CUSTOM_ATTENTION_MFMA16/MFMA4/REDUCTION macros
  dispatch_auto.cu - kv_cache_dtype=="auto" dispatch (bf16/fp16 KV cache)
  dispatch_fp8.cu  - kv_cache_dtype=="fp8" dispatch (FP8 KV cache)

The entry point (attention.cu) is reduced to is_navi_gpu() detection and
a kv_cache_dtype switch that calls the per-shard dispatch functions.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
Break the paged_attention_v1 launcher into 3 per-HEAD_SIZE shards
(small: 32/64/80, medium: 96/112/120/128, large: 192/256) to reduce
peak memory per compilation unit. Each shard defines its own launcher
template with only its HEAD_SIZE cases, so the compiler instantiates
fewer kernel templates per TU.

New files in csrc/libtorch_stable/attention/paged_attention_v1/:
  common.h        - shared includes, macros, launcher preamble, shard decls
  shard_small.cu  - HEAD_SIZE 32, 64, 80
  shard_medium.cu - HEAD_SIZE 96, 112, 120, 128
  shard_large.cu  - HEAD_SIZE 192, 256

The entry point (paged_attention_v1.cu) dispatches to the right shard
by head_size at runtime.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
Move the kernel template and launcher into a header, then split the 5
explicit template instantiations across 2 per-dtype shards (fp16/fp32
and bf16). This reduces peak memory per compilation unit.

New files in csrc/libtorch_stable/mamba/selective_scan_fwd/:
  kernel.cuh    - kernel traits, __global__ kernel, launcher templates
  shard_fp16.cu - explicit instantiations for Half and float input types
  shard_bf16.cu - explicit instantiations for BFloat16 input type

The entry point (selective_scan_fwd.cu) keeps only the param setup,
validation, and dtype dispatch. It calls selective_scan_fwd_cuda via
forward declaration; the linker resolves to the shard that has the
matching explicit instantiation.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
@marcusr-amd marcusr-amd force-pushed the marcusr/split-heavy-tu branch from 3123d18 to 71ab095 Compare June 26, 2026 19:58
Break the monolithic skinny_gemms_int4.cu (12.6 GB peak RSS, 16 min)
into per-N shards following the w8a8/int8 pattern. Each shard handles
both the standard wvSplitK_int4_g and fused MoE dispatch for one N
value.

Peak RSS drops from 12.6 GB to ~3.2 GB per shard, which was the
primary OOM bottleneck blocking MAX_JOBS=2 on 16 GB CI runners.

New files in csrc/rocm/skinny_gemms_int4/:
  dispatch.cuh         - dispatch_int4_g and dispatch_moe_int4_g templates
  launch.h             - per-N launcher declarations
  instantiate_n{1-5}.cu - one TU per N, explicit half + bf16 instantiations

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
Break the monolithic skinny_gemms.cu (6.3 GB peak RSS, 260s) into
per-N shards for the wvSplitK dispatch. The fused variants
(fused_silu_mul, fused_silu_gate_mul) and other ops (LLMM1, wvSplitKrc,
wvSplitKQ) remain in the entry point since they are N=1 only or
unrelated to the per-N split.

New files in csrc/rocm/skinny_gemms/:
  kernel.cuh          - helper functions (inline) + kernel templates
  dispatch.cuh        - dispatch_wvsplitk<scalar_t, N> template
  launch.h            - per-N launcher declarations
  instantiate_n{1-5}.cu - one TU per N, explicit half + bf16 instantiations

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
The heavy HIP translation units have been split into per-N shards,
reducing peak RSS per compilation from 12.6 GB (skinny_gemms_int4)
and 6.3 GB (skinny_gemms bf16) down to <=3.2 GB per shard.

MAX_JOBS=2 was previously reverted (adf8d91) due to OOM on the
16 GB GitHub-hosted runners. With the TU splits, MAX_JOBS=2 passes
with 10.7 GB peak (5 GB headroom). MAX_JOBS=4 is expected to peak
at ~13 GB based on per-file RSS measurements.

Split files:
  skinny_gemms_int4.cu  12.6 GB -> 5 shards @ 3.2 GB each
  skinny_gemms.cu        6.3 GB -> 5 shards @ ~1.3 GB each
  skinny_gemms_int8.cu   split into 5 per-N shards
  attention.cu           split into 2 per-dtype shards
  paged_attention_v1.cu  split into 3 per-HEAD_SIZE shards
  selective_scan_fwd.cu  split into 2 per-dtype shards

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
@marcusr-amd marcusr-amd requested a review from mgehre-amd June 29, 2026 16:36
@mgehre-amd

Copy link
Copy Markdown

I'm concerned that this kind of large changes will make it harder for use to merge upstream/main unless we can upstream those change ourselves quickly. But at the same time, I'm happy to reduce compile time as this really hurts the capability to iterate.
Thoughts?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants