Skip to content

[AutoSP] (Sequence Parallelism) support for Multimodal Models (ViT + LLM)#7984

Merged
delock merged 35 commits intodeepspeedai:masterfrom
nathon-lee:multimodal-seq-parallel
May 1, 2026
Merged

[AutoSP] (Sequence Parallelism) support for Multimodal Models (ViT + LLM)#7984
delock merged 35 commits intodeepspeedai:masterfrom
nathon-lee:multimodal-seq-parallel

Conversation

@nathon-lee
Copy link
Copy Markdown
Contributor

Description

Hello DeepSpeed Team! 👋

This PR directly addresses the "Multimodal model support" goal outlined in the DeepSpeed Roadmap Q2 2026 (#7861).

It introduces AutoSP (Sequence Parallelism) support for Multimodal Models (ViT + LLM) out of the box. As noted in the roadmap, multimodal models handle significantly longer sequence lengths, making SP critical. This PR automates the injection of DeepSpeed Ulysses-based sequence parallelism into multimodal architectures, removing the need for manual and error-prone engineering efforts.

This is a consolidated PR of several incremental features developed and thoroughly tested in my fork.

🎯 Related Issue

🌟 Key Features & Contributions

  1. AutoSP Scaffolding & Detector (auto_wrap_model_for_sp):

    • Introduced a scanning utility to automatically detect ViT encoders and LLM decoders within a multimodal model.
    • Automatically wraps LLM decoder attention layers with DeepSpeed's existing DistributedAttention.
  2. ViT Sequence Parallelism (UlyssesSPViTAttention):

    • Implemented a Ulysses-style Gather-Compute-Scatter sequence parallel wrapper tailored for non-causal ViT attention layers.
    • Significantly reduces the memory footprint of ViT Feed-Forward Networks (FFN) and LayerNorms across the sequence dimension.
  3. Cross-Modal Fusion Adapters (Phase 2):

    • Handled the complex sequence scatter/gather at the vision-language boundary to ensure the LLM decoder receives uniformly sharded fused sequences.
    • Supported architectures include:
      • LLaVA (LlavaFusionAdapter): Visual token splice replacing image placeholders.
      • InternVL (InternVLFusionAdapter): IMG_CONTEXT token splice.
      • Qwen2-VL (Qwen2VLFusionAdapter): Vision_start/end bounded splice.

🧪 Testing & Validation

To ensure this PR does not break any existing functionality and is numerically sound, comprehensive tests have been added:

  • Numerical Equivalence Tests: Added multi-GPU tests (tests/unit/sequence_parallelism/test_autosp_equivalence.py) verifying that the SP-wrapped path across N ranks produces the exact same numerical results as the equivalent single-device (non-SP) computation.
  • Integration Tests: End-to-end mock integration tests validating the full pipeline from ViT to fusion adapter.
  • Benchmarks Provided: Included a multimodal SP benchmark script (benchmarks/autosp/bench_multimodal_sp.py) to easily verify throughput scaling and peak GPU memory reduction.

(All tests pass cleanly on 2 GPUs with NCCL_P2P_DISABLE=1)

🚧 Known Limitations & Future Work

To be fully transparent, there are a few limitations in the current design that I plan to improve in follow-up iterations (or would love guidance on from the team):

  1. Manual Wrapping for Fusion Layers: While ViT and LLM attentions are wrapped automatically, the vision projection layer currently requires manual wrapping with ModalityFusionSPAdapter due to varying HF model implementations. Fully automating Phase 2 is a logical next step.
  2. ViT SP Trade-off: The current UlyssesSPViTAttention uses a Gather-Compute-Scatter approach. While it successfully reduces FFN memory by $1/N$, it still computes the full attention matrix on every rank. A true All-to-All sequence-to-head transposition for Opaque ViT layers is something I am actively exploring.
  3. Padding Attention Mask: When fused_len % world_size != 0, zero-padding is applied. Currently, the global attention_mask is not automatically intercepted and patched, which might require user attention during inference.

I would deeply appreciate any feedback or suggestions from the maintainers! I am more than happy to make any required adjustments, refactorings, or add further test cases to get this perfectly aligned with the Q2 roadmap and DeepSpeed's standards.

Thank you for your time reviewing this! 🚀

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 7475f39b05

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread deepspeed/sequence/auto_sp.py Outdated
# (query, key, value) tensors separately. For HuggingFace-style
# attention that receives hidden_states, use scatter_idx=2 / gather_idx=0
# defaults which match the typical [bs, seq, heads, dim] layout.
wrapped = DistributedAttention(local_attention=module, sequence_process_group=process_group)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Gate LLM wrapping on DistributedAttention-compatible signatures

This replaces every detected LLM attention module with DistributedAttention, but DistributedAttention.forward expects (query, key, value, batch_dim_idx) and then calls the wrapped module with Q/K/V tensors (deepspeed/sequence/layer.py). For the HF-style classes listed in _LLM_ATTN_CLASSNAMES, decoder blocks are still invoked with hidden_states, so the first wrapped decoder forward will fail (or run with invalid semantics) because no call-site adapter is added here. The new integration tests also avoid running the LLM branch due this interface mismatch, which confirms this auto-wrap path is not actually executable end-to-end.

Useful? React with 👍 / 👎.

Comment thread deepspeed/sequence/auto_sp.py Outdated
# Wrap ViT encoder attention layers
# ------------------------------------------------------------------
for name, module in info.vit_attn_modules:
wrapped = UlyssesSPViTAttention(module, process_group)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Configure CLS handling when wrapping ViT attention modules

Auto-wrapping always constructs UlyssesSPViTAttention with its default has_cls_token=True, which is incorrect for patch-only vision encoders (including classes newly registered here such as InternVisionAttention and Qwen2VLVisionAttention). In those models, token 0 is a real patch, so forcing CLS mode causes gather/scatter to treat a patch as replicated CLS and mis-partition the sequence. The benchmark and integration tests must manually flip has_cls_token=False after wrapping, so the advertised one-call wrapping is incorrect by default for those targets.

Useful? React with 👍 / 👎.

Signed-off-by: nathon-lee <leejianwoo@gmail.com>
Signed-off-by: nathon-lee <leejianwoo@gmail.com>
Signed-off-by: nathon-lee <leejianwoo@gmail.com>
Signed-off-by: nathon-lee <leejianwoo@gmail.com>
…splice

Signed-off-by: nathon-lee <leejianwoo@gmail.com>

fix: fix some format issue by pre-commit

Signed-off-by: nathon-lee <leejianwoo@gmail.com>

fix: fix some format err by tool

Signed-off-by: nathon-lee <leejianwoo@gmail.com>
Signed-off-by: nathon-lee <leejianwoo@gmail.com>
…parallelism

Signed-off-by: nathon-lee <leejianwoo@gmail.com>
Signed-off-by: nathon-lee <leejianwoo@gmail.com>
Signed-off-by: nathon-lee <leejianwoo@gmail.com>

fix: fix some format errs by tool

Signed-off-by: nathon-lee <leejianwoo@gmail.com>
Signed-off-by: nathon-lee <leejianwoo@gmail.com>
Signed-off-by: nathon-lee <leejianwoo@gmail.com>

fix: fix some format err by tool

Signed-off-by: nathon-lee <leejianwoo@gmail.com>
Signed-off-by: nathon-lee <leejianwoo@gmail.com>

fix: delete get_accelerator for not use.

Signed-off-by: nathon-lee <leejianwoo@gmail.com>
Signed-off-by: nathon-lee <leejianwoo@gmail.com>
@nathon-lee nathon-lee force-pushed the multimodal-seq-parallel branch from 7475f39 to c21fe99 Compare April 27, 2026 02:34

# DeepSpeed Team
"""
ModalityFusionSPAdapter — Phase 2
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If these adapters intend to be applied in model script, should update documentation to reflect the usage.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usage documentation has been added (lines 24-63 in autosp_fusion.py)

Comment thread deepspeed/sequence/autosp_vit.py Outdated
across the sequence dimension. Each rank appends its local patches to the
same ``cls`` token before calling the wrapped attention.

Padding: when ``num_patches % world_size != 0``, we pad patches with zeros
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a test covering num_patches % world_size !=0?

Copy link
Copy Markdown
Contributor Author

@nathon-lee nathon-lee Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test coverage for test_noneven_patches has been added, with num_patches % world_size != 0

@delock
Copy link
Copy Markdown
Collaborator

delock commented Apr 28, 2026

Hi @nathon-lee thanks for your contribution! I have left my review comments, thanks!

Signed-off-by: nathon-lee <leejianwoo@gmail.com>
@nathon-lee
Copy link
Copy Markdown
Contributor Author

Hi @delock, thank you for the thorough review!

  1. Documentation: Absolutely agree — I'll add a concise usage example to
    ModalityFusionSPAdapter's docstring (and a brief note in the docs) showing
    how to apply the fusion adapters in a model training script. I'll include
    this in the current PR.

  2. num_patches % world_size != 0 test: Great catch! This edge case is
    currently untested. I'll add a dedicated test in test_autosp_equivalence.py
    to cover the non-divisible path and verify the padding/strip logic in
    UlyssesSPViTAttention.forward() is correct (the module docstring mentions
    it, but the implementation needs to be confirmed against the test).

I'll push the updates shortly — thanks again for your time!

…cs and tests

Signed-off-by: nathon-lee <leejianwoo@gmail.com>
Signed-off-by: nathon-lee <leejianwoo@gmail.com>
[AutoSP] Fix ViT CLS handling and skip incompatible HF LLM wrapping
Comment thread deepspeed/sequence/autosp_vit.py Outdated
for _ in range(self.world_size)
]
dist.all_gather(gathered, local_patches_padded.contiguous(), group=self.process_group)
full_patches = torch.cat(gathered, dim=1) # [bs, world_size * max_local_len, hidden_dim]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @nathon-lee , thanks for the fix of uneven divided patches. I have a follow up question. I saw paddings added will appear in full_patches thus will also appear in full_input. During attention computation of full_input, softmax might be affected by these padding patches. Should we 'unpad' the allgather result before compute full_input, or mask the padding during attention computation?

Copy link
Copy Markdown
Contributor Author

@nathon-lee nathon-lee Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this, @delock — much appreciated. You're right that the zero-padded tokens in full_patches would participate in the softmax and lead to divergence from single-device execution.

Fixed by de-padding before calling attention rather than masking. After all_gather, each shard is trimmed to its true length using per-rank lengths collected via a preceding scalar all_gather

De-pad gathered shards before calling attention so that dummy zero-tokens
never enter the softmax computation:
- all_gather each rank's exact local_patch_len into all_lens
- strip per-rank padding from gathered buffers before torch.cat
- update scatter offset to sum(all_lens[:rank]) instead of rank*max_local_len

All 6 TestViTSPEquivalence tests pass (including test_noneven_patches).

Signed-off-by: nathon-lee <leejianwoo@gmail.com>

fix some format err by tool

Signed-off-by: nathon-lee <leejianwoo@gmail.com>
[AutoSP] Fix padding-before-attention bug in UlyssesSPViTAttention
@nathon-lee nathon-lee force-pushed the multimodal-seq-parallel branch from 02262e4 to 462f169 Compare April 30, 2026 05:58
@delock delock merged commit 4e668fc into deepspeedai:master May 1, 2026
10 of 11 checks passed
@sfc-gh-truwase sfc-gh-truwase removed the request for review from GuanhuaWang May 5, 2026 10:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants