fix(pd): transfer MiniMax-M3 sparse indexer-key cache in disaggregation#1368
Merged
Conversation
MiniMax-M3 sparse attention reuses the unified KV cache and kv_scale for K/V, so the fp8 per-token scales already travel with the KV blocks. It keeps one extra per-token buffer, runner.sparse_attention_index_cache, holding the indexer keys used for top-k block selection at decode time. get_kv_transfer_tensors() never registered that buffer, so under PD disaggregation the decode node ran top-k against a zero/stale index for the prefilled tokens and attended to the wrong KV blocks. This is masked for short prompts (the init+local+topk window already covers every block, so selection is moot) but corrupts output once the context exceeds that window. Register the indexer-key cache as block-indexed transfer regions (one per sparse layer, same physical-block striding as the KV cache), guarded by getattr so non-sparse models and bf16 paths are unaffected. Tested (latest image, 1P+1D TP4, fp8 KV via Triton attention): GSM8K 5-shot = 0.9401, i.e. no regression to M3 fp8 PD. Short-prompt GSM8K does not exercise the long-context top-k path the buffer affects; that path is covered by review, not this run.
Contributor
There was a problem hiding this comment.
Pull request overview
Fixes PD disaggregation correctness for MiniMax-M3 sparse attention by ensuring the per-token sparse indexer-key cache is included in the KV RDMA transfer set. This prevents decode workers from running top‑k block selection against stale/zero index-cache data for prefilled tokens, which can mis-select KV blocks once context grows beyond the init/local/top‑k coverage window.
Changes:
- Register
runner.sparse_attention_index_cacheas additional block-indexed transfer regions inget_kv_transfer_tensors(). - Guard the new transfer registration via
getattr(...)so non-sparse models (and runners without the buffer) are unaffected.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
valarLip
approved these changes
Jun 29, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
MiniMax-M3 sparse attention reuses the unified KV cache and kv_scale for K/V, so the fp8 per-token scales already travel with the KV blocks. It keeps one extra per-token buffer, runner.sparse_attention_index_cache, holding the indexer keys used for top-k block selection at decode time. get_kv_transfer_tensors() never registered that buffer, so under PD disaggregation the decode node ran top-k against a zero/stale index for the prefilled tokens and attended to the wrong KV blocks. This is masked for short prompts (the init+local+topk window already covers every block, so selection is moot) but corrupts output once the context exceeds that window.
Register the indexer-key cache as block-indexed transfer regions (one per sparse layer, same physical-block striding as the KV cache), guarded by getattr so non-sparse models and bf16 paths are unaffected.
Tested (latest image, 1P+1D TP4, fp8 KV via Triton attention): GSM8K 5-shot = 0.9401, i.e. no regression to M3 fp8 PD. Short-prompt GSM8K does not exercise the long-context top-k path the buffer affects; that path is covered by review, not this run.
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist