From 001f77c363710e3f62e05c5aacbed4b2ff7c8c97 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 27 Feb 2026 06:30:00 +0000 Subject: [PATCH 01/21] Initial plan From b90aee5a854d5d7b4d9e4c5c951b3c6d61a87c35 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 27 Feb 2026 06:36:18 +0000 Subject: [PATCH 02/21] Revert "fix: update 1 file reformatted." This reverts commit ff886701c392ab03863c227de14fbe1d671d4173. Co-authored-by: nathon-lee <248585198+nathon-lee@users.noreply.github.com> --- deepspeed/runtime/zero/stage_1_and_2.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 107e47a44042..183fd077f8a9 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -283,11 +283,18 @@ def _enforce_cpu_offload(): self.low_precision_master_weights_and_grads = self.master_weights_and_grads_dtype != torch.float32 + # Check for Muon optimizer usage + self.uses_muon = any(getattr(param, 'use_muon', False) for group in self.optimizer.param_groups for param in group['params']) + if self.reduce_scatter and self.partition_gradients: valid_reduce_scatter_dtypes = (torch.float16, torch.bfloat16, torch.float32) assert self.communication_data_type in valid_reduce_scatter_dtypes, f"{self.zero_stage_string} supports {valid_reduce_scatter_dtypes} communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'" assert self.gradient_predivide_factor == 1.0, f"gradient_predivide_factor != 1.0 is not yet supported with {self.zero_stage_string} with reduce scatter enabled" assert self.postscale_gradients, f"pre-scale gradients is not yet supported with {self.zero_stage_string} with reduce scatter enabled" + + # Check for Muon optimizer compatibility with reduce_scatter (applies to both ZeRO-1 and ZeRO-2) + if self.reduce_scatter and self.uses_muon: + assert False, f"{self.zero_stage_string} with reduce_scatter=True is incompatible with Muon optimizer. Please disable reduce_scatter or use a different optimizer." # param flattened by groups self.bit16_groups = [] @@ -1187,7 +1194,9 @@ def average_tensor(self, tensor: torch.Tensor, communication_data_type: torch.dt stream = get_accelerator().current_stream() with get_accelerator().stream(stream): - if not self.reduce_scatter: + # Check if current configuration requires full all-reduce + if not self.reduce_scatter or any(self.group_uses_muon): + # Force full all-reduce for Muon parameters or when reduce_scatter is disabled self.gradient_reduction_w_predivide(tensor, communication_data_type) return From cbc816c90f4bd6e10ab5b67f4d471002ade8cba7 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 6 Mar 2026 06:40:53 +0000 Subject: [PATCH 03/21] Initial plan From 5fcc9a7e4bf58b1d935dcfeab53143d3cf9dbdf7 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 6 Mar 2026 06:43:32 +0000 Subject: [PATCH 04/21] Reapply "fix: update 1 file reformatted." This reverts commit b90aee5a854d5d7b4d9e4c5c951b3c6d61a87c35. --- deepspeed/runtime/zero/stage_1_and_2.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 1efea00bcbbd..12f97348a21f 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -284,18 +284,11 @@ def _enforce_cpu_offload(): self.low_precision_master_weights_and_grads = self.master_weights_and_grads_dtype != torch.float32 - # Check for Muon optimizer usage - self.uses_muon = any(getattr(param, 'use_muon', False) for group in self.optimizer.param_groups for param in group['params']) - if self.reduce_scatter and self.partition_gradients: valid_reduce_scatter_dtypes = (torch.float16, torch.bfloat16, torch.float32) assert self.communication_data_type in valid_reduce_scatter_dtypes, f"{self.zero_stage_string} supports {valid_reduce_scatter_dtypes} communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'" assert self.gradient_predivide_factor == 1.0, f"gradient_predivide_factor != 1.0 is not yet supported with {self.zero_stage_string} with reduce scatter enabled" assert self.postscale_gradients, f"pre-scale gradients is not yet supported with {self.zero_stage_string} with reduce scatter enabled" - - # Check for Muon optimizer compatibility with reduce_scatter (applies to both ZeRO-1 and ZeRO-2) - if self.reduce_scatter and self.uses_muon: - assert False, f"{self.zero_stage_string} with reduce_scatter=True is incompatible with Muon optimizer. Please disable reduce_scatter or use a different optimizer." # param flattened by groups self.bit16_groups = [] @@ -1224,9 +1217,7 @@ def average_tensor(self, tensor: torch.Tensor, communication_data_type: torch.dt stream = get_accelerator().current_stream() with get_accelerator().stream(stream): - # Check if current configuration requires full all-reduce - if not self.reduce_scatter or any(self.group_uses_muon): - # Force full all-reduce for Muon parameters or when reduce_scatter is disabled + if not self.reduce_scatter: self.gradient_reduction_w_predivide(tensor, communication_data_type) return From a0788ea15385320881c3ecb69c0c46c85f1988aa Mon Sep 17 00:00:00 2001 From: nathon-lee Date: Thu, 23 Apr 2026 14:01:21 +0800 Subject: [PATCH 05/21] feat: Add AutoSP scaffolding for multimodal sequence parallelism Signed-off-by: nathon-lee --- deepspeed/sequence/__init__.py | 6 + deepspeed/sequence/auto_sp.py | 130 +++++++++++++++ deepspeed/sequence/autosp_detector.py | 103 ++++++++++++ deepspeed/sequence/autosp_fusion.py | 130 +++++++++++++++ deepspeed/sequence/autosp_vit.py | 143 ++++++++++++++++ deepspeed/sequence/test_autosp.py | 231 ++++++++++++++++++++++++++ 6 files changed, 743 insertions(+) create mode 100644 deepspeed/sequence/auto_sp.py create mode 100644 deepspeed/sequence/autosp_detector.py create mode 100644 deepspeed/sequence/autosp_fusion.py create mode 100644 deepspeed/sequence/autosp_vit.py create mode 100644 deepspeed/sequence/test_autosp.py diff --git a/deepspeed/sequence/__init__.py b/deepspeed/sequence/__init__.py index 208299fb8c50..1132bac74d3f 100644 --- a/deepspeed/sequence/__init__.py +++ b/deepspeed/sequence/__init__.py @@ -2,3 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team + + +from deepspeed.sequence.autosp_detector import detect_model_sp_info, SPModelInfo +from deepspeed.sequence.autosp_vit import UlyssesSPViTAttention +from deepspeed.sequence.autosp_fusion import ModalityFusionSPAdapter +from deepspeed.sequence.auto_sp import auto_wrap_model_for_sp \ No newline at end of file diff --git a/deepspeed/sequence/auto_sp.py b/deepspeed/sequence/auto_sp.py new file mode 100644 index 000000000000..ba53fbe620f7 --- /dev/null +++ b/deepspeed/sequence/auto_sp.py @@ -0,0 +1,130 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +AutoSP: one-call sequence parallelism for multimodal models. + +Usage:: + + from deepspeed.sequence.auto_sp import auto_wrap_model_for_sp + from deepspeed.utils import groups + + model, _, _, _ = deepspeed.initialize(config=ds_config, model=model, ...) + sp_group = groups._get_sequence_parallel_group() + model = auto_wrap_model_for_sp(model, process_group=sp_group) + +``auto_wrap_model_for_sp`` scans the model and injects: + +* :class:`~deepspeed.sequence.autosp_vit.UlyssesSPViTAttention` + for ViT encoder attention layers. +* :class:`~deepspeed.sequence.layer.DistributedAttention` + for LLM decoder attention layers (Megatron-style Q/K/V interface). + +The vision-language projection layer (Phase 2) is detected and a warning is +emitted; wrap it manually with +:class:`~deepspeed.sequence.autosp_fusion.ModalityFusionSPAdapter` until +Phase 2 automation is implemented. +""" + +import logging + +import torch.nn as nn + +from deepspeed.sequence.autosp_detector import detect_model_sp_info +from deepspeed.sequence.autosp_vit import UlyssesSPViTAttention +from deepspeed.sequence.layer import DistributedAttention + +logger = logging.getLogger(__name__) + + +def auto_wrap_model_for_sp(model: nn.Module, process_group) -> nn.Module: + """Inject sequence-parallel wrappers into *model* in-place. + + Scans the model's named modules and replaces recognised attention layers + with their SP-aware equivalents: + + * ViT attention → :class:`UlyssesSPViTAttention` + * LLM attention → :class:`DistributedAttention` + + The function modifies *model* in-place **and** returns it for convenience. + + Parameters + ---------- + model: + The multimodal model to wrap. Must be on the correct device before + calling this function. + process_group: + The sequence-parallel process group (from + ``groups._get_sequence_parallel_group()``). + + Returns + ------- + The same *model* object with attention layers replaced. + + Raises + ------ + ValueError + If no recognisable attention modules are found. Register the model's + attention class names in ``autosp_detector._VIT_ATTN_CLASSNAMES`` or + ``_LLM_ATTN_CLASSNAMES`` to fix this. + """ + info = detect_model_sp_info(model) + + if not info.vit_attn_modules and not info.llm_attn_modules: + raise ValueError( + "auto_wrap_model_for_sp: no recognisable attention modules found. " + "Add the model's attention class name(s) to " + "_VIT_ATTN_CLASSNAMES or _LLM_ATTN_CLASSNAMES in " + "deepspeed/sequence/autosp_detector.py and retry.") + + # ------------------------------------------------------------------ + # Wrap ViT encoder attention layers + # ------------------------------------------------------------------ + for name, module in info.vit_attn_modules: + wrapped = UlyssesSPViTAttention(module, process_group) + _set_module_by_name(model, name, wrapped) + logger.debug("AutoSP: wrapped ViT attention '%s' with UlyssesSPViTAttention", name) + + logger.info("AutoSP: wrapped %d ViT attention layer(s).", len(info.vit_attn_modules)) + + # ------------------------------------------------------------------ + # Wrap LLM decoder attention layers + # ------------------------------------------------------------------ + for name, module in info.llm_attn_modules: + # DistributedAttention wraps a Megatron-style attention that receives + # (query, key, value) tensors separately. For HuggingFace-style + # attention that receives hidden_states, use scatter_idx=2 / gather_idx=0 + # defaults which match the typical [bs, seq, heads, dim] layout. + wrapped = DistributedAttention(local_attention=module, sequence_process_group=process_group) + _set_module_by_name(model, name, wrapped) + logger.debug("AutoSP: wrapped LLM attention '%s' with DistributedAttention", name) + + logger.info("AutoSP: wrapped %d LLM attention layer(s).", len(info.llm_attn_modules)) + + # ------------------------------------------------------------------ + # Warn about the vision projection layer (Phase 2) + # ------------------------------------------------------------------ + if info.vision_projection_module is not None: + proj_name, _ = info.vision_projection_module + logger.warning( + "AutoSP detected vision projection layer '%s'. " + "ModalityFusionSPAdapter (Phase 2) is not yet automated. " + "Wrap this layer manually with ModalityFusionSPAdapter if you " + "need correct cross-modal sequence gather/scatter.", proj_name) + + return model + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _set_module_by_name(model: nn.Module, dotted_name: str, new_module: nn.Module) -> None: + """Replace the submodule at *dotted_name* with *new_module* in-place.""" + parts = dotted_name.split(".") + parent = model + for part in parts[:-1]: + parent = getattr(parent, part) + setattr(parent, parts[-1], new_module) \ No newline at end of file diff --git a/deepspeed/sequence/autosp_detector.py b/deepspeed/sequence/autosp_detector.py new file mode 100644 index 000000000000..b2383d5a1a10 --- /dev/null +++ b/deepspeed/sequence/autosp_detector.py @@ -0,0 +1,103 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Automatically detect ViT encoder and LLM decoder attention modules in +multimodal models to guide AutoSP injection. + +Extend _VIT_ATTN_CLASSNAMES / _LLM_ATTN_CLASSNAMES when adding support for +new model architectures. +""" + +import torch.nn as nn +from dataclasses import dataclass, field +from typing import List, Optional, Tuple + +# --------------------------------------------------------------------------- +# Architecture registry +# --------------------------------------------------------------------------- + +# Known ViT attention class names (HuggingFace transformers naming) +_VIT_ATTN_CLASSNAMES = { + "ViTAttention", + "CLIPAttention", + "SiglipAttention", + "InternVisionAttention", + "Qwen2VLVisionAttention", + "Idefics2VisionAttention", + "PaliGemmaVisionAttention", +} + +# Known LLM decoder attention class names +_LLM_ATTN_CLASSNAMES = { + "LlamaAttention", + "MistralAttention", + "Qwen2Attention", + "InternLM2Attention", + "GemmaAttention", + "Phi3Attention", + "GPTNeoXAttention", + "FalconAttention", + "MptAttention", +} + +# Common attribute names that hold the vision-language projection layer +_VISION_PROJ_KEYWORDS = ( + "visual_projection", + "mm_projector", + "vision_proj", + "multi_modal_projector", + "img_projection", +) + + +# --------------------------------------------------------------------------- +# Data structures +# --------------------------------------------------------------------------- + + +@dataclass +class SPModelInfo: + """Holds the detection results for a multimodal model.""" + + # (dotted_name, module) pairs for ViT attention layers + vit_attn_modules: List[Tuple[str, nn.Module]] = field(default_factory=list) + # (dotted_name, module) pairs for LLM decoder attention layers + llm_attn_modules: List[Tuple[str, nn.Module]] = field(default_factory=list) + # (dotted_name, module) for the outermost vision-language projection layer + vision_projection_module: Optional[Tuple[str, nn.Module]] = None + + +# --------------------------------------------------------------------------- +# Detection logic +# --------------------------------------------------------------------------- + + +def detect_model_sp_info(model: nn.Module) -> SPModelInfo: + """Recursively scan *model* and return an :class:`SPModelInfo`. + + The function identifies: + * ViT encoder attention layers → wrapped with :class:`UlyssesSPViTAttention` + * LLM decoder attention layers → wrapped with :class:`DistributedAttention` + * The vision-language projection layer → wrapped with + :class:`ModalityFusionSPAdapter` (Phase 2) + + To add support for a new architecture, simply register its attention class + names in ``_VIT_ATTN_CLASSNAMES`` or ``_LLM_ATTN_CLASSNAMES``. + """ + info = SPModelInfo() + for name, module in model.named_modules(): + cls_name = type(module).__name__ + if cls_name in _VIT_ATTN_CLASSNAMES: + info.vit_attn_modules.append((name, module)) + elif cls_name in _LLM_ATTN_CLASSNAMES: + info.llm_attn_modules.append((name, module)) + + # Record only the first (outermost) match to avoid double-wrapping + # nested projection modules. + if info.vision_projection_module is None: + if any(kw in name for kw in _VISION_PROJ_KEYWORDS): + info.vision_projection_module = (name, module) + + return info \ No newline at end of file diff --git a/deepspeed/sequence/autosp_fusion.py b/deepspeed/sequence/autosp_fusion.py new file mode 100644 index 000000000000..a15dc8a28520 --- /dev/null +++ b/deepspeed/sequence/autosp_fusion.py @@ -0,0 +1,130 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +ModalityFusionSPAdapter — Phase 2 + +Handles the sequence scatter/gather at the vision-language boundary so that +the LLM decoder's :class:`~deepspeed.sequence.layer.DistributedAttention` +receives a uniformly sharded fused (visual + text) sequence. + +Workflow +-------- +:: + + [visual tokens, sharded] ──all-gather──► [visual tokens, full] + │ + splice into text + │ + [fused embeds, full] ──scatter──► [fused embeds, sharded per rank] + │ + LLM decoder (SP-aware) + +Status: Phase 2. ``_splice_visual_into_text`` is intentionally left as a +``NotImplementedError``; override it per model architecture (see docstring). +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import deepspeed.comm as dist + +# Default image placeholder token ID used by LLaVA-style models. +_DEFAULT_IMAGE_TOKEN_ID = -200 + + +class ModalityFusionSPAdapter(nn.Module): + """Wraps the vision projection layer and handles cross-modal sequence fusion. + + After projecting visual features, this adapter: + + 1. Gathers the sharded visual token slices from all SP ranks into a single + full visual token tensor. + 2. Splices the visual tokens into the text embedding sequence at the + positions marked by ``image_token_id`` placeholders. + 3. Pads and re-shards the fused sequence so that the subsequent LLM + decoder layers receive uniformly distributed sequence slices. + + Parameters + ---------- + projection: + The vision projection module (e.g. ``mm_projector``). + process_group: + The sequence-parallel process group. + image_token_id: + The token ID used as an image placeholder in the input IDs tensor. + Defaults to ``-200`` (LLaVA convention). + + Notes + ----- + Subclass this and override :meth:`_splice_visual_into_text` to adapt to a + specific multimodal architecture (LLaVA, InternVL, Qwen-VL, …). + """ + + def __init__(self, projection: nn.Module, process_group, image_token_id: int = _DEFAULT_IMAGE_TOKEN_ID) -> None: + super().__init__() + self.projection = projection + self.process_group = process_group + self.world_size = dist.get_world_size(process_group) + self.image_token_id = image_token_id + + def forward(self, visual_features: torch.Tensor, text_embeds: torch.Tensor, + input_ids: torch.Tensor) -> torch.Tensor: + """Project visual features and return a sharded fused embedding. + + Parameters + ---------- + visual_features: + Raw visual features from the ViT encoder. + Shape: ``[bs, local_visual_tokens, vit_hidden]``. + text_embeds: + Full text token embeddings (not sharded yet). + Shape: ``[bs, text_seq_len, lm_hidden]``. + input_ids: + Token IDs used to locate image placeholder positions. + Shape: ``[bs, text_seq_len]``. + + Returns + ------- + Sharded fused embedding for this rank. + Shape: ``[bs, local_fused_len, lm_hidden]``. + """ + # 1. Project visual features to the LLM hidden dimension + visual_embeds = self.projection(visual_features) # [bs, local_v, lm_hidden] + + # 2. All-gather visual slices from all SP ranks + parts = [torch.zeros_like(visual_embeds) for _ in range(self.world_size)] + dist.all_gather(parts, visual_embeds.contiguous(), group=self.process_group) + full_visual = torch.cat(parts, dim=1) # [bs, total_visual_tokens, lm_hidden] + + # 3. Splice visual tokens into text embedding sequence + fused = self._splice_visual_into_text(text_embeds, full_visual, input_ids) # [bs, fused_len, lm_hidden] + + # 4. Pad fused length to be divisible by world_size, then scatter + total_len = fused.shape[1] + pad = (self.world_size - total_len % self.world_size) % self.world_size + if pad > 0: + fused = F.pad(fused, (0, 0, 0, pad)) + + rank = dist.get_rank(self.process_group) + local_len = fused.shape[1] // self.world_size + return fused[:, rank * local_len:(rank + 1) * local_len, :].contiguous() + + def _splice_visual_into_text(self, text_embeds: torch.Tensor, visual_embeds: torch.Tensor, + input_ids: torch.Tensor) -> torch.Tensor: + """Replace image placeholder positions in *text_embeds* with *visual_embeds*. + + This is intentionally architecture-specific. The default raises + ``NotImplementedError``; override this method for each supported model. + + Reference implementations: + * LLaVA: ``LlavaMetaForCausalLM.prepare_inputs_embeds`` + * InternVL: ``InternVLChatModel.extract_feature`` + * Qwen-VL: ``Qwen2VLForConditionalGeneration.get_rope_index`` + """ + raise NotImplementedError( + f"{type(self).__name__}._splice_visual_into_text is not implemented. " + "Subclass ModalityFusionSPAdapter and override this method to match " + "your model's prepare_inputs_embeds logic.") \ No newline at end of file diff --git a/deepspeed/sequence/autosp_vit.py b/deepspeed/sequence/autosp_vit.py new file mode 100644 index 000000000000..7d5275669805 --- /dev/null +++ b/deepspeed/sequence/autosp_vit.py @@ -0,0 +1,143 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Ulysses-style sequence-parallel wrapper for ViT encoder attention layers. + +Design notes +------------ +ViT self-attention is non-causal: every patch token attends to every other +patch token. This means a straightforward per-rank local attention (as used +for causal LLMs) would be *incorrect* — each rank must have access to the +full key/value context. + +We therefore use a **gather-compute-scatter** pattern: + +1. Input arrives already sharded along the sequence dimension (each rank owns + ``local_patches = num_patches // world_size`` consecutive patches). +2. Before attention we **all-gather** patch tokens so that every rank runs the + full ViT attention over the complete patch sequence. This keeps the + computation equivalent to single-device execution. +3. The output is **scattered** back so that each rank returns only its local + slice, matching the sharded input contract expected by downstream layers. + +Memory benefit: activations *outside* the attention block (e.g. feed-forward +layers, layer norms) are stored only locally, reducing per-rank memory +proportional to ``world_size``. + +The ``cls`` token (if present) is replicated on every rank and is not split +across the sequence dimension. Each rank appends its local patches to the +same ``cls`` token before calling the wrapped attention. + +Padding: when ``num_patches % world_size != 0``, we pad patches with zeros +before scattering and strip the padding after gathering. Padding tokens do +not carry gradients and are never passed to downstream layers. +""" + +import torch +import torch.nn as nn + +import deepspeed.comm as dist + + +class UlyssesSPViTAttention(nn.Module): + """Sequence-parallel wrapper for an opaque ViT attention module. + + Parameters + ---------- + attn: + The original ViT attention layer (any ``nn.Module`` that maps + ``hidden_states`` → ``hidden_states`` or a tuple whose first element + is the attention output tensor). + process_group: + The sequence-parallel process group. + has_cls_token: + Set to ``True`` (default) when the first token in the sequence is a + ``[CLS]`` token that should be replicated on every rank rather than + sharded. + """ + + def __init__(self, attn: nn.Module, process_group, has_cls_token: bool = True) -> None: + super().__init__() + self.attn = attn + self.process_group = process_group + self.world_size = dist.get_world_size(process_group) + self.has_cls_token = has_cls_token + + # ------------------------------------------------------------------ + # forward + # ------------------------------------------------------------------ + + def forward(self, hidden_states: torch.Tensor, **kwargs): + """ + Parameters + ---------- + hidden_states: + Shape ``[bs, local_seq_len, hidden_dim]`` where + ``local_seq_len = (1 + local_patches)`` if ``has_cls_token`` else + ``local_patches``. Each rank holds a contiguous slice of patches. + **kwargs: + Passed through to the wrapped attention (e.g. ``attention_mask``, + ``head_mask``, ``output_attentions``). + + Returns + ------- + Same shape as input (or a tuple whose first element matches the input + shape, preserving whatever the wrapped module returns). + """ + bs, local_seq_len, hidden_dim = hidden_states.shape + + if self.has_cls_token: + # CLS token is replicated on every rank — not part of the sharded seq + cls_token = hidden_states[:, :1, :] + local_patches = hidden_states[:, 1:, :] + else: + local_patches = hidden_states + + local_patch_len = local_patches.shape[1] + + # ------------------------------------------------------------------- + # 1. All-gather patches from all ranks to reconstruct the full sequence + # ------------------------------------------------------------------- + # We need to all-gather so every rank sees the full K/V context. + gathered = [torch.zeros_like(local_patches) for _ in range(self.world_size)] + dist.all_gather(gathered, local_patches.contiguous(), group=self.process_group) + full_patches = torch.cat(gathered, dim=1) # [bs, num_patches_padded, hidden_dim] + + # ------------------------------------------------------------------- + # 2. Build the full input (prepend CLS if needed) and call attention + # ------------------------------------------------------------------- + if self.has_cls_token: + full_input = torch.cat([cls_token, full_patches], dim=1) + else: + full_input = full_patches + + attn_out = self.attn(full_input, **kwargs) + + # Unwrap tuple: some ViT implementations return (attn_output, attn_weights) + if isinstance(attn_out, (tuple, list)): + full_out, *extra = attn_out + else: + full_out = attn_out + extra = [] + + # ------------------------------------------------------------------- + # 3. Scatter output: each rank keeps only its local slice of patches + # ------------------------------------------------------------------- + if self.has_cls_token: + cls_out = full_out[:, :1, :] + patch_out = full_out[:, 1:, :] + else: + patch_out = full_out + + # Determine this rank's slice boundaries + rank = dist.get_rank(self.process_group) + local_out = patch_out[:, rank * local_patch_len:(rank + 1) * local_patch_len, :].contiguous() + + if self.has_cls_token: + local_out = torch.cat([cls_out, local_out], dim=1) + + if extra: + return (local_out, *extra) + return local_out \ No newline at end of file diff --git a/deepspeed/sequence/test_autosp.py b/deepspeed/sequence/test_autosp.py new file mode 100644 index 000000000000..2f1fe52b3386 --- /dev/null +++ b/deepspeed/sequence/test_autosp.py @@ -0,0 +1,231 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Unit tests for AutoSP multimodal sequence parallelism: + - autosp_detector: model scanning + - UlyssesSPViTAttention: ViT SP wrapper + - auto_wrap_model_for_sp: end-to-end wrapping +""" + +import pytest +import torch +import torch.nn as nn + +from deepspeed.sequence.autosp_detector import (SPModelInfo, _LLM_ATTN_CLASSNAMES, _VIT_ATTN_CLASSNAMES, + detect_model_sp_info) +from deepspeed.sequence.autosp_vit import UlyssesSPViTAttention +from deepspeed.sequence.auto_sp import _set_module_by_name, auto_wrap_model_for_sp + + +# --------------------------------------------------------------------------- +# Minimal fake modules that mimic the interface of real attention layers +# without requiring a GPU or a real transformer model. +# --------------------------------------------------------------------------- + + +class _FakeViTAttn(nn.Module): + """Identity ViT attention — returns hidden_states unchanged.""" + + def forward(self, hidden_states, **kwargs): + return hidden_states + + +class _FakeViTAttnTuple(nn.Module): + """ViT attention that returns a (output, weights) tuple.""" + + def forward(self, hidden_states, **kwargs): + weights = torch.zeros(hidden_states.shape[0], 1, hidden_states.shape[1], hidden_states.shape[1]) + return hidden_states, weights + + +class _FakeLLMAttn(nn.Module): + """Identity LLM attention.""" + + def forward(self, query, key, value, *args, **kwargs): + return query + + +# Register fake class names so the detector recognises them +_VIT_ATTN_CLASSNAMES.add("_FakeViTAttn") +_VIT_ATTN_CLASSNAMES.add("_FakeViTAttnTuple") +_LLM_ATTN_CLASSNAMES.add("_FakeLLMAttn") + + +class _FakeMultimodalModel(nn.Module): + """Minimal multimodal model with one ViT and one LLM attention layer.""" + + def __init__(self): + super().__init__() + self.vision_encoder = nn.ModuleList([_FakeViTAttn()]) + self.mm_projector = nn.Linear(64, 64) + self.llm = nn.ModuleList([_FakeLLMAttn()]) + + +class _FakeViTOnlyModel(nn.Module): + + def __init__(self, num_layers=3): + super().__init__() + self.layers = nn.ModuleList([_FakeViTAttn() for _ in range(num_layers)]) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_mock_process_group(world_size: int, rank: int): + """Return a mock object that satisfies dist.get_world_size / get_rank.""" + import unittest.mock as mock + import deepspeed.comm as dist + + pg = mock.MagicMock() + dist.get_world_size = mock.MagicMock(return_value=world_size) + dist.get_rank = mock.MagicMock(return_value=rank) + + def _fake_all_gather(tensor_list, tensor, group=None): + for t in tensor_list: + t.copy_(tensor) + + dist.all_gather = _fake_all_gather + return pg + + +# --------------------------------------------------------------------------- +# autosp_detector tests +# --------------------------------------------------------------------------- + + +class TestAutospDetector: + + def test_detects_vit_and_llm(self): + model = _FakeMultimodalModel() + info = detect_model_sp_info(model) + assert len(info.vit_attn_modules) == 1 + assert len(info.llm_attn_modules) == 1 + + def test_detects_vision_projection(self): + model = _FakeMultimodalModel() + info = detect_model_sp_info(model) + assert info.vision_projection_module is not None + name, module = info.vision_projection_module + assert "mm_projector" in name + + def test_detects_multiple_vit_layers(self): + model = _FakeViTOnlyModel(num_layers=4) + info = detect_model_sp_info(model) + assert len(info.vit_attn_modules) == 4 + assert len(info.llm_attn_modules) == 0 + assert info.vision_projection_module is None + + def test_empty_model_returns_empty_info(self): + model = nn.Sequential(nn.Linear(8, 8)) + info = detect_model_sp_info(model) + assert isinstance(info, SPModelInfo) + assert len(info.vit_attn_modules) == 0 + assert len(info.llm_attn_modules) == 0 + + def test_only_first_projection_is_recorded(self): + """Multiple projection-like names → only the outermost is recorded.""" + + class _M(nn.Module): + + def __init__(self): + super().__init__() + self.mm_projector = nn.Sequential(nn.Linear(8, 8)) + self.mm_projector.visual_projection = nn.Linear(8, 8) + + model = _M() + info = detect_model_sp_info(model) + assert info.vision_projection_module is not None + # Should be the outermost "mm_projector", not the nested one + name, _ = info.vision_projection_module + assert name == "mm_projector" + + +# --------------------------------------------------------------------------- +# UlyssesSPViTAttention tests (CPU, rank-0 simulation via mocks) +# --------------------------------------------------------------------------- + + +class TestUlyssesSPViTAttention: + + @pytest.mark.parametrize("has_cls_token", [True, False]) + @pytest.mark.parametrize("num_patches,world_size", [ + (16, 4), + (16, 2), + (9, 3), + ]) + def test_output_shape_matches_input(self, has_cls_token, num_patches, world_size): + """Output shape must equal input shape for any padding scenario.""" + pg = _make_mock_process_group(world_size=world_size, rank=0) + attn = _FakeViTAttn() + wrapper = UlyssesSPViTAttention(attn, pg, has_cls_token=has_cls_token) + + local_patches = num_patches // world_size + seq_len = (1 + local_patches) if has_cls_token else local_patches + x = torch.randn(2, seq_len, 32) + + out = wrapper(x) + assert out.shape == x.shape, f"Expected {x.shape}, got {out.shape}" + + def test_tuple_output_unwrapped_correctly(self): + """Wrappers that return (output, weights) tuples are handled.""" + pg = _make_mock_process_group(world_size=2, rank=0) + attn = _FakeViTAttnTuple() + wrapper = UlyssesSPViTAttention(attn, pg, has_cls_token=False) + + x = torch.randn(1, 8, 16) # 8 patches, 2 ranks → 4 local each + result = wrapper(x) + # Should return a tuple: (attention_output, attention_weights) + assert isinstance(result, tuple) + assert result[0].shape == x.shape + + def test_identity_attn_preserves_values(self): + """When attn is identity, output values should match input values.""" + world_size = 2 + pg = _make_mock_process_group(world_size=world_size, rank=0) + attn = _FakeViTAttn() + wrapper = UlyssesSPViTAttention(attn, pg, has_cls_token=True) + + # Each rank holds cls + 4 local patches + x = torch.arange(2 * 5 * 4, dtype=torch.float).reshape(2, 5, 4) + out = wrapper(x) + # CLS token should be identical + assert torch.allclose(out[:, :1, :], x[:, :1, :]) + # Local patch slice should match input patches for identity attn + assert torch.allclose(out[:, 1:, :], x[:, 1:, :]) + + +# --------------------------------------------------------------------------- +# auto_wrap_model_for_sp tests +# --------------------------------------------------------------------------- + + +class TestAutoWrapModelForSP: + + def test_vit_layers_replaced(self): + pg = _make_mock_process_group(world_size=2, rank=0) + model = _FakeViTOnlyModel(num_layers=2) + auto_wrap_model_for_sp(model, pg) + for layer in model.layers: + assert isinstance(layer, UlyssesSPViTAttention) + + def test_raises_on_unknown_model(self): + pg = _make_mock_process_group(world_size=2, rank=0) + model = nn.Sequential(nn.Linear(8, 8)) + with pytest.raises(ValueError, match="no recognisable attention"): + auto_wrap_model_for_sp(model, pg) + + def test_set_module_by_name_shallow(self): + model = _FakeViTOnlyModel(num_layers=1) + new_mod = nn.Linear(4, 4) + _set_module_by_name(model, "layers.0", new_mod) + assert model.layers[0] is new_mod + + def test_set_module_by_name_deep(self): + model = _FakeMultimodalModel() + new_mod = nn.Identity() + _set_module_by_name(model, "vision_encoder.0", new_mod) + assert model.vision_encoder[0] is new_mod \ No newline at end of file From 0c6c14c999d5ef7b7165c07d2f034017dc15d550 Mon Sep 17 00:00:00 2001 From: nathon-lee Date: Thu, 23 Apr 2026 06:21:38 +0000 Subject: [PATCH 06/21] fix: fix some format issue by pre-commit Signed-off-by: nathon-lee --- deepspeed/sequence/__init__.py | 3 +-- deepspeed/sequence/auto_sp.py | 11 +++++------ deepspeed/sequence/autosp_detector.py | 3 +-- deepspeed/sequence/autosp_fusion.py | 9 ++++----- deepspeed/sequence/autosp_vit.py | 2 +- deepspeed/sequence/test_autosp.py | 3 +-- 6 files changed, 13 insertions(+), 18 deletions(-) diff --git a/deepspeed/sequence/__init__.py b/deepspeed/sequence/__init__.py index 1132bac74d3f..e1b8447c7b8f 100644 --- a/deepspeed/sequence/__init__.py +++ b/deepspeed/sequence/__init__.py @@ -3,8 +3,7 @@ # DeepSpeed Team - from deepspeed.sequence.autosp_detector import detect_model_sp_info, SPModelInfo from deepspeed.sequence.autosp_vit import UlyssesSPViTAttention from deepspeed.sequence.autosp_fusion import ModalityFusionSPAdapter -from deepspeed.sequence.auto_sp import auto_wrap_model_for_sp \ No newline at end of file +from deepspeed.sequence.auto_sp import auto_wrap_model_for_sp diff --git a/deepspeed/sequence/auto_sp.py b/deepspeed/sequence/auto_sp.py index ba53fbe620f7..d1315dc1259a 100644 --- a/deepspeed/sequence/auto_sp.py +++ b/deepspeed/sequence/auto_sp.py @@ -72,11 +72,10 @@ def auto_wrap_model_for_sp(model: nn.Module, process_group) -> nn.Module: info = detect_model_sp_info(model) if not info.vit_attn_modules and not info.llm_attn_modules: - raise ValueError( - "auto_wrap_model_for_sp: no recognisable attention modules found. " - "Add the model's attention class name(s) to " - "_VIT_ATTN_CLASSNAMES or _LLM_ATTN_CLASSNAMES in " - "deepspeed/sequence/autosp_detector.py and retry.") + raise ValueError("auto_wrap_model_for_sp: no recognisable attention modules found. " + "Add the model's attention class name(s) to " + "_VIT_ATTN_CLASSNAMES or _LLM_ATTN_CLASSNAMES in " + "deepspeed/sequence/autosp_detector.py and retry.") # ------------------------------------------------------------------ # Wrap ViT encoder attention layers @@ -127,4 +126,4 @@ def _set_module_by_name(model: nn.Module, dotted_name: str, new_module: nn.Modul parent = model for part in parts[:-1]: parent = getattr(parent, part) - setattr(parent, parts[-1], new_module) \ No newline at end of file + setattr(parent, parts[-1], new_module) diff --git a/deepspeed/sequence/autosp_detector.py b/deepspeed/sequence/autosp_detector.py index b2383d5a1a10..2bd423b67d71 100644 --- a/deepspeed/sequence/autosp_detector.py +++ b/deepspeed/sequence/autosp_detector.py @@ -51,7 +51,6 @@ "img_projection", ) - # --------------------------------------------------------------------------- # Data structures # --------------------------------------------------------------------------- @@ -100,4 +99,4 @@ def detect_model_sp_info(model: nn.Module) -> SPModelInfo: if any(kw in name for kw in _VISION_PROJ_KEYWORDS): info.vision_projection_module = (name, module) - return info \ No newline at end of file + return info diff --git a/deepspeed/sequence/autosp_fusion.py b/deepspeed/sequence/autosp_fusion.py index a15dc8a28520..2e59d86d1e99 100644 --- a/deepspeed/sequence/autosp_fusion.py +++ b/deepspeed/sequence/autosp_fusion.py @@ -113,7 +113,7 @@ def forward(self, visual_features: torch.Tensor, text_embeds: torch.Tensor, return fused[:, rank * local_len:(rank + 1) * local_len, :].contiguous() def _splice_visual_into_text(self, text_embeds: torch.Tensor, visual_embeds: torch.Tensor, - input_ids: torch.Tensor) -> torch.Tensor: + input_ids: torch.Tensor) -> torch.Tensor: """Replace image placeholder positions in *text_embeds* with *visual_embeds*. This is intentionally architecture-specific. The default raises @@ -124,7 +124,6 @@ def _splice_visual_into_text(self, text_embeds: torch.Tensor, visual_embeds: tor * InternVL: ``InternVLChatModel.extract_feature`` * Qwen-VL: ``Qwen2VLForConditionalGeneration.get_rope_index`` """ - raise NotImplementedError( - f"{type(self).__name__}._splice_visual_into_text is not implemented. " - "Subclass ModalityFusionSPAdapter and override this method to match " - "your model's prepare_inputs_embeds logic.") \ No newline at end of file + raise NotImplementedError(f"{type(self).__name__}._splice_visual_into_text is not implemented. " + "Subclass ModalityFusionSPAdapter and override this method to match " + "your model's prepare_inputs_embeds logic.") diff --git a/deepspeed/sequence/autosp_vit.py b/deepspeed/sequence/autosp_vit.py index 7d5275669805..85c2ea593423 100644 --- a/deepspeed/sequence/autosp_vit.py +++ b/deepspeed/sequence/autosp_vit.py @@ -140,4 +140,4 @@ def forward(self, hidden_states: torch.Tensor, **kwargs): if extra: return (local_out, *extra) - return local_out \ No newline at end of file + return local_out diff --git a/deepspeed/sequence/test_autosp.py b/deepspeed/sequence/test_autosp.py index 2f1fe52b3386..314676e1d5d6 100644 --- a/deepspeed/sequence/test_autosp.py +++ b/deepspeed/sequence/test_autosp.py @@ -18,7 +18,6 @@ from deepspeed.sequence.autosp_vit import UlyssesSPViTAttention from deepspeed.sequence.auto_sp import _set_module_by_name, auto_wrap_model_for_sp - # --------------------------------------------------------------------------- # Minimal fake modules that mimic the interface of real attention layers # without requiring a GPU or a real transformer model. @@ -228,4 +227,4 @@ def test_set_module_by_name_deep(self): model = _FakeMultimodalModel() new_mod = nn.Identity() _set_module_by_name(model, "vision_encoder.0", new_mod) - assert model.vision_encoder[0] is new_mod \ No newline at end of file + assert model.vision_encoder[0] is new_mod From ac71ca96141ca03b5a1aa31c3548ede439cdf86d Mon Sep 17 00:00:00 2001 From: nathon-lee Date: Thu, 23 Apr 2026 14:47:05 +0800 Subject: [PATCH 07/21] tests: Add tests for ModalityFusionSPAdapter and LLM wrap path Signed-off-by: nathon-lee --- deepspeed/sequence/test_autosp.py | 154 ++++++++++++++++++++++++++++++ 1 file changed, 154 insertions(+) diff --git a/deepspeed/sequence/test_autosp.py b/deepspeed/sequence/test_autosp.py index 314676e1d5d6..6822faaf7131 100644 --- a/deepspeed/sequence/test_autosp.py +++ b/deepspeed/sequence/test_autosp.py @@ -7,6 +7,7 @@ - autosp_detector: model scanning - UlyssesSPViTAttention: ViT SP wrapper - auto_wrap_model_for_sp: end-to-end wrapping + - ModalityFusionSPAdapter: cross-modal gather/scatter """ import pytest @@ -15,8 +16,10 @@ from deepspeed.sequence.autosp_detector import (SPModelInfo, _LLM_ATTN_CLASSNAMES, _VIT_ATTN_CLASSNAMES, detect_model_sp_info) +from deepspeed.sequence.autosp_fusion import ModalityFusionSPAdapter from deepspeed.sequence.autosp_vit import UlyssesSPViTAttention from deepspeed.sequence.auto_sp import _set_module_by_name, auto_wrap_model_for_sp +from deepspeed.sequence.layer import DistributedAttention # --------------------------------------------------------------------------- # Minimal fake modules that mimic the interface of real attention layers @@ -69,6 +72,14 @@ def __init__(self, num_layers=3): self.layers = nn.ModuleList([_FakeViTAttn() for _ in range(num_layers)]) +class _FakeLLMOnlyModel(nn.Module): + """Minimal LLM-only model with multiple decoder attention layers.""" + + def __init__(self, num_layers=2): + super().__init__() + self.layers = nn.ModuleList([_FakeLLMAttn() for _ in range(num_layers)]) + + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -228,3 +239,146 @@ def test_set_module_by_name_deep(self): new_mod = nn.Identity() _set_module_by_name(model, "vision_encoder.0", new_mod) assert model.vision_encoder[0] is new_mod + + def test_llm_layers_replaced_with_distributed_attention(self): + """LLM attention layers must be wrapped with DistributedAttention.""" + pg = _make_mock_process_group(world_size=2, rank=0) + model = _FakeLLMOnlyModel(num_layers=3) + auto_wrap_model_for_sp(model, pg) + for layer in model.layers: + assert isinstance(layer, DistributedAttention) + + def test_multimodal_model_wraps_both_branches(self): + """Both ViT and LLM attention layers must be replaced in a combined model.""" + pg = _make_mock_process_group(world_size=2, rank=0) + model = _FakeMultimodalModel() + returned = auto_wrap_model_for_sp(model, pg) + # auto_wrap_model_for_sp must return the same object (in-place) + assert returned is model + assert isinstance(model.vision_encoder[0], UlyssesSPViTAttention) + assert isinstance(model.llm[0], DistributedAttention) + + def test_original_module_preserved_inside_wrapper(self): + """The wrapped module should still be accessible inside the wrapper.""" + pg = _make_mock_process_group(world_size=2, rank=0) + model = _FakeViTOnlyModel(num_layers=1) + original_attn = model.layers[0] + auto_wrap_model_for_sp(model, pg) + assert model.layers[0].attn is original_attn + + +# --------------------------------------------------------------------------- +# ModalityFusionSPAdapter tests +# --------------------------------------------------------------------------- + +class _ConcatFusionAdapter(ModalityFusionSPAdapter): + """Concrete subclass that appends visual tokens after text tokens.""" + + def _splice_visual_into_text(self, text_embeds, visual_embeds, input_ids): + return torch.cat([text_embeds, visual_embeds], dim=1) + + +class TestModalityFusionSPAdapter: + + def test_base_class_raises_not_implemented(self): + """The base _splice_visual_into_text must raise NotImplementedError.""" + pg = _make_mock_process_group(world_size=2, rank=0) + adapter = ModalityFusionSPAdapter(nn.Identity(), pg) + with pytest.raises(NotImplementedError): + adapter._splice_visual_into_text(None, None, None) + + @pytest.mark.parametrize("world_size,local_v,text_len,hidden", [ + (2, 4, 6, 8), + (4, 3, 5, 16), + (1, 8, 8, 4), + ]) + def test_output_shape(self, world_size, local_v, text_len, hidden): + """Output local_len must equal ceil(fused_len / world_size).""" + pg = _make_mock_process_group(world_size=world_size, rank=0) + adapter = _ConcatFusionAdapter(nn.Identity(), pg) + + bs = 2 + visual = torch.randn(bs, local_v, hidden) + text = torch.randn(bs, text_len, hidden) + ids = torch.zeros(bs, text_len, dtype=torch.long) + + out = adapter(visual, text, ids) + + # all_gather mock copies local_v to each of world_size slots + fused_len = text_len + local_v * world_size + pad = (world_size - fused_len % world_size) % world_size + expected_local = (fused_len + pad) // world_size + assert out.shape == (bs, expected_local, hidden), f"Expected ({bs},{expected_local},{hidden}), got {out.shape}" + + def test_padding_produces_valid_output_when_not_divisible(self): + """When fused_len % world_size != 0, padding must not raise and output is well-formed.""" + world_size = 4 + # text_len=5, local_v=3 → fused_len = 5 + 3*4 = 17, needs padding of 3 + pg = _make_mock_process_group(world_size=world_size, rank=0) + adapter = _ConcatFusionAdapter(nn.Identity(), pg) + + bs, local_v, text_len, hidden = 1, 3, 5, 4 + out = adapter( + torch.randn(bs, local_v, hidden), + torch.randn(bs, text_len, hidden), + torch.zeros(bs, text_len, dtype=torch.long), + ) + # padded_len = 20, local_len = 5 + assert out.shape == (bs, 5, hidden) + + def test_no_padding_when_divisible(self): + """When fused_len is already divisible, no extra tokens should be added.""" + world_size = 4 + # text_len=4, local_v=4 → fused_len = 4 + 4*4 = 20, divisible by 4 + pg = _make_mock_process_group(world_size=world_size, rank=0) + adapter = _ConcatFusionAdapter(nn.Identity(), pg) + + bs, local_v, text_len, hidden = 1, 4, 4, 8 + out = adapter( + torch.randn(bs, local_v, hidden), + torch.randn(bs, text_len, hidden), + torch.zeros(bs, text_len, dtype=torch.long), + ) + assert out.shape == (bs, 5, hidden) # 20 // 4 = 5 + + def test_different_ranks_return_different_slices(self): + """Rank 0 and rank 1 must return different slices of the fused sequence.""" + world_size = 2 + bs, local_v, text_len, hidden = 1, 4, 4, 8 + # Use distinct text vs visual values so slices clearly differ + text = torch.zeros(bs, text_len, hidden) + visual = torch.ones(bs, local_v, hidden) + ids = torch.zeros(bs, text_len, dtype=torch.long) + + outputs = {} + for rank in range(world_size): + pg = _make_mock_process_group(world_size=world_size, rank=rank) + adapter = _ConcatFusionAdapter(nn.Identity(), pg) + outputs[rank] = adapter(visual.clone(), text.clone(), ids.clone()) + + # fused = [0,0,0,0, 1,1,1,1, 1,1,1,1] (text zeros then visual ones x2) + # rank 0: indices 0-5, rank 1: indices 6-11 + assert not torch.allclose(outputs[0], outputs[1]) + + def test_projection_is_applied(self): + """Projection layer must transform visual features before gather.""" + world_size = 2 + pg = _make_mock_process_group(world_size=world_size, rank=0) + + # Use a projection that doubles all values + class _DoubleProjection(nn.Module): + def forward(self, x): + return x * 2.0 + + adapter = _ConcatFusionAdapter(_DoubleProjection(), pg) + bs, local_v, text_len, hidden = 1, 4, 4, 8 + visual = torch.ones(bs, local_v, hidden) + text = torch.zeros(bs, text_len, hidden) + ids = torch.zeros(bs, text_len, dtype=torch.long) + + out = adapter(visual, text, ids) + # The visual part of the output should have value 2.0 (doubled), not 1.0 + # rank 0 gets the first local_len tokens; fused = [text(0)*4, visual(2)*8] + # Since text_len=4 and local_len=6, rank0 slice starts with text zeros + # and ends with some visual twos. + assert out.max().item() == pytest.approx(2.0) From bb027587527c72b7d43535465447ac40b90cefe6 Mon Sep 17 00:00:00 2001 From: nathon-lee Date: Thu, 23 Apr 2026 06:51:02 +0000 Subject: [PATCH 08/21] fix: fix some format issue by pre-commit Signed-off-by: nathon-lee --- deepspeed/sequence/test_autosp.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/deepspeed/sequence/test_autosp.py b/deepspeed/sequence/test_autosp.py index 6822faaf7131..4fbda160532c 100644 --- a/deepspeed/sequence/test_autosp.py +++ b/deepspeed/sequence/test_autosp.py @@ -271,6 +271,7 @@ def test_original_module_preserved_inside_wrapper(self): # ModalityFusionSPAdapter tests # --------------------------------------------------------------------------- + class _ConcatFusionAdapter(ModalityFusionSPAdapter): """Concrete subclass that appends visual tokens after text tokens.""" @@ -367,6 +368,7 @@ def test_projection_is_applied(self): # Use a projection that doubles all values class _DoubleProjection(nn.Module): + def forward(self, x): return x * 2.0 From efc956044ab19b5dc58cc37a6dfcefad69d9bdb1 Mon Sep 17 00:00:00 2001 From: nathon-lee Date: Thu, 23 Apr 2026 15:04:30 +0800 Subject: [PATCH 09/21] [Sequence Parallelism] Implement LlavaFusionAdapter for visual token splice Signed-off-by: nathon-lee fix: fix some format issue by pre-commit Signed-off-by: nathon-lee fix: fix some format err by tool Signed-off-by: nathon-lee --- deepspeed/sequence/__init__.py | 2 +- deepspeed/sequence/autosp_fusion.py | 58 ++++++++++++++ deepspeed/sequence/test_autosp.py | 113 +++++++++++++++++++++++++++- 3 files changed, 171 insertions(+), 2 deletions(-) diff --git a/deepspeed/sequence/__init__.py b/deepspeed/sequence/__init__.py index e1b8447c7b8f..424262240263 100644 --- a/deepspeed/sequence/__init__.py +++ b/deepspeed/sequence/__init__.py @@ -5,5 +5,5 @@ from deepspeed.sequence.autosp_detector import detect_model_sp_info, SPModelInfo from deepspeed.sequence.autosp_vit import UlyssesSPViTAttention -from deepspeed.sequence.autosp_fusion import ModalityFusionSPAdapter +from deepspeed.sequence.autosp_fusion import ModalityFusionSPAdapter, LlavaFusionAdapter from deepspeed.sequence.auto_sp import auto_wrap_model_for_sp diff --git a/deepspeed/sequence/autosp_fusion.py b/deepspeed/sequence/autosp_fusion.py index 2e59d86d1e99..98a40f98c016 100644 --- a/deepspeed/sequence/autosp_fusion.py +++ b/deepspeed/sequence/autosp_fusion.py @@ -127,3 +127,61 @@ def _splice_visual_into_text(self, text_embeds: torch.Tensor, visual_embeds: tor raise NotImplementedError(f"{type(self).__name__}._splice_visual_into_text is not implemented. " "Subclass ModalityFusionSPAdapter and override this method to match " "your model's prepare_inputs_embeds logic.") + + +class LlavaFusionAdapter(ModalityFusionSPAdapter): + """LLaVA-style splice: replace each image placeholder token with visual tokens. + + Follows the logic of ``LlavaMetaForCausalLM.prepare_inputs_labels_for_multimodal``: + for each sample, locate ``image_token_id`` placeholders in ``input_ids``, + remove them, and insert the corresponding visual token chunk in their place. + + Visual tokens for a sample are split evenly across the number of image + placeholders found. This matches the common single-image case (one + placeholder per sample) and simple multi-image cases where every image + contributes the same number of tokens. + + Parameters are inherited from :class:`ModalityFusionSPAdapter`. + """ + + def _splice_visual_into_text(self, text_embeds: torch.Tensor, visual_embeds: torch.Tensor, + input_ids: torch.Tensor) -> torch.Tensor: + bs, text_len, hidden = text_embeds.shape + device = text_embeds.device + + fused_samples = [] + for i in range(bs): + img_pos = (input_ids[i] == self.image_token_id).nonzero(as_tuple=True)[0] + num_images = img_pos.numel() + + if num_images == 0: + # No image in this sample — keep text embeddings unchanged. + fused_samples.append(text_embeds[i]) + continue + + # Split all visual tokens evenly across the image placeholders. + visual_chunks = torch.chunk(visual_embeds[i], num_images, dim=0) + + segments = [] + prev = 0 + for j, pos in enumerate(img_pos.tolist()): + # Text segment before this placeholder. + if pos > prev: + segments.append(text_embeds[i, prev:pos]) + # Visual tokens replacing this placeholder. + segments.append(visual_chunks[j]) + # Skip the placeholder token itself. + prev = pos + 1 + + # Remaining text after the last placeholder. + if prev < text_len: + segments.append(text_embeds[i, prev:]) + + fused_samples.append(torch.cat(segments, dim=0)) + + # Pad all samples to the same length so they stack into a tensor. + max_len = max(s.shape[0] for s in fused_samples) + out = torch.zeros(bs, max_len, hidden, dtype=text_embeds.dtype, device=device) + for i, s in enumerate(fused_samples): + out[i, :s.shape[0]] = s + return out diff --git a/deepspeed/sequence/test_autosp.py b/deepspeed/sequence/test_autosp.py index 4fbda160532c..9c8564f1e257 100644 --- a/deepspeed/sequence/test_autosp.py +++ b/deepspeed/sequence/test_autosp.py @@ -8,6 +8,7 @@ - UlyssesSPViTAttention: ViT SP wrapper - auto_wrap_model_for_sp: end-to-end wrapping - ModalityFusionSPAdapter: cross-modal gather/scatter + - LlavaFusionAdapter: LLaVA-style visual token splice """ import pytest @@ -16,7 +17,7 @@ from deepspeed.sequence.autosp_detector import (SPModelInfo, _LLM_ATTN_CLASSNAMES, _VIT_ATTN_CLASSNAMES, detect_model_sp_info) -from deepspeed.sequence.autosp_fusion import ModalityFusionSPAdapter +from deepspeed.sequence.autosp_fusion import LlavaFusionAdapter, ModalityFusionSPAdapter from deepspeed.sequence.autosp_vit import UlyssesSPViTAttention from deepspeed.sequence.auto_sp import _set_module_by_name, auto_wrap_model_for_sp from deepspeed.sequence.layer import DistributedAttention @@ -384,3 +385,113 @@ def forward(self, x): # Since text_len=4 and local_len=6, rank0 slice starts with text zeros # and ends with some visual twos. assert out.max().item() == pytest.approx(2.0) + + +# --------------------------------------------------------------------------- +# LlavaFusionAdapter tests (tests _splice_visual_into_text directly) +# --------------------------------------------------------------------------- + +_IMAGE_ID = -200 # matches ModalityFusionSPAdapter default + + +def _make_llava_adapter(world_size=2, rank=0): + pg = _make_mock_process_group(world_size=world_size, rank=rank) + return LlavaFusionAdapter(nn.Identity(), pg, image_token_id=_IMAGE_ID) + + +class TestLlavaFusionAdapter: + + def test_single_image_fused_shape(self): + """One image placeholder per sample → fused length = text_len - 1 + num_visual.""" + adapter = _make_llava_adapter() + bs, text_len, num_vis, hidden = 2, 6, 4, 8 + # Place a single image placeholder at position 2. + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[:, 2] = _IMAGE_ID + text = torch.randn(bs, text_len, hidden) + visual = torch.randn(bs, num_vis, hidden) + + fused = adapter._splice_visual_into_text(text, visual, ids) + # placeholder is removed and replaced by num_vis tokens + assert fused.shape == (bs, text_len - 1 + num_vis, hidden) + + def test_text_values_preserved_around_image(self): + """Text tokens before and after the placeholder must be numerically intact.""" + adapter = _make_llava_adapter() + bs, text_len, num_vis, hidden = 1, 5, 3, 4 + # Placeholder at index 2: text = [A, B, , C, D] + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[0, 2] = _IMAGE_ID + text = torch.arange(bs * text_len * hidden, dtype=torch.float).reshape(bs, text_len, hidden) + visual = torch.ones(bs, num_vis, hidden) * 99.0 + + fused = adapter._splice_visual_into_text(text, visual, ids) + # fused = [A, B, vis0, vis1, vis2, C, D] + assert torch.allclose(fused[0, :2], text[0, :2]) # A, B preserved + assert torch.allclose(fused[0, 5:], text[0, 3:]) # C, D preserved + assert torch.allclose(fused[0, 2:5], visual[0]) # visual inserted + + def test_no_image_token_returns_text_unchanged(self): + """When input_ids has no placeholder, output equals text_embeds exactly.""" + adapter = _make_llava_adapter() + bs, text_len, hidden = 2, 6, 8 + ids = torch.zeros(bs, text_len, dtype=torch.long) # no -200 + text = torch.randn(bs, text_len, hidden) + visual = torch.randn(bs, 4, hidden) + + fused = adapter._splice_visual_into_text(text, visual, ids) + assert fused.shape == (bs, text_len, hidden) + assert torch.allclose(fused, text) + + def test_multi_image_splice(self): + """Two placeholders per sample → visual tokens split evenly between them.""" + adapter = _make_llava_adapter() + bs, text_len, num_vis, hidden = 1, 7, 6, 4 + # Placeholders at index 1 and 4: [t0, , t2, t3, , t5, t6] + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[0, 1] = _IMAGE_ID + ids[0, 4] = _IMAGE_ID + text = torch.zeros(bs, text_len, hidden) + # First 3 visual tokens = 1.0, last 3 = 2.0 (so we can tell them apart) + visual = torch.cat([torch.ones(bs, 3, hidden), torch.full((bs, 3, hidden), 2.0)], dim=1) + + fused = adapter._splice_visual_into_text(text, visual, ids) + # Expected fused length: 7 - 2 placeholders + 6 visual = 11 + assert fused.shape == (bs, 11, hidden) + # First chunk (indices 1-3) should be 1.0 + assert torch.allclose(fused[0, 1:4], torch.ones(3, hidden)) + # Second chunk (indices 6-8) should be 2.0 + assert torch.allclose(fused[0, 6:9], torch.full((3, hidden), 2.0)) + + def test_batch_padding_when_lengths_differ(self): + """Samples with different numbers of image tokens are padded to max length.""" + adapter = _make_llava_adapter() + hidden = 4 + # Sample 0: 1 placeholder in a 4-token sequence + 2 visual → fused len = 5 + # Sample 1: no placeholder in a 4-token sequence → fused len = 4 + ids = torch.zeros(2, 4, dtype=torch.long) + ids[0, 1] = _IMAGE_ID + text = torch.ones(2, 4, hidden) + visual = torch.ones(2, 2, hidden) * 3.0 + + fused = adapter._splice_visual_into_text(text, visual, ids) + # Max fused length is 5; sample 1 padded with zeros at the end. + assert fused.shape == (2, 5, hidden) + assert torch.all(fused[1, 4:] == 0) # padding tokens are zero + + def test_forward_end_to_end_shape(self): + """Full forward pass through LlavaFusionAdapter returns the correct shard shape.""" + world_size = 2 + pg = _make_mock_process_group(world_size=world_size, rank=0) + adapter = LlavaFusionAdapter(nn.Identity(), pg, image_token_id=_IMAGE_ID) + + bs, local_v, text_len, hidden = 1, 4, 6, 8 + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[0, 2] = _IMAGE_ID # one placeholder + visual = torch.randn(bs, local_v, hidden) + text = torch.randn(bs, text_len, hidden) + + out = adapter(visual, text, ids) + # fused_len = text_len - 1 + local_v * world_size = 5 + 8 = 13 + # padded to 14 (next multiple of 2), local = 7 + assert out.shape == (bs, 7, hidden) From 5aa270b188cb1e575ea608fa7e8b75489e10180b Mon Sep 17 00:00:00 2001 From: nathon-lee Date: Thu, 23 Apr 2026 08:52:54 +0000 Subject: [PATCH 10/21] feat: Add numerical equivalence tests for AutoSP multimodal SP Signed-off-by: nathon-lee --- .../test_autosp_equivalence.py | 185 ++++++++++++++++++ 1 file changed, 185 insertions(+) create mode 100644 tests/unit/sequence_parallelism/test_autosp_equivalence.py diff --git a/tests/unit/sequence_parallelism/test_autosp_equivalence.py b/tests/unit/sequence_parallelism/test_autosp_equivalence.py new file mode 100644 index 000000000000..3de9b5621918 --- /dev/null +++ b/tests/unit/sequence_parallelism/test_autosp_equivalence.py @@ -0,0 +1,185 @@ +# SPDX-License-Identifier: Apache-2.0 +# DeepSpeed Team +""" +Numerical equivalence tests for AutoSP multimodal sequence parallelism. + +Each test verifies that running the SP-wrapped path across N ranks produces +the same result as the equivalent single-device (non-SP) computation. + +These tests require 2 GPUs and the NCCL backend. +Run with: + + deepspeed --num_gpus 2 -m pytest tests/unit/sequence_parallelism/test_autosp_equivalence.py -v +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import pytest + +import deepspeed.comm as dist +from deepspeed.sequence.autosp_vit import UlyssesSPViTAttention +from deepspeed.sequence.autosp_fusion import LlavaFusionAdapter + +from unit.common import DistributedTest + +# --------------------------------------------------------------------------- +# Shared identity attention — deterministic, easy to verify +# --------------------------------------------------------------------------- + +_IMAGE_TOKEN_ID = -200 + + +class _IdentityAttn(nn.Module): + """Returns hidden_states unchanged so that gather-compute-scatter is a no-op.""" + + def forward(self, hidden_states, **kwargs): + return hidden_states + + +# --------------------------------------------------------------------------- +# UlyssesSPViTAttention equivalence +# --------------------------------------------------------------------------- + + +class TestViTSPEquivalence(DistributedTest): + """SP-wrapped ViT attention with an identity inner module must reproduce + the unsharded output on every rank.""" + + world_size = 2 + + @pytest.mark.parametrize("has_cls_token", [True, False]) + @pytest.mark.parametrize("num_patches", [8, 12]) + def test_output_equals_single_device(self, has_cls_token, num_patches): + """Each rank's local output slice must match the corresponding slice of + the single-device output.""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + rank = dist.get_rank(sp_group) + bs, hidden = 2, 32 + + # --- Single-device reference --- + # Build the full input (all ranks see the same RNG seed so the tensor + # is identical everywhere). + torch.manual_seed(42) + if has_cls_token: + full_input = torch.randn(bs, 1 + num_patches, hidden) + else: + full_input = torch.randn(bs, num_patches, hidden) + + identity = _IdentityAttn() + # Single-device path is just identity — output == input. + ref_out = identity(full_input) + + # --- SP path --- + local_patches = num_patches // self.world_size + if has_cls_token: + cls = full_input[:, :1, :] + patch_slice = full_input[:, 1 + rank * local_patches:1 + (rank + 1) * local_patches, :] + local_input = torch.cat([cls, patch_slice], dim=1) + else: + local_input = full_input[:, rank * local_patches:(rank + 1) * local_patches, :] + + wrapper = UlyssesSPViTAttention(_IdentityAttn(), sp_group, has_cls_token=has_cls_token) + sp_out = wrapper(local_input) + + # --- Compare --- + # sp_out is the local slice; reconstruct what slice of ref_out it maps to. + if has_cls_token: + ref_slice = torch.cat( + [ref_out[:, :1, :], ref_out[:, 1 + rank * local_patches:1 + (rank + 1) * local_patches, :]], dim=1) + else: + ref_slice = ref_out[:, rank * local_patches:(rank + 1) * local_patches, :] + + assert torch.allclose(sp_out, ref_slice, + atol=1e-5), (f"rank={rank} sp_out differs from reference: " + f"max_diff={( sp_out - ref_slice).abs().max().item():.2e}") + + +# --------------------------------------------------------------------------- +# LlavaFusionAdapter equivalence +# --------------------------------------------------------------------------- + + +class TestLlavaFusionEquivalence(DistributedTest): + """Verifies that the SP gather/scatter in LlavaFusionAdapter is a lossless + round-trip: concatenating all ranks' output shards reproduces the full + fused sequence that single-device splicing would produce.""" + + world_size = 2 + + def _build_inputs(self, bs, local_v, text_len, hidden, rank): + """Build deterministic visual and text tensors identical on every rank.""" + torch.manual_seed(0) + # Each rank holds a contiguous slice of the visual tokens. + full_visual = torch.randn(bs, local_v * self.world_size, hidden) + text = torch.randn(bs, text_len, hidden) + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[:, 1] = _IMAGE_TOKEN_ID # one image placeholder at position 1 + local_visual = full_visual[:, rank * local_v:(rank + 1) * local_v, :] + return full_visual, local_visual, text, ids + + def test_shards_reassemble_to_full_fused(self): + """Gathering all ranks' output shards must equal the single-device + fused sequence (modulo padding zeros).""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + rank = dist.get_rank(sp_group) + + bs, local_v, text_len, hidden = 1, 4, 6, 8 + full_visual, local_visual, text, ids = self._build_inputs(bs, local_v, text_len, hidden, rank) + + # --- SP path: each rank gets one shard --- + adapter = LlavaFusionAdapter(nn.Identity(), sp_group, image_token_id=_IMAGE_TOKEN_ID) + local_out = adapter(local_visual, text, ids) # [bs, local_fused, hidden] + + # Gather all shards onto every rank so we can compare globally. + gathered = [torch.zeros_like(local_out) for _ in range(self.world_size)] + dist.all_gather(gathered, local_out, group=sp_group) + full_sp_out = torch.cat(gathered, dim=1) # [bs, padded_fused, hidden] + + # --- Single-device reference --- + # Simulate what a non-SP LlavaFusionAdapter would produce: project the + # full visual tensor (identity here) and splice once. + ref_adapter = LlavaFusionAdapter(nn.Identity(), sp_group, image_token_id=_IMAGE_TOKEN_ID) + # Call _splice_visual_into_text directly so we bypass the SP scatter. + ref_fused = ref_adapter._splice_visual_into_text(text, full_visual, ids) + + # Pad reference to the same padded length. + fused_len = ref_fused.shape[1] + pad = (self.world_size - fused_len % self.world_size) % self.world_size + if pad > 0: + ref_fused = F.pad(ref_fused, (0, 0, 0, pad)) + + assert torch.allclose(full_sp_out, ref_fused, + atol=1e-5), (f"rank={rank} reassembled SP output differs from reference: " + f"max_diff={( full_sp_out - ref_fused).abs().max().item():.2e}") + + def test_no_image_token_passthrough(self): + """When there are no image placeholders the SP fused output must equal + the sharded text after padding/scatter (all-text path).""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + rank = dist.get_rank(sp_group) + + bs, local_v, text_len, hidden = 1, 2, 8, 4 + torch.manual_seed(1) + local_visual = torch.randn(bs, local_v, hidden) + text = torch.randn(bs, text_len, hidden) + ids = torch.zeros(bs, text_len, dtype=torch.long) # no image placeholder + + adapter = LlavaFusionAdapter(nn.Identity(), sp_group, image_token_id=_IMAGE_TOKEN_ID) + local_out = adapter(local_visual, text, ids) + + # Gather shards and strip the padding slice from visual gather. + gathered = [torch.zeros_like(local_out) for _ in range(self.world_size)] + dist.all_gather(gathered, local_out, group=sp_group) + full_sp_out = torch.cat(gathered, dim=1) + + # Expected: text unchanged, then visual tokens appended (all-gather copies). + full_visual = torch.cat([local_visual] * self.world_size, dim=1) + ref_fused = torch.cat([text, full_visual], dim=1) + pad = (self.world_size - ref_fused.shape[1] % self.world_size) % self.world_size + if pad > 0: + ref_fused = F.pad(ref_fused, (0, 0, 0, pad)) + + assert torch.allclose(full_sp_out, ref_fused, + atol=1e-5), (f"rank={rank} no-image path differs from reference: " + f"max_diff={( full_sp_out - ref_fused).abs().max().item():.2e}") From d7de96fc48c474755f810f3727597a7bfa509e18 Mon Sep 17 00:00:00 2001 From: nathon-lee Date: Fri, 24 Apr 2026 02:36:11 +0000 Subject: [PATCH 11/21] test: add numerical equivalence tests for AutoSP multimodal sequence parallelism Signed-off-by: nathon-lee --- .../test_autosp_equivalence.py | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/tests/unit/sequence_parallelism/test_autosp_equivalence.py b/tests/unit/sequence_parallelism/test_autosp_equivalence.py index 3de9b5621918..44121c9c128c 100644 --- a/tests/unit/sequence_parallelism/test_autosp_equivalence.py +++ b/tests/unit/sequence_parallelism/test_autosp_equivalence.py @@ -9,7 +9,7 @@ These tests require 2 GPUs and the NCCL backend. Run with: - deepspeed --num_gpus 2 -m pytest tests/unit/sequence_parallelism/test_autosp_equivalence.py -v + deepspeed --num_gpus 2 --no_local_rank --module pytest tests/unit/sequence_parallelism/test_autosp_equivalence.py -v """ import torch @@ -62,11 +62,11 @@ def test_output_equals_single_device(self, has_cls_token, num_patches): # is identical everywhere). torch.manual_seed(42) if has_cls_token: - full_input = torch.randn(bs, 1 + num_patches, hidden) + full_input = torch.randn(bs, 1 + num_patches, hidden).cuda() else: - full_input = torch.randn(bs, num_patches, hidden) + full_input = torch.randn(bs, num_patches, hidden).cuda() - identity = _IdentityAttn() + identity = _IdentityAttn().cuda() # Single-device path is just identity — output == input. ref_out = identity(full_input) @@ -79,7 +79,7 @@ def test_output_equals_single_device(self, has_cls_token, num_patches): else: local_input = full_input[:, rank * local_patches:(rank + 1) * local_patches, :] - wrapper = UlyssesSPViTAttention(_IdentityAttn(), sp_group, has_cls_token=has_cls_token) + wrapper = UlyssesSPViTAttention(_IdentityAttn().cuda(), sp_group, has_cls_token=has_cls_token).cuda() sp_out = wrapper(local_input) # --- Compare --- @@ -111,9 +111,9 @@ def _build_inputs(self, bs, local_v, text_len, hidden, rank): """Build deterministic visual and text tensors identical on every rank.""" torch.manual_seed(0) # Each rank holds a contiguous slice of the visual tokens. - full_visual = torch.randn(bs, local_v * self.world_size, hidden) - text = torch.randn(bs, text_len, hidden) - ids = torch.zeros(bs, text_len, dtype=torch.long) + full_visual = torch.randn(bs, local_v * self.world_size, hidden).cuda() + text = torch.randn(bs, text_len, hidden).cuda() + ids = torch.zeros(bs, text_len, dtype=torch.long).cuda() ids[:, 1] = _IMAGE_TOKEN_ID # one image placeholder at position 1 local_visual = full_visual[:, rank * local_v:(rank + 1) * local_v, :] return full_visual, local_visual, text, ids @@ -128,7 +128,7 @@ def test_shards_reassemble_to_full_fused(self): full_visual, local_visual, text, ids = self._build_inputs(bs, local_v, text_len, hidden, rank) # --- SP path: each rank gets one shard --- - adapter = LlavaFusionAdapter(nn.Identity(), sp_group, image_token_id=_IMAGE_TOKEN_ID) + adapter = LlavaFusionAdapter(nn.Identity(), sp_group, image_token_id=_IMAGE_TOKEN_ID).cuda() local_out = adapter(local_visual, text, ids) # [bs, local_fused, hidden] # Gather all shards onto every rank so we can compare globally. @@ -139,7 +139,7 @@ def test_shards_reassemble_to_full_fused(self): # --- Single-device reference --- # Simulate what a non-SP LlavaFusionAdapter would produce: project the # full visual tensor (identity here) and splice once. - ref_adapter = LlavaFusionAdapter(nn.Identity(), sp_group, image_token_id=_IMAGE_TOKEN_ID) + ref_adapter = LlavaFusionAdapter(nn.Identity(), sp_group, image_token_id=_IMAGE_TOKEN_ID).cuda() # Call _splice_visual_into_text directly so we bypass the SP scatter. ref_fused = ref_adapter._splice_visual_into_text(text, full_visual, ids) @@ -161,11 +161,11 @@ def test_no_image_token_passthrough(self): bs, local_v, text_len, hidden = 1, 2, 8, 4 torch.manual_seed(1) - local_visual = torch.randn(bs, local_v, hidden) - text = torch.randn(bs, text_len, hidden) - ids = torch.zeros(bs, text_len, dtype=torch.long) # no image placeholder + local_visual = torch.randn(bs, local_v, hidden).cuda() + text = torch.randn(bs, text_len, hidden).cuda() + ids = torch.zeros(bs, text_len, dtype=torch.long).cuda() # no image placeholder - adapter = LlavaFusionAdapter(nn.Identity(), sp_group, image_token_id=_IMAGE_TOKEN_ID) + adapter = LlavaFusionAdapter(nn.Identity(), sp_group, image_token_id=_IMAGE_TOKEN_ID).cuda() local_out = adapter(local_visual, text, ids) # Gather shards and strip the padding slice from visual gather. @@ -173,9 +173,9 @@ def test_no_image_token_passthrough(self): dist.all_gather(gathered, local_out, group=sp_group) full_sp_out = torch.cat(gathered, dim=1) - # Expected: text unchanged, then visual tokens appended (all-gather copies). - full_visual = torch.cat([local_visual] * self.world_size, dim=1) - ref_fused = torch.cat([text, full_visual], dim=1) + # Expected: when there is no image token, the visual tokens are ignored. + # So the fused output should just be the text tokens. + ref_fused = text pad = (self.world_size - ref_fused.shape[1] % self.world_size) % self.world_size if pad > 0: ref_fused = F.pad(ref_fused, (0, 0, 0, pad)) From bc01ed75ed9404adfeaf9e6c0d1b0594229bd461 Mon Sep 17 00:00:00 2001 From: nathon-lee Date: Fri, 24 Apr 2026 02:47:44 +0000 Subject: [PATCH 12/21] test: fix check-torchcuda CI failure in autosp tests Signed-off-by: nathon-lee --- .../test_autosp_equivalence.py | 32 +++++++++++-------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/tests/unit/sequence_parallelism/test_autosp_equivalence.py b/tests/unit/sequence_parallelism/test_autosp_equivalence.py index 44121c9c128c..bb05aabe8ec8 100644 --- a/tests/unit/sequence_parallelism/test_autosp_equivalence.py +++ b/tests/unit/sequence_parallelism/test_autosp_equivalence.py @@ -20,6 +20,7 @@ import deepspeed.comm as dist from deepspeed.sequence.autosp_vit import UlyssesSPViTAttention from deepspeed.sequence.autosp_fusion import LlavaFusionAdapter +from deepspeed.accelerator import get_accelerator from unit.common import DistributedTest @@ -62,11 +63,11 @@ def test_output_equals_single_device(self, has_cls_token, num_patches): # is identical everywhere). torch.manual_seed(42) if has_cls_token: - full_input = torch.randn(bs, 1 + num_patches, hidden).cuda() + full_input = torch.randn(bs, 1 + num_patches, hidden).to(get_accelerator().device_name()) else: - full_input = torch.randn(bs, num_patches, hidden).cuda() + full_input = torch.randn(bs, num_patches, hidden).to(get_accelerator().device_name()) - identity = _IdentityAttn().cuda() + identity = _IdentityAttn().to(get_accelerator().device_name()) # Single-device path is just identity — output == input. ref_out = identity(full_input) @@ -79,7 +80,9 @@ def test_output_equals_single_device(self, has_cls_token, num_patches): else: local_input = full_input[:, rank * local_patches:(rank + 1) * local_patches, :] - wrapper = UlyssesSPViTAttention(_IdentityAttn().cuda(), sp_group, has_cls_token=has_cls_token).cuda() + wrapper = UlyssesSPViTAttention(_IdentityAttn().to(get_accelerator().device_name()), + sp_group, + has_cls_token=has_cls_token).to(get_accelerator().device_name()) sp_out = wrapper(local_input) # --- Compare --- @@ -111,9 +114,9 @@ def _build_inputs(self, bs, local_v, text_len, hidden, rank): """Build deterministic visual and text tensors identical on every rank.""" torch.manual_seed(0) # Each rank holds a contiguous slice of the visual tokens. - full_visual = torch.randn(bs, local_v * self.world_size, hidden).cuda() - text = torch.randn(bs, text_len, hidden).cuda() - ids = torch.zeros(bs, text_len, dtype=torch.long).cuda() + full_visual = torch.randn(bs, local_v * self.world_size, hidden).to(get_accelerator().device_name()) + text = torch.randn(bs, text_len, hidden).to(get_accelerator().device_name()) + ids = torch.zeros(bs, text_len, dtype=torch.long).to(get_accelerator().device_name()) ids[:, 1] = _IMAGE_TOKEN_ID # one image placeholder at position 1 local_visual = full_visual[:, rank * local_v:(rank + 1) * local_v, :] return full_visual, local_visual, text, ids @@ -128,7 +131,8 @@ def test_shards_reassemble_to_full_fused(self): full_visual, local_visual, text, ids = self._build_inputs(bs, local_v, text_len, hidden, rank) # --- SP path: each rank gets one shard --- - adapter = LlavaFusionAdapter(nn.Identity(), sp_group, image_token_id=_IMAGE_TOKEN_ID).cuda() + adapter = LlavaFusionAdapter(nn.Identity(), sp_group, + image_token_id=_IMAGE_TOKEN_ID).to(get_accelerator().device_name()) local_out = adapter(local_visual, text, ids) # [bs, local_fused, hidden] # Gather all shards onto every rank so we can compare globally. @@ -139,7 +143,8 @@ def test_shards_reassemble_to_full_fused(self): # --- Single-device reference --- # Simulate what a non-SP LlavaFusionAdapter would produce: project the # full visual tensor (identity here) and splice once. - ref_adapter = LlavaFusionAdapter(nn.Identity(), sp_group, image_token_id=_IMAGE_TOKEN_ID).cuda() + ref_adapter = LlavaFusionAdapter(nn.Identity(), sp_group, + image_token_id=_IMAGE_TOKEN_ID).to(get_accelerator().device_name()) # Call _splice_visual_into_text directly so we bypass the SP scatter. ref_fused = ref_adapter._splice_visual_into_text(text, full_visual, ids) @@ -161,11 +166,12 @@ def test_no_image_token_passthrough(self): bs, local_v, text_len, hidden = 1, 2, 8, 4 torch.manual_seed(1) - local_visual = torch.randn(bs, local_v, hidden).cuda() - text = torch.randn(bs, text_len, hidden).cuda() - ids = torch.zeros(bs, text_len, dtype=torch.long).cuda() # no image placeholder + local_visual = torch.randn(bs, local_v, hidden).to(get_accelerator().device_name()) + text = torch.randn(bs, text_len, hidden).to(get_accelerator().device_name()) + ids = torch.zeros(bs, text_len, dtype=torch.long).to(get_accelerator().device_name()) # no image placeholder - adapter = LlavaFusionAdapter(nn.Identity(), sp_group, image_token_id=_IMAGE_TOKEN_ID).cuda() + adapter = LlavaFusionAdapter(nn.Identity(), sp_group, + image_token_id=_IMAGE_TOKEN_ID).to(get_accelerator().device_name()) local_out = adapter(local_visual, text, ids) # Gather shards and strip the padding slice from visual gather. From 34ca81782fdf3ff8cce82d8308754c914223e407 Mon Sep 17 00:00:00 2001 From: nathon-lee Date: Fri, 24 Apr 2026 15:19:09 +0800 Subject: [PATCH 13/21] feat: Added two new adapter classes after LlavaFusionAdapter Signed-off-by: nathon-lee fix: fix some format errs by tool Signed-off-by: nathon-lee --- deepspeed/sequence/__init__.py | 3 +- deepspeed/sequence/autosp_fusion.py | 139 +++++++++++ deepspeed/sequence/test_autosp.py | 229 +++++++++++++++++- .../test_autosp_equivalence.py | 196 ++++++++++++++- 4 files changed, 564 insertions(+), 3 deletions(-) diff --git a/deepspeed/sequence/__init__.py b/deepspeed/sequence/__init__.py index 424262240263..b76f944eff79 100644 --- a/deepspeed/sequence/__init__.py +++ b/deepspeed/sequence/__init__.py @@ -5,5 +5,6 @@ from deepspeed.sequence.autosp_detector import detect_model_sp_info, SPModelInfo from deepspeed.sequence.autosp_vit import UlyssesSPViTAttention -from deepspeed.sequence.autosp_fusion import ModalityFusionSPAdapter, LlavaFusionAdapter +from deepspeed.sequence.autosp_fusion import (ModalityFusionSPAdapter, LlavaFusionAdapter, InternVLFusionAdapter, + Qwen2VLFusionAdapter) from deepspeed.sequence.auto_sp import auto_wrap_model_for_sp diff --git a/deepspeed/sequence/autosp_fusion.py b/deepspeed/sequence/autosp_fusion.py index 98a40f98c016..cdd3286dd6e9 100644 --- a/deepspeed/sequence/autosp_fusion.py +++ b/deepspeed/sequence/autosp_fusion.py @@ -185,3 +185,142 @@ def _splice_visual_into_text(self, text_embeds: torch.Tensor, visual_embeds: tor for i, s in enumerate(fused_samples): out[i, :s.shape[0]] = s return out + + +class InternVLFusionAdapter(ModalityFusionSPAdapter): + """InternVL-style splice: replace IMG_CONTEXT token runs with visual tokens. + + InternVL encodes each image as `` ×N `` + inside the token sequence. Each ``IMG_CONTEXT`` token (``image_token_id``) + is a 1-to-1 placeholder for one ViT visual token. This adapter locates + every contiguous run of ``image_token_id`` tokens and replaces them with + the corresponding slice of *visual_embeds*, while preserving the + ``IMG_START`` / ``IMG_END`` boundary embeddings unchanged. + + Because the replacement is 1-to-1, the output sequence length equals the + input sequence length (no length change). + + Parameters are inherited from :class:`ModalityFusionSPAdapter`. + Set ``image_token_id`` to the ``IMG_CONTEXT`` token id used by the model + (e.g. the id of ````). + """ + + def _splice_visual_into_text(self, text_embeds: torch.Tensor, visual_embeds: torch.Tensor, + input_ids: torch.Tensor) -> torch.Tensor: + # Start from a clone of text embeddings; we only overwrite context positions. + out = text_embeds.clone() + bs = text_embeds.shape[0] + + for i in range(bs): + ctx_pos = (input_ids[i] == self.image_token_id).nonzero(as_tuple=True)[0] + if ctx_pos.numel() == 0: + continue + # ctx_pos lists every IMG_CONTEXT index in order. visual_embeds[i] + # has exactly ctx_pos.numel() tokens (one per context position). + out[i, ctx_pos] = visual_embeds[i, :ctx_pos.numel()] + + return out + + +class Qwen2VLFusionAdapter(nn.Module): + """Qwen2-VL-style splice: visual tokens enclosed by vision_start/end tokens. + + Qwen2-VL wraps each image's visual tokens with a pair of special boundary + tokens in ``input_ids``: ``vision_start_token_id`` and + ``vision_end_token_id``. The placeholder tokens between each + (start, end) pair are replaced 1-to-1 by the projected visual token + embeddings. The boundary token embeddings are kept unchanged. + + Because the replacement is 1-to-1, the output sequence length equals the + input sequence length. + + Parameters + ---------- + projection: + The vision projection module (e.g. ``visual.merger``). + process_group: + The sequence-parallel process group. + vision_start_token_id: + Token id of ``<|vision_start|>``. + vision_end_token_id: + Token id of ``<|vision_end|>``. + """ + + def __init__(self, projection: nn.Module, process_group, vision_start_token_id: int, + vision_end_token_id: int) -> None: + super().__init__() + self.projection = projection + self.process_group = process_group + self.world_size = dist.get_world_size(process_group) + self.vision_start_token_id = vision_start_token_id + self.vision_end_token_id = vision_end_token_id + + def forward(self, visual_features: torch.Tensor, text_embeds: torch.Tensor, + input_ids: torch.Tensor) -> torch.Tensor: + """Project visual features and return a sharded fused embedding. + + Parameters + ---------- + visual_features: + Raw visual features from the ViT encoder. + Shape: ``[bs, local_visual_tokens, vit_hidden]``. + text_embeds: + Full text token embeddings (not sharded yet). + Shape: ``[bs, text_seq_len, lm_hidden]``. + input_ids: + Token IDs used to locate vision_start/end boundaries. + Shape: ``[bs, text_seq_len]``. + + Returns + ------- + Sharded fused embedding for this rank. + Shape: ``[bs, local_fused_len, lm_hidden]``. + """ + # 1. Project visual features to the LLM hidden dimension. + visual_embeds = self.projection(visual_features) # [bs, local_v, lm_hidden] + + # 2. All-gather visual slices from all SP ranks. + parts = [torch.zeros_like(visual_embeds) for _ in range(self.world_size)] + dist.all_gather(parts, visual_embeds.contiguous(), group=self.process_group) + full_visual = torch.cat(parts, dim=1) # [bs, total_visual_tokens, lm_hidden] + + # 3. Replace placeholder positions in text with visual tokens (length-preserving). + fused = self._splice_visual_into_text(text_embeds, full_visual, input_ids) + + # 4. Pad fused length to be divisible by world_size, then scatter. + total_len = fused.shape[1] + pad = (self.world_size - total_len % self.world_size) % self.world_size + if pad > 0: + fused = F.pad(fused, (0, 0, 0, pad)) + + rank = dist.get_rank(self.process_group) + local_len = fused.shape[1] // self.world_size + return fused[:, rank * local_len:(rank + 1) * local_len, :].contiguous() + + def _splice_visual_into_text(self, text_embeds: torch.Tensor, visual_embeds: torch.Tensor, + input_ids: torch.Tensor) -> torch.Tensor: + """Replace inner placeholder tokens between vision_start/end pairs with visual embeddings.""" + out = text_embeds.clone() + bs = text_embeds.shape[0] + + for i in range(bs): + start_pos = (input_ids[i] == self.vision_start_token_id).nonzero(as_tuple=True)[0] + end_pos = (input_ids[i] == self.vision_end_token_id).nonzero(as_tuple=True)[0] + + if start_pos.numel() == 0: + continue + + # Accumulate inner placeholder positions across all start/end pairs. + # Inner positions are (start+1) .. (end-1) inclusive, i.e. excluding + # the boundary tokens themselves. + inner_positions = [] + for s, e in zip(start_pos.tolist(), end_pos.tolist()): + inner_positions.extend(range(s + 1, e)) + + if not inner_positions: + continue + + inner_pos = torch.tensor(inner_positions, dtype=torch.long, device=text_embeds.device) + out[i, inner_pos] = visual_embeds[i, :len(inner_positions)] + + return out diff --git a/deepspeed/sequence/test_autosp.py b/deepspeed/sequence/test_autosp.py index 9c8564f1e257..771a9e7bb6b7 100644 --- a/deepspeed/sequence/test_autosp.py +++ b/deepspeed/sequence/test_autosp.py @@ -9,6 +9,8 @@ - auto_wrap_model_for_sp: end-to-end wrapping - ModalityFusionSPAdapter: cross-modal gather/scatter - LlavaFusionAdapter: LLaVA-style visual token splice + - InternVLFusionAdapter: InternVL-style IMG_CONTEXT token splice + - Qwen2VLFusionAdapter: Qwen2-VL vision_start/end bounded splice """ import pytest @@ -17,7 +19,8 @@ from deepspeed.sequence.autosp_detector import (SPModelInfo, _LLM_ATTN_CLASSNAMES, _VIT_ATTN_CLASSNAMES, detect_model_sp_info) -from deepspeed.sequence.autosp_fusion import LlavaFusionAdapter, ModalityFusionSPAdapter +from deepspeed.sequence.autosp_fusion import (InternVLFusionAdapter, LlavaFusionAdapter, ModalityFusionSPAdapter, + Qwen2VLFusionAdapter) from deepspeed.sequence.autosp_vit import UlyssesSPViTAttention from deepspeed.sequence.auto_sp import _set_module_by_name, auto_wrap_model_for_sp from deepspeed.sequence.layer import DistributedAttention @@ -495,3 +498,227 @@ def test_forward_end_to_end_shape(self): # fused_len = text_len - 1 + local_v * world_size = 5 + 8 = 13 # padded to 14 (next multiple of 2), local = 7 assert out.shape == (bs, 7, hidden) + + +# --------------------------------------------------------------------------- +# InternVLFusionAdapter tests (tests _splice_visual_into_text directly) +# --------------------------------------------------------------------------- + +_CONTEXT_ID = 92546 # arbitrary IMG_CONTEXT token id for tests +_START_ID = 92545 +_END_ID = 92547 + + +def _make_internvl_adapter(world_size=2, rank=0): + pg = _make_mock_process_group(world_size=world_size, rank=rank) + return InternVLFusionAdapter(nn.Identity(), pg, image_token_id=_CONTEXT_ID) + + +class TestInternVLFusionAdapter: + + def test_context_tokens_replaced_with_visual(self): + """IMG_CONTEXT positions must carry visual embeddings after splice.""" + adapter = _make_internvl_adapter() + bs, text_len, hidden = 1, 7, 4 + # Layout: [t0, START, ctx, ctx, ctx, END, t6] + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[0, 2] = _CONTEXT_ID + ids[0, 3] = _CONTEXT_ID + ids[0, 4] = _CONTEXT_ID + + text = torch.zeros(bs, text_len, hidden) + visual = torch.ones(bs, 3, hidden) * 7.0 + + fused = adapter._splice_visual_into_text(text, visual, ids) + assert torch.allclose(fused[0, 2:5], visual[0]) + + def test_sequence_length_preserved(self): + """Output length must equal input length (1-to-1 replacement).""" + adapter = _make_internvl_adapter() + bs, text_len, hidden = 2, 10, 8 + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[:, 3:7] = _CONTEXT_ID # 4 context tokens per sample + text = torch.randn(bs, text_len, hidden) + visual = torch.randn(bs, 4, hidden) + + fused = adapter._splice_visual_into_text(text, visual, ids) + assert fused.shape == (bs, text_len, hidden) + + def test_boundary_tokens_preserved(self): + """IMG_START and IMG_END embeddings must be unchanged after splice.""" + adapter = _make_internvl_adapter() + bs, text_len, hidden = 1, 5, 4 + # [START, ctx, ctx, END, text] + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[0, 1] = _CONTEXT_ID + ids[0, 2] = _CONTEXT_ID + + text = torch.arange(bs * text_len * hidden, dtype=torch.float).reshape(bs, text_len, hidden) + visual = torch.ones(bs, 2, hidden) * 99.0 + + fused = adapter._splice_visual_into_text(text, visual, ids) + # Position 0 (START) and 3 (END) must be unchanged. + assert torch.allclose(fused[0, 0], text[0, 0]) + assert torch.allclose(fused[0, 3], text[0, 3]) + + def test_no_context_tokens_returns_text_unchanged(self): + """When there are no IMG_CONTEXT tokens the output must equal text_embeds.""" + adapter = _make_internvl_adapter() + bs, text_len, hidden = 2, 6, 8 + ids = torch.zeros(bs, text_len, dtype=torch.long) + text = torch.randn(bs, text_len, hidden) + visual = torch.randn(bs, 4, hidden) + + fused = adapter._splice_visual_into_text(text, visual, ids) + assert torch.allclose(fused, text) + + def test_multi_image_replacement(self): + """Two separate runs of context tokens correspond to two images.""" + adapter = _make_internvl_adapter() + bs, text_len, hidden = 1, 10, 4 + # Image 1: positions 1-2, Image 2: positions 6-7 + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[0, 1] = _CONTEXT_ID + ids[0, 2] = _CONTEXT_ID + ids[0, 6] = _CONTEXT_ID + ids[0, 7] = _CONTEXT_ID + + text = torch.zeros(bs, text_len, hidden) + # First 2 visual tokens = 1.0, next 2 = 2.0 + visual = torch.cat([torch.ones(bs, 2, hidden), torch.full((bs, 2, hidden), 2.0)], dim=1) + + fused = adapter._splice_visual_into_text(text, visual, ids) + assert fused.shape == (bs, text_len, hidden) + assert torch.allclose(fused[0, 1:3], torch.ones(2, hidden)) + assert torch.allclose(fused[0, 6:8], torch.full((2, hidden), 2.0)) + + def test_forward_end_to_end_shape(self): + """Full forward pass returns the correct shard shape.""" + world_size = 2 + pg = _make_mock_process_group(world_size=world_size, rank=0) + adapter = InternVLFusionAdapter(nn.Identity(), pg, image_token_id=_CONTEXT_ID) + + bs, local_v, text_len, hidden = 1, 3, 8, 4 + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[0, 2:5] = _CONTEXT_ID # 3 context tokens; local_v * world_size = 6 total + visual = torch.randn(bs, local_v, hidden) + text = torch.randn(bs, text_len, hidden) + + out = adapter(visual, text, ids) + # fused_len == text_len == 8 (length-preserving); padded to 8 (divisible by 2); local = 4 + assert out.shape == (bs, 4, hidden) + + +# --------------------------------------------------------------------------- +# Qwen2VLFusionAdapter tests (tests _splice_visual_into_text directly) +# --------------------------------------------------------------------------- + +_VIS_START_ID = 151652 +_VIS_END_ID = 151653 + + +def _make_qwen2vl_adapter(world_size=2, rank=0): + pg = _make_mock_process_group(world_size=world_size, rank=rank) + return Qwen2VLFusionAdapter(nn.Identity(), + pg, + vision_start_token_id=_VIS_START_ID, + vision_end_token_id=_VIS_END_ID) + + +class TestQwen2VLFusionAdapter: + + def test_inner_tokens_replaced_with_visual(self): + """Tokens between vision_start and vision_end must become visual embeddings.""" + adapter = _make_qwen2vl_adapter() + bs, text_len, hidden = 1, 7, 4 + # [t0, t1, , pad, pad, , t6] + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[0, 2] = _VIS_START_ID + ids[0, 5] = _VIS_END_ID + + text = torch.zeros(bs, text_len, hidden) + visual = torch.ones(bs, 2, hidden) * 5.0 + + fused = adapter._splice_visual_into_text(text, visual, ids) + assert torch.allclose(fused[0, 3:5], visual[0]) + + def test_sequence_length_preserved(self): + """Output length must equal input length (1-to-1 replacement).""" + adapter = _make_qwen2vl_adapter() + bs, text_len, hidden = 2, 12, 8 + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[:, 2] = _VIS_START_ID + ids[:, 8] = _VIS_END_ID # 5 inner placeholder tokens + text = torch.randn(bs, text_len, hidden) + visual = torch.randn(bs, 5, hidden) + + fused = adapter._splice_visual_into_text(text, visual, ids) + assert fused.shape == (bs, text_len, hidden) + + def test_boundary_tokens_preserved(self): + """vision_start and vision_end embeddings must be unchanged after splice.""" + adapter = _make_qwen2vl_adapter() + bs, text_len, hidden = 1, 6, 4 + # [t0, , pad, pad, , t5] + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[0, 1] = _VIS_START_ID + ids[0, 4] = _VIS_END_ID + + text = torch.arange(bs * text_len * hidden, dtype=torch.float).reshape(bs, text_len, hidden) + visual = torch.ones(bs, 2, hidden) * 99.0 + + fused = adapter._splice_visual_into_text(text, visual, ids) + assert torch.allclose(fused[0, 1], text[0, 1]) # vision_start preserved + assert torch.allclose(fused[0, 4], text[0, 4]) # vision_end preserved + + def test_no_vision_tokens_returns_text_unchanged(self): + """When there are no vision_start/end tokens the output must equal text_embeds.""" + adapter = _make_qwen2vl_adapter() + bs, text_len, hidden = 2, 8, 4 + ids = torch.zeros(bs, text_len, dtype=torch.long) + text = torch.randn(bs, text_len, hidden) + visual = torch.randn(bs, 4, hidden) + + fused = adapter._splice_visual_into_text(text, visual, ids) + assert torch.allclose(fused, text) + + def test_multi_image_replacement(self): + """Two vision blocks are handled independently.""" + adapter = _make_qwen2vl_adapter() + bs, text_len, hidden = 1, 14, 4 + # Block 1: positions 1 (start) .. 4 (end), 2 inner tokens at 2-3 + # Block 2: positions 8 (start) .. 12 (end), 3 inner tokens at 9-11 + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[0, 1] = _VIS_START_ID + ids[0, 4] = _VIS_END_ID + ids[0, 8] = _VIS_START_ID + ids[0, 12] = _VIS_END_ID + + text = torch.zeros(bs, text_len, hidden) + visual = torch.cat([torch.ones(bs, 2, hidden), torch.full((bs, 3, hidden), 2.0)], dim=1) + + fused = adapter._splice_visual_into_text(text, visual, ids) + assert fused.shape == (bs, text_len, hidden) + assert torch.allclose(fused[0, 2:4], torch.ones(2, hidden)) + assert torch.allclose(fused[0, 9:12], torch.full((3, hidden), 2.0)) + + def test_forward_end_to_end_shape(self): + """Full forward pass returns the correct shard shape.""" + world_size = 2 + pg = _make_mock_process_group(world_size=world_size, rank=0) + adapter = Qwen2VLFusionAdapter(nn.Identity(), + pg, + vision_start_token_id=_VIS_START_ID, + vision_end_token_id=_VIS_END_ID) + + bs, local_v, text_len, hidden = 1, 3, 10, 4 + ids = torch.zeros(bs, text_len, dtype=torch.long) + # 6 inner placeholder tokens (local_v * world_size = 6) + ids[0, 1] = _VIS_START_ID + ids[0, 8] = _VIS_END_ID + visual = torch.randn(bs, local_v, hidden) + text = torch.randn(bs, text_len, hidden) + + out = adapter(visual, text, ids) + # fused_len == text_len == 10 (length-preserving); padded to 10; local = 5 + assert out.shape == (bs, 5, hidden) diff --git a/tests/unit/sequence_parallelism/test_autosp_equivalence.py b/tests/unit/sequence_parallelism/test_autosp_equivalence.py index bb05aabe8ec8..bb473b4ec5e1 100644 --- a/tests/unit/sequence_parallelism/test_autosp_equivalence.py +++ b/tests/unit/sequence_parallelism/test_autosp_equivalence.py @@ -19,7 +19,7 @@ import deepspeed.comm as dist from deepspeed.sequence.autosp_vit import UlyssesSPViTAttention -from deepspeed.sequence.autosp_fusion import LlavaFusionAdapter +from deepspeed.sequence.autosp_fusion import InternVLFusionAdapter, LlavaFusionAdapter, Qwen2VLFusionAdapter from deepspeed.accelerator import get_accelerator from unit.common import DistributedTest @@ -189,3 +189,197 @@ def test_no_image_token_passthrough(self): assert torch.allclose(full_sp_out, ref_fused, atol=1e-5), (f"rank={rank} no-image path differs from reference: " f"max_diff={( full_sp_out - ref_fused).abs().max().item():.2e}") + + +# --------------------------------------------------------------------------- +# InternVLFusionAdapter equivalence +# --------------------------------------------------------------------------- + +_INTERNVL_CONTEXT_TOKEN_ID = 92546 + + +class TestInternVLFusionEquivalence(DistributedTest): + """Verifies that the SP gather/scatter in InternVLFusionAdapter is a lossless + round-trip: concatenating all ranks' output shards reproduces the full fused + sequence that single-device splicing would produce. + + InternVL replaces IMG_CONTEXT tokens 1-to-1 with visual tokens, so the + sequence length is preserved. + """ + + world_size = 2 + + def _build_inputs(self, bs, local_v, text_len, hidden, rank, num_ctx_tokens): + """Build deterministic inputs with a run of IMG_CONTEXT tokens in the middle.""" + torch.manual_seed(2) + full_visual = torch.randn(bs, local_v * self.world_size, hidden).to(get_accelerator().device_name()) + text = torch.randn(bs, text_len, hidden).to(get_accelerator().device_name()) + ids = torch.zeros(bs, text_len, dtype=torch.long).to(get_accelerator().device_name()) + # Place IMG_CONTEXT tokens starting at position 2. + ids[:, 2:2 + num_ctx_tokens] = _INTERNVL_CONTEXT_TOKEN_ID + local_visual = full_visual[:, rank * local_v:(rank + 1) * local_v, :] + return full_visual, local_visual, text, ids + + def test_shards_reassemble_to_full_fused(self): + """Gathering all ranks' output shards must equal the single-device + fused sequence (modulo padding zeros).""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + rank = dist.get_rank(sp_group) + + bs, local_v, text_len, hidden = 1, 3, 8, 4 + full_visual, local_visual, text, ids = self._build_inputs(bs, + local_v, + text_len, + hidden, + rank, + num_ctx_tokens=local_v * self.world_size) + + # SP path. + adapter = InternVLFusionAdapter(nn.Identity(), sp_group, + image_token_id=_INTERNVL_CONTEXT_TOKEN_ID).to(get_accelerator().device_name()) + local_out = adapter(local_visual, text, ids) + + gathered = [torch.zeros_like(local_out) for _ in range(self.world_size)] + dist.all_gather(gathered, local_out, group=sp_group) + full_sp_out = torch.cat(gathered, dim=1) + + # Single-device reference. + ref_adapter = InternVLFusionAdapter(nn.Identity(), sp_group, image_token_id=_INTERNVL_CONTEXT_TOKEN_ID).to( + get_accelerator().device_name()) + ref_fused = ref_adapter._splice_visual_into_text(text, full_visual, ids) + + fused_len = ref_fused.shape[1] + pad = (self.world_size - fused_len % self.world_size) % self.world_size + if pad > 0: + ref_fused = F.pad(ref_fused, (0, 0, 0, pad)) + + assert torch.allclose(full_sp_out, ref_fused, + atol=1e-5), (f"rank={rank} InternVL reassembled output differs from reference: " + f"max_diff={( full_sp_out - ref_fused).abs().max().item():.2e}") + + def test_no_context_token_passthrough(self): + """When there are no IMG_CONTEXT tokens the fused output must equal the text.""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + rank = dist.get_rank(sp_group) + + bs, local_v, text_len, hidden = 1, 2, 6, 4 + torch.manual_seed(3) + local_visual = torch.randn(bs, local_v, hidden).to(get_accelerator().device_name()) + text = torch.randn(bs, text_len, hidden).to(get_accelerator().device_name()) + ids = torch.zeros(bs, text_len, dtype=torch.long).to(get_accelerator().device_name()) + + adapter = InternVLFusionAdapter(nn.Identity(), sp_group, + image_token_id=_INTERNVL_CONTEXT_TOKEN_ID).to(get_accelerator().device_name()) + local_out = adapter(local_visual, text, ids) + + gathered = [torch.zeros_like(local_out) for _ in range(self.world_size)] + dist.all_gather(gathered, local_out, group=sp_group) + full_sp_out = torch.cat(gathered, dim=1) + + ref_fused = text + pad = (self.world_size - ref_fused.shape[1] % self.world_size) % self.world_size + if pad > 0: + ref_fused = F.pad(ref_fused, (0, 0, 0, pad)) + + assert torch.allclose(full_sp_out, ref_fused, + atol=1e-5), (f"rank={rank} InternVL no-context path differs from reference: " + f"max_diff={( full_sp_out - ref_fused).abs().max().item():.2e}") + + +# --------------------------------------------------------------------------- +# Qwen2VLFusionAdapter equivalence +# --------------------------------------------------------------------------- + +_QWEN2VL_START_ID = 151652 +_QWEN2VL_END_ID = 151653 + + +class TestQwen2VLFusionEquivalence(DistributedTest): + """Verifies that the SP gather/scatter in Qwen2VLFusionAdapter is a lossless + round-trip: concatenating all ranks' output shards reproduces the full fused + sequence that single-device splicing would produce. + + Qwen2-VL replaces inner placeholder tokens (between vision_start/end pairs) + 1-to-1 with visual tokens, so the sequence length is preserved. + """ + + world_size = 2 + + def _build_inputs(self, bs, local_v, text_len, hidden, rank, num_inner): + """Build inputs with a single vision_start/end block containing num_inner placeholders.""" + torch.manual_seed(4) + full_visual = torch.randn(bs, local_v * self.world_size, hidden).to(get_accelerator().device_name()) + text = torch.randn(bs, text_len, hidden).to(get_accelerator().device_name()) + ids = torch.zeros(bs, text_len, dtype=torch.long).to(get_accelerator().device_name()) + # [t0, , pad×num_inner, , ...] + ids[:, 1] = _QWEN2VL_START_ID + ids[:, 2 + num_inner] = _QWEN2VL_END_ID + local_visual = full_visual[:, rank * local_v:(rank + 1) * local_v, :] + return full_visual, local_visual, text, ids + + def test_shards_reassemble_to_full_fused(self): + """Gathering all ranks' output shards must equal the single-device + fused sequence (modulo padding zeros).""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + rank = dist.get_rank(sp_group) + + bs, local_v, text_len, hidden = 1, 3, 10, 4 + num_inner = local_v * self.world_size # inner placeholder count equals total visual tokens + full_visual, local_visual, text, ids = self._build_inputs(bs, local_v, text_len, hidden, rank, num_inner) + + # SP path. + adapter = Qwen2VLFusionAdapter(nn.Identity(), + sp_group, + vision_start_token_id=_QWEN2VL_START_ID, + vision_end_token_id=_QWEN2VL_END_ID).to(get_accelerator().device_name()) + local_out = adapter(local_visual, text, ids) + + gathered = [torch.zeros_like(local_out) for _ in range(self.world_size)] + dist.all_gather(gathered, local_out, group=sp_group) + full_sp_out = torch.cat(gathered, dim=1) + + # Single-device reference. + ref_adapter = Qwen2VLFusionAdapter(nn.Identity(), + sp_group, + vision_start_token_id=_QWEN2VL_START_ID, + vision_end_token_id=_QWEN2VL_END_ID).to(get_accelerator().device_name()) + ref_fused = ref_adapter._splice_visual_into_text(text, full_visual, ids) + + fused_len = ref_fused.shape[1] + pad = (self.world_size - fused_len % self.world_size) % self.world_size + if pad > 0: + ref_fused = F.pad(ref_fused, (0, 0, 0, pad)) + + assert torch.allclose(full_sp_out, ref_fused, + atol=1e-5), (f"rank={rank} Qwen2VL reassembled output differs from reference: " + f"max_diff={( full_sp_out - ref_fused).abs().max().item():.2e}") + + def test_no_vision_token_passthrough(self): + """When there are no vision_start/end tokens the fused output must equal the text.""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + rank = dist.get_rank(sp_group) + + bs, local_v, text_len, hidden = 1, 2, 8, 4 + torch.manual_seed(5) + local_visual = torch.randn(bs, local_v, hidden).to(get_accelerator().device_name()) + text = torch.randn(bs, text_len, hidden).to(get_accelerator().device_name()) + ids = torch.zeros(bs, text_len, dtype=torch.long).to(get_accelerator().device_name()) + + adapter = Qwen2VLFusionAdapter(nn.Identity(), + sp_group, + vision_start_token_id=_QWEN2VL_START_ID, + vision_end_token_id=_QWEN2VL_END_ID).to(get_accelerator().device_name()) + local_out = adapter(local_visual, text, ids) + + gathered = [torch.zeros_like(local_out) for _ in range(self.world_size)] + dist.all_gather(gathered, local_out, group=sp_group) + full_sp_out = torch.cat(gathered, dim=1) + + ref_fused = text + pad = (self.world_size - ref_fused.shape[1] % self.world_size) % self.world_size + if pad > 0: + ref_fused = F.pad(ref_fused, (0, 0, 0, pad)) + + assert torch.allclose(full_sp_out, ref_fused, + atol=1e-5), (f"rank={rank} Qwen2VL no-vision path differs from reference: " + f"max_diff={( full_sp_out - ref_fused).abs().max().item():.2e}") From 2fd81bcd4ed5824f1836ad54307db6963376c049 Mon Sep 17 00:00:00 2001 From: nathon-lee Date: Sat, 25 Apr 2026 13:02:34 +0800 Subject: [PATCH 14/21] [AutoSP] Add InternVLFusionAdapter and Qwen2VLFusionAdapter Signed-off-by: nathon-lee --- tests/unit/sequence_parallelism/test_autosp_equivalence.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/sequence_parallelism/test_autosp_equivalence.py b/tests/unit/sequence_parallelism/test_autosp_equivalence.py index bb473b4ec5e1..391d3ae264cc 100644 --- a/tests/unit/sequence_parallelism/test_autosp_equivalence.py +++ b/tests/unit/sequence_parallelism/test_autosp_equivalence.py @@ -6,10 +6,10 @@ Each test verifies that running the SP-wrapped path across N ranks produces the same result as the equivalent single-device (non-SP) computation. -These tests require 2 GPUs and the NCCL backend. +These tests require 2 GPUs. Run with: - deepspeed --num_gpus 2 --no_local_rank --module pytest tests/unit/sequence_parallelism/test_autosp_equivalence.py -v + NCCL_P2P_DISABLE=1 python -m pytest tests/unit/sequence_parallelism/test_autosp_equivalence.py -v """ import torch From 6bf1ec312c1c7b4a64a30c0a05b4aabd4ff18209 Mon Sep 17 00:00:00 2001 From: nathon-lee Date: Sat, 25 Apr 2026 13:40:21 +0800 Subject: [PATCH 15/21] [AutoSP] Add integration tests and multimodal SP benchmark Signed-off-by: nathon-lee fix: fix some format err by tool Signed-off-by: nathon-lee --- benchmarks/autosp/bench_multimodal_sp.py | 270 +++++++++++++++++ .../test_autosp_integration.py | 283 ++++++++++++++++++ 2 files changed, 553 insertions(+) create mode 100644 benchmarks/autosp/bench_multimodal_sp.py create mode 100644 tests/unit/sequence_parallelism/test_autosp_integration.py diff --git a/benchmarks/autosp/bench_multimodal_sp.py b/benchmarks/autosp/bench_multimodal_sp.py new file mode 100644 index 000000000000..d83fb46a412d --- /dev/null +++ b/benchmarks/autosp/bench_multimodal_sp.py @@ -0,0 +1,270 @@ +# SPDX-License-Identifier: Apache-2.0 +# DeepSpeed Team +""" +Benchmark: AutoSP multimodal sequence parallelism (ViT SP + fusion adapter). + +Measures per-iteration latency, throughput, and peak GPU memory for the +ViT-SP + fusion-adapter pipeline at a given SP degree. + +Launch (from repo root): + + # SP degree 2 — two GPUs: + NCCL_P2P_DISABLE=1 torchrun --nproc_per_node=2 \\ + benchmarks/autosp/bench_multimodal_sp.py [args] + + # Baseline — single GPU (all-gather/scatter are no-ops): + torchrun --nproc_per_node=1 \\ + benchmarks/autosp/bench_multimodal_sp.py [args] + +Compare the two output tables to quantify memory savings and throughput scaling. + +Arguments: + --arch {internvl, qwen2vl} architecture to simulate (default: internvl) + --batch-size N samples per batch (default: 2) + --seq-len N text sequence length (default: 512) + --visual-tokens N total visual tokens per sample (default: 256) + --hidden N hidden dimension (default: 1024) + --num-layers N ViT and LLM layers each (default: 2) + --iters N measured iterations (default: 50) + --warmup N warmup iterations (default: 10) +""" + +import argparse +import statistics + +import torch +import torch.nn as nn + +import deepspeed +import deepspeed.comm as dist +from deepspeed.accelerator import get_accelerator +from deepspeed.sequence.auto_sp import auto_wrap_model_for_sp +from deepspeed.sequence.autosp_vit import UlyssesSPViTAttention +from deepspeed.sequence.autosp_fusion import InternVLFusionAdapter, Qwen2VLFusionAdapter + +# --------------------------------------------------------------------------- +# Token IDs +# --------------------------------------------------------------------------- + +_INTERNVL_CONTEXT_ID = 92546 +_QWEN2VL_START_ID = 151652 +_QWEN2VL_END_ID = 151653 + +# --------------------------------------------------------------------------- +# Mock attention classes — names match autosp_detector registries exactly +# --------------------------------------------------------------------------- + + +class InternVisionAttention(nn.Module): + + def forward(self, hidden_states, **kwargs): + return hidden_states + + +class InternLM2Attention(nn.Module): + + def forward(self, hidden_states, **kwargs): + return hidden_states + + +class Qwen2VLVisionAttention(nn.Module): + + def forward(self, hidden_states, **kwargs): + return hidden_states + + +class Qwen2Attention(nn.Module): + + def forward(self, hidden_states, **kwargs): + return hidden_states + + +# --------------------------------------------------------------------------- +# Model building blocks +# --------------------------------------------------------------------------- + + +class _ViTBlock(nn.Module): + """One ViT transformer block: attention (to be SP-wrapped) + linear FFN.""" + + def __init__(self, attn_cls, hidden: int) -> None: + super().__init__() + self.attn = attn_cls() + self.ffn = nn.Linear(hidden, hidden, bias=False) + + def forward(self, x, **kwargs): + out = self.attn(x, **kwargs) + if isinstance(out, (tuple, list)): + out = out[0] + return self.ffn(out) + + +class _MinimalInternVLModel(nn.Module): + """InternVL-like benchmark model. + + Module paths detected by autosp_detector: + - ``vision_encoder.*.attn`` -> InternVisionAttention (_VIT_ATTN_CLASSNAMES) + - ``mm_projector`` -> keyword in _VISION_PROJ_KEYWORDS + + ``language_model`` uses plain nn.Linear layers so it is NOT wrapped by + DistributedAttention (avoids the Q/K/V interface requirement) yet still + contributes realistic compute on the scattered fused sequence. + """ + + def __init__(self, hidden: int, num_layers: int) -> None: + super().__init__() + self.vision_encoder = nn.Sequential(*[_ViTBlock(InternVisionAttention, hidden) for _ in range(num_layers)]) + self.mm_projector = nn.Identity() + self.language_model = nn.Sequential(*[nn.Linear(hidden, hidden, bias=False) for _ in range(num_layers)]) + self.fusion = None + + def forward(self, local_patches: torch.Tensor, text_embeds: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor: + local_visual = self.vision_encoder(local_patches) + local_fused = self.fusion(local_visual, text_embeds, input_ids) + return self.language_model(local_fused) + + +class _MinimalQwen2VLModel(nn.Module): + """Qwen2VL-like benchmark model.""" + + def __init__(self, hidden: int, num_layers: int) -> None: + super().__init__() + self.visual = nn.Sequential(*[_ViTBlock(Qwen2VLVisionAttention, hidden) for _ in range(num_layers)]) + self.multi_modal_projector = nn.Identity() + self.model = nn.Sequential(*[nn.Linear(hidden, hidden, bias=False) for _ in range(num_layers)]) + self.fusion = None + + def forward(self, local_patches: torch.Tensor, text_embeds: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor: + local_visual = self.visual(local_patches) + local_fused = self.fusion(local_visual, text_embeds, input_ids) + return self.model(local_fused) + + +# --------------------------------------------------------------------------- +# Setup helpers +# --------------------------------------------------------------------------- + + +def _build_model_and_inputs(arch: str, args, sp_group, device): + rank = dist.get_rank(sp_group) + world_size = dist.get_world_size(sp_group) + + local_v = args.visual_tokens // world_size + bs, text_len, hidden = args.batch_size, args.seq_len, args.hidden + + torch.manual_seed(0) + local_patches = torch.randn(bs, local_v, hidden, device=device) + text_embeds = torch.randn(bs, text_len, hidden, device=device) + input_ids = torch.zeros(bs, text_len, dtype=torch.long, device=device) + + if arch == "internvl": + num_ctx = min(local_v * world_size, text_len - 2) + input_ids[:, 2:2 + num_ctx] = _INTERNVL_CONTEXT_ID + + model = _MinimalInternVLModel(hidden, args.num_layers).to(device) + auto_wrap_model_for_sp(model, sp_group) + for m in model.modules(): + if isinstance(m, UlyssesSPViTAttention): + m.has_cls_token = False + model.fusion = InternVLFusionAdapter(model.mm_projector, sp_group, + image_token_id=_INTERNVL_CONTEXT_ID).to(device) + else: # qwen2vl + num_inner = min(local_v * world_size, text_len - 3) + input_ids[:, 1] = _QWEN2VL_START_ID + input_ids[:, 2 + num_inner] = _QWEN2VL_END_ID + + model = _MinimalQwen2VLModel(hidden, args.num_layers).to(device) + auto_wrap_model_for_sp(model, sp_group) + for m in model.modules(): + if isinstance(m, UlyssesSPViTAttention): + m.has_cls_token = False + model.fusion = Qwen2VLFusionAdapter(model.multi_modal_projector, + sp_group, + vision_start_token_id=_QWEN2VL_START_ID, + vision_end_token_id=_QWEN2VL_END_ID).to(device) + + return model, local_patches, text_embeds, input_ids + + +# --------------------------------------------------------------------------- +# Benchmark runner +# --------------------------------------------------------------------------- + + +def _run(arch: str, args) -> None: + deepspeed.init_distributed(dist_backend="nccl") + rank = dist.get_rank() + world_size = dist.get_world_size() + device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") + torch.cuda.set_device(device) + + sp_group = dist.new_group(ranks=list(range(world_size))) + model, local_patches, text_embeds, input_ids = _build_model_and_inputs(arch, args, sp_group, device) + model.eval() + + # Warmup + with torch.no_grad(): + for _ in range(args.warmup): + model(local_patches, text_embeds, input_ids) + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats(device) + + # Timed iterations using CUDA events for accurate GPU-side measurement. + latencies_ms = [] + with torch.no_grad(): + for _ in range(args.iters): + t_start = torch.cuda.Event(enable_timing=True) + t_end = torch.cuda.Event(enable_timing=True) + t_start.record() + model(local_patches, text_embeds, input_ids) + t_end.record() + torch.cuda.synchronize() + latencies_ms.append(t_start.elapsed_time(t_end)) + + peak_mem_mb = torch.cuda.max_memory_allocated(device) / 1024**2 + mean_ms = statistics.mean(latencies_ms) + std_ms = statistics.stdev(latencies_ms) if len(latencies_ms) > 1 else 0.0 + # tokens/s: fused sequence length approximated by seq_len (length-preserving adapters). + throughput = (args.batch_size * args.seq_len) / (mean_ms / 1000.0) + + if rank == 0: + sep = "=" * 62 + print(f"\n{sep}") + print(f" AutoSP Benchmark arch={arch} sp_degree={world_size}") + print(sep) + print(f" batch_size : {args.batch_size}") + print(f" seq_len : {args.seq_len}") + print(f" visual_tokens : {args.visual_tokens} (local={args.visual_tokens // world_size}/rank)") + print(f" hidden : {args.hidden}") + print(f" num_layers : {args.num_layers}") + print(f" warmup / iters : {args.warmup} / {args.iters}") + print(f" {'─' * 58}") + print(f" Latency : {mean_ms:.2f} ± {std_ms:.2f} ms/iter") + print(f" Throughput : {throughput:,.0f} tokens/s") + print(f" Peak GPU memory : {peak_mem_mb:.1f} MB") + print(f"{sep}\n") + + +def main() -> None: + parser = argparse.ArgumentParser(description="AutoSP multimodal SP benchmark") + parser.add_argument("--arch", + choices=["internvl", "qwen2vl"], + default="internvl", + help="Model architecture to simulate") + parser.add_argument("--batch-size", type=int, default=2) + parser.add_argument("--seq-len", type=int, default=512) + parser.add_argument("--visual-tokens", + type=int, + default=256, + help="Total visual tokens (must be divisible by --nproc_per_node)") + parser.add_argument("--hidden", type=int, default=1024) + parser.add_argument("--num-layers", type=int, default=2, help="Number of ViT blocks and LLM linear layers each") + parser.add_argument("--iters", type=int, default=50) + parser.add_argument("--warmup", type=int, default=10) + args = parser.parse_args() + + _run(args.arch, args) + + +if __name__ == "__main__": + main() diff --git a/tests/unit/sequence_parallelism/test_autosp_integration.py b/tests/unit/sequence_parallelism/test_autosp_integration.py new file mode 100644 index 000000000000..3412283535fd --- /dev/null +++ b/tests/unit/sequence_parallelism/test_autosp_integration.py @@ -0,0 +1,283 @@ +# SPDX-License-Identifier: Apache-2.0 +# DeepSpeed Team +""" +End-to-end integration tests for AutoSP multimodal sequence parallelism. + +Each test builds a minimal mock model whose attention-layer class names match +the autosp_detector registry, then verifies two things: + +1. auto_wrap_model_for_sp correctly identifies and wraps the attention modules. +2. The full pipeline (SP-wrapped ViT -> fusion adapter) produces fused output + numerically equivalent to the single-device splice reference. + +The LLM decoder branch is intentionally not called in the forward-pass +equivalence tests because DistributedAttention uses a Megatron Q/K/V +interface that is incompatible with the simple hidden_states mock. Its +correct injection is verified by the detection tests instead. + +These tests require 2 GPUs. +Run with: + + NCCL_P2P_DISABLE=1 python -m pytest tests/unit/sequence_parallelism/test_autosp_integration.py -v +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import deepspeed.comm as dist +from deepspeed.sequence.auto_sp import auto_wrap_model_for_sp +from deepspeed.sequence.autosp_vit import UlyssesSPViTAttention +from deepspeed.sequence.layer import DistributedAttention +from deepspeed.sequence.autosp_fusion import InternVLFusionAdapter, Qwen2VLFusionAdapter +from deepspeed.accelerator import get_accelerator + +from unit.common import DistributedTest + +# --------------------------------------------------------------------------- +# Token IDs +# --------------------------------------------------------------------------- + +_INTERNVL_CONTEXT_ID = 92546 +_QWEN2VL_START_ID = 151652 +_QWEN2VL_END_ID = 151653 + +# --------------------------------------------------------------------------- +# Mock attention classes +# +# Class names must match exactly the entries in autosp_detector._VIT_ATTN_CLASSNAMES +# and _LLM_ATTN_CLASSNAMES so that auto_wrap_model_for_sp detects them. +# --------------------------------------------------------------------------- + + +class InternVisionAttention(nn.Module): + """Mock ViT attention for InternVL (registered in _VIT_ATTN_CLASSNAMES).""" + + def forward(self, hidden_states, **kwargs): + return hidden_states + + +class InternLM2Attention(nn.Module): + """Mock LLM attention for InternVL (registered in _LLM_ATTN_CLASSNAMES).""" + + def forward(self, hidden_states, **kwargs): + return hidden_states + + +class Qwen2VLVisionAttention(nn.Module): + """Mock ViT attention for Qwen2-VL (registered in _VIT_ATTN_CLASSNAMES).""" + + def forward(self, hidden_states, **kwargs): + return hidden_states + + +class Qwen2Attention(nn.Module): + """Mock LLM attention for Qwen2-VL (registered in _LLM_ATTN_CLASSNAMES).""" + + def forward(self, hidden_states, **kwargs): + return hidden_states + + +# --------------------------------------------------------------------------- +# Model skeleton helpers +# --------------------------------------------------------------------------- + + +class _AttnLayer(nn.Module): + """Generic transformer block that holds an attention submodule. + + auto_wrap_model_for_sp scans named_modules() and replaces ``self.attn`` + when its class name is in the detector's registry. + """ + + def __init__(self, attn: nn.Module) -> None: + super().__init__() + self.attn = attn + + def forward(self, x, **kwargs): + return self.attn(x, **kwargs) + + +class _MinimalInternVLModel(nn.Module): + """Minimal InternVL-like skeleton for integration testing. + + Module paths recognised by autosp_detector: + - ``vision_encoder.0.attn`` -> InternVisionAttention (_VIT_ATTN_CLASSNAMES) + - ``language_model.0.attn`` -> InternLM2Attention (_LLM_ATTN_CLASSNAMES) + - ``mm_projector`` -> keyword in _VISION_PROJ_KEYWORDS + + ``forward`` exercises only the ViT + fusion path; ``language_model`` is + present solely for detection verification (DistributedAttention wrapping). + """ + + def __init__(self) -> None: + super().__init__() + self.vision_encoder = nn.Sequential(_AttnLayer(InternVisionAttention())) + self.mm_projector = nn.Identity() + self.language_model = nn.Sequential(_AttnLayer(InternLM2Attention())) + self.fusion = None + + def forward(self, local_patches: torch.Tensor, text_embeds: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor: + local_visual = self.vision_encoder(local_patches) + return self.fusion(local_visual, text_embeds, input_ids) + + +class _MinimalQwen2VLModel(nn.Module): + """Minimal Qwen2-VL-like skeleton for integration testing. + + Module paths recognised by autosp_detector: + - ``visual.0.attn`` -> Qwen2VLVisionAttention (_VIT_ATTN_CLASSNAMES) + - ``model.0.attn`` -> Qwen2Attention (_LLM_ATTN_CLASSNAMES) + - ``multi_modal_projector`` -> keyword in _VISION_PROJ_KEYWORDS + """ + + def __init__(self) -> None: + super().__init__() + self.visual = nn.Sequential(_AttnLayer(Qwen2VLVisionAttention())) + self.multi_modal_projector = nn.Identity() + self.model = nn.Sequential(_AttnLayer(Qwen2Attention())) + self.fusion = None + + def forward(self, local_patches: torch.Tensor, text_embeds: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor: + local_visual = self.visual(local_patches) + return self.fusion(local_visual, text_embeds, input_ids) + + +# --------------------------------------------------------------------------- +# InternVL integration tests +# --------------------------------------------------------------------------- + + +class TestInternVLIntegration(DistributedTest): + """Integration tests for the InternVL multimodal SP pipeline.""" + + world_size = 2 + + def test_auto_wrap_detects_and_wraps_modules(self): + """auto_wrap_model_for_sp must replace InternVisionAttention with + UlyssesSPViTAttention and InternLM2Attention with DistributedAttention.""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + model = _MinimalInternVLModel().to(get_accelerator().device_name()) + auto_wrap_model_for_sp(model, sp_group) + + assert isinstance( + model.vision_encoder[0].attn, + UlyssesSPViTAttention), ("Expected vision_encoder[0].attn to be UlyssesSPViTAttention after auto_wrap") + assert isinstance( + model.language_model[0].attn, + DistributedAttention), ("Expected language_model[0].attn to be DistributedAttention after auto_wrap") + + def test_full_pipeline_visual_to_fused(self): + """SP-wrapped ViT -> InternVLFusionAdapter must produce fused output + numerically equivalent to the single-device splice reference.""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + rank = dist.get_rank(sp_group) + + bs, local_v, text_len, hidden = 1, 4, 10, 8 + num_ctx = local_v * self.world_size + + torch.manual_seed(20) + full_visual = torch.randn(bs, local_v * self.world_size, hidden).to(get_accelerator().device_name()) + text = torch.randn(bs, text_len, hidden).to(get_accelerator().device_name()) + ids = torch.zeros(bs, text_len, dtype=torch.long).to(get_accelerator().device_name()) + ids[:, 2:2 + num_ctx] = _INTERNVL_CONTEXT_ID + + local_patches = full_visual[:, rank * local_v:(rank + 1) * local_v, :] + + model = _MinimalInternVLModel().to(get_accelerator().device_name()) + auto_wrap_model_for_sp(model, sp_group) + # The mock ViT has no CLS token; override the default set by auto_wrap. + for m in model.modules(): + if isinstance(m, UlyssesSPViTAttention): + m.has_cls_token = False + model.fusion = InternVLFusionAdapter(model.mm_projector, sp_group, + image_token_id=_INTERNVL_CONTEXT_ID).to(get_accelerator().device_name()) + + local_out = model(local_patches, text, ids) + + gathered = [torch.zeros_like(local_out) for _ in range(self.world_size)] + dist.all_gather(gathered, local_out, group=sp_group) + full_sp_out = torch.cat(gathered, dim=1) + + # Single-device reference: splice without SP scatter. + ref_adapter = InternVLFusionAdapter(nn.Identity(), sp_group, + image_token_id=_INTERNVL_CONTEXT_ID).to(get_accelerator().device_name()) + ref_fused = ref_adapter._splice_visual_into_text(text, full_visual, ids) + pad = (self.world_size - ref_fused.shape[1] % self.world_size) % self.world_size + if pad > 0: + ref_fused = F.pad(ref_fused, (0, 0, 0, pad)) + + assert torch.allclose(full_sp_out, ref_fused, + atol=1e-5), (f"rank={rank} InternVL full pipeline output differs from reference: " + f"max_diff={(full_sp_out - ref_fused).abs().max().item():.2e}") + + +# --------------------------------------------------------------------------- +# Qwen2-VL integration tests +# --------------------------------------------------------------------------- + + +class TestQwen2VLIntegration(DistributedTest): + """Integration tests for the Qwen2-VL multimodal SP pipeline.""" + + world_size = 2 + + def test_auto_wrap_detects_and_wraps_modules(self): + """auto_wrap_model_for_sp must replace Qwen2VLVisionAttention with + UlyssesSPViTAttention and Qwen2Attention with DistributedAttention.""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + model = _MinimalQwen2VLModel().to(get_accelerator().device_name()) + auto_wrap_model_for_sp(model, sp_group) + + assert isinstance( + model.visual[0].attn, + UlyssesSPViTAttention), ("Expected visual[0].attn to be UlyssesSPViTAttention after auto_wrap") + assert isinstance(model.model[0].attn, + DistributedAttention), ("Expected model[0].attn to be DistributedAttention after auto_wrap") + + def test_full_pipeline_visual_to_fused(self): + """SP-wrapped ViT -> Qwen2VLFusionAdapter must produce fused output + numerically equivalent to the single-device splice reference.""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + rank = dist.get_rank(sp_group) + + bs, local_v, text_len, hidden = 1, 3, 10, 8 + num_inner = local_v * self.world_size + + torch.manual_seed(21) + full_visual = torch.randn(bs, local_v * self.world_size, hidden).to(get_accelerator().device_name()) + text = torch.randn(bs, text_len, hidden).to(get_accelerator().device_name()) + ids = torch.zeros(bs, text_len, dtype=torch.long).to(get_accelerator().device_name()) + ids[:, 1] = _QWEN2VL_START_ID + ids[:, 2 + num_inner] = _QWEN2VL_END_ID + + local_patches = full_visual[:, rank * local_v:(rank + 1) * local_v, :] + + model = _MinimalQwen2VLModel().to(get_accelerator().device_name()) + auto_wrap_model_for_sp(model, sp_group) + for m in model.modules(): + if isinstance(m, UlyssesSPViTAttention): + m.has_cls_token = False + model.fusion = Qwen2VLFusionAdapter(model.multi_modal_projector, + sp_group, + vision_start_token_id=_QWEN2VL_START_ID, + vision_end_token_id=_QWEN2VL_END_ID).to(get_accelerator().device_name()) + + local_out = model(local_patches, text, ids) + + gathered = [torch.zeros_like(local_out) for _ in range(self.world_size)] + dist.all_gather(gathered, local_out, group=sp_group) + full_sp_out = torch.cat(gathered, dim=1) + + ref_adapter = Qwen2VLFusionAdapter(nn.Identity(), + sp_group, + vision_start_token_id=_QWEN2VL_START_ID, + vision_end_token_id=_QWEN2VL_END_ID).to(get_accelerator().device_name()) + ref_fused = ref_adapter._splice_visual_into_text(text, full_visual, ids) + pad = (self.world_size - ref_fused.shape[1] % self.world_size) % self.world_size + if pad > 0: + ref_fused = F.pad(ref_fused, (0, 0, 0, pad)) + + assert torch.allclose(full_sp_out, ref_fused, + atol=1e-5), (f"rank={rank} Qwen2VL full pipeline output differs from reference: " + f"max_diff={(full_sp_out - ref_fused).abs().max().item():.2e}") From 4e7726e2a466a79580314b41c2016f82bb3e2a6d Mon Sep 17 00:00:00 2001 From: nathon-lee Date: Sat, 25 Apr 2026 16:05:18 +0800 Subject: [PATCH 16/21] fix: fix some warn errs . Signed-off-by: nathon-lee fix: delete get_accelerator for not use. Signed-off-by: nathon-lee --- benchmarks/autosp/bench_multimodal_sp.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/benchmarks/autosp/bench_multimodal_sp.py b/benchmarks/autosp/bench_multimodal_sp.py index d83fb46a412d..96b5617b190b 100644 --- a/benchmarks/autosp/bench_multimodal_sp.py +++ b/benchmarks/autosp/bench_multimodal_sp.py @@ -30,6 +30,7 @@ """ import argparse +import logging import statistics import torch @@ -37,7 +38,6 @@ import deepspeed import deepspeed.comm as dist -from deepspeed.accelerator import get_accelerator from deepspeed.sequence.auto_sp import auto_wrap_model_for_sp from deepspeed.sequence.autosp_vit import UlyssesSPViTAttention from deepspeed.sequence.autosp_fusion import InternVLFusionAdapter, Qwen2VLFusionAdapter @@ -162,7 +162,12 @@ def _build_model_and_inputs(arch: str, args, sp_group, device): input_ids[:, 2:2 + num_ctx] = _INTERNVL_CONTEXT_ID model = _MinimalInternVLModel(hidden, args.num_layers).to(device) + # Suppress the Phase 2 projection-layer warning: we wrap manually below. + _auto_sp_logger = logging.getLogger("deepspeed.sequence.auto_sp") + _prev_level = _auto_sp_logger.level + _auto_sp_logger.setLevel(logging.ERROR) auto_wrap_model_for_sp(model, sp_group) + _auto_sp_logger.setLevel(_prev_level) for m in model.modules(): if isinstance(m, UlyssesSPViTAttention): m.has_cls_token = False @@ -174,7 +179,11 @@ def _build_model_and_inputs(arch: str, args, sp_group, device): input_ids[:, 2 + num_inner] = _QWEN2VL_END_ID model = _MinimalQwen2VLModel(hidden, args.num_layers).to(device) + _auto_sp_logger = logging.getLogger("deepspeed.sequence.auto_sp") + _prev_level = _auto_sp_logger.level + _auto_sp_logger.setLevel(logging.ERROR) auto_wrap_model_for_sp(model, sp_group) + _auto_sp_logger.setLevel(_prev_level) for m in model.modules(): if isinstance(m, UlyssesSPViTAttention): m.has_cls_token = False @@ -244,6 +253,8 @@ def _run(arch: str, args) -> None: print(f" Peak GPU memory : {peak_mem_mb:.1f} MB") print(f"{sep}\n") + dist.destroy_process_group() + def main() -> None: parser = argparse.ArgumentParser(description="AutoSP multimodal SP benchmark") From c21fe99b4e070cb4cc2552371e19a328118affae Mon Sep 17 00:00:00 2001 From: nathon-lee Date: Sat, 25 Apr 2026 16:33:07 +0800 Subject: [PATCH 17/21] fix: benchmarks/autosp: replace torch.cuda with get_accelerator() Signed-off-by: nathon-lee --- benchmarks/autosp/bench_multimodal_sp.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/benchmarks/autosp/bench_multimodal_sp.py b/benchmarks/autosp/bench_multimodal_sp.py index 96b5617b190b..4909298ee63c 100644 --- a/benchmarks/autosp/bench_multimodal_sp.py +++ b/benchmarks/autosp/bench_multimodal_sp.py @@ -38,6 +38,7 @@ import deepspeed import deepspeed.comm as dist +from deepspeed.accelerator import get_accelerator from deepspeed.sequence.auto_sp import auto_wrap_model_for_sp from deepspeed.sequence.autosp_vit import UlyssesSPViTAttention from deepspeed.sequence.autosp_fusion import InternVLFusionAdapter, Qwen2VLFusionAdapter @@ -204,8 +205,8 @@ def _run(arch: str, args) -> None: deepspeed.init_distributed(dist_backend="nccl") rank = dist.get_rank() world_size = dist.get_world_size() - device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") - torch.cuda.set_device(device) + device = torch.device(get_accelerator().device_name(), rank % get_accelerator().device_count()) + get_accelerator().set_device(rank % get_accelerator().device_count()) sp_group = dist.new_group(ranks=list(range(world_size))) model, local_patches, text_embeds, input_ids = _build_model_and_inputs(arch, args, sp_group, device) @@ -215,22 +216,22 @@ def _run(arch: str, args) -> None: with torch.no_grad(): for _ in range(args.warmup): model(local_patches, text_embeds, input_ids) - torch.cuda.synchronize() - torch.cuda.reset_peak_memory_stats(device) + get_accelerator().synchronize() + get_accelerator().reset_peak_memory_stats() # Timed iterations using CUDA events for accurate GPU-side measurement. latencies_ms = [] with torch.no_grad(): for _ in range(args.iters): - t_start = torch.cuda.Event(enable_timing=True) - t_end = torch.cuda.Event(enable_timing=True) + t_start = get_accelerator().Event(enable_timing=True) + t_end = get_accelerator().Event(enable_timing=True) t_start.record() model(local_patches, text_embeds, input_ids) t_end.record() - torch.cuda.synchronize() + get_accelerator().synchronize() latencies_ms.append(t_start.elapsed_time(t_end)) - peak_mem_mb = torch.cuda.max_memory_allocated(device) / 1024**2 + peak_mem_mb = get_accelerator().max_memory_allocated() / 1024**2 mean_ms = statistics.mean(latencies_ms) std_ms = statistics.stdev(latencies_ms) if len(latencies_ms) > 1 else 0.0 # tokens/s: fused sequence length approximated by seq_len (length-preserving adapters). From baa8e51061e04896f9876c67d432d169ed5f856b Mon Sep 17 00:00:00 2001 From: nathon-lee Date: Wed, 29 Apr 2026 09:52:47 +0800 Subject: [PATCH 18/21] [AutoSP] Fix ViT CLS handling and skip incompatible HF LLM wrapping Signed-off-by: nathon-lee --- benchmarks/autosp/bench_multimodal_sp.py | 7 --- deepspeed/sequence/auto_sp.py | 46 ++++++++++++------- deepspeed/sequence/autosp_detector.py | 13 ++++++ .../test_autosp_integration.py | 36 ++++++--------- 4 files changed, 57 insertions(+), 45 deletions(-) diff --git a/benchmarks/autosp/bench_multimodal_sp.py b/benchmarks/autosp/bench_multimodal_sp.py index 4909298ee63c..e6ad8faf86ce 100644 --- a/benchmarks/autosp/bench_multimodal_sp.py +++ b/benchmarks/autosp/bench_multimodal_sp.py @@ -40,7 +40,6 @@ import deepspeed.comm as dist from deepspeed.accelerator import get_accelerator from deepspeed.sequence.auto_sp import auto_wrap_model_for_sp -from deepspeed.sequence.autosp_vit import UlyssesSPViTAttention from deepspeed.sequence.autosp_fusion import InternVLFusionAdapter, Qwen2VLFusionAdapter # --------------------------------------------------------------------------- @@ -169,9 +168,6 @@ def _build_model_and_inputs(arch: str, args, sp_group, device): _auto_sp_logger.setLevel(logging.ERROR) auto_wrap_model_for_sp(model, sp_group) _auto_sp_logger.setLevel(_prev_level) - for m in model.modules(): - if isinstance(m, UlyssesSPViTAttention): - m.has_cls_token = False model.fusion = InternVLFusionAdapter(model.mm_projector, sp_group, image_token_id=_INTERNVL_CONTEXT_ID).to(device) else: # qwen2vl @@ -185,9 +181,6 @@ def _build_model_and_inputs(arch: str, args, sp_group, device): _auto_sp_logger.setLevel(logging.ERROR) auto_wrap_model_for_sp(model, sp_group) _auto_sp_logger.setLevel(_prev_level) - for m in model.modules(): - if isinstance(m, UlyssesSPViTAttention): - m.has_cls_token = False model.fusion = Qwen2VLFusionAdapter(model.multi_modal_projector, sp_group, vision_start_token_id=_QWEN2VL_START_ID, diff --git a/deepspeed/sequence/auto_sp.py b/deepspeed/sequence/auto_sp.py index d1315dc1259a..e6f5326fb11f 100644 --- a/deepspeed/sequence/auto_sp.py +++ b/deepspeed/sequence/auto_sp.py @@ -18,8 +18,10 @@ * :class:`~deepspeed.sequence.autosp_vit.UlyssesSPViTAttention` for ViT encoder attention layers. -* :class:`~deepspeed.sequence.layer.DistributedAttention` - for LLM decoder attention layers (Megatron-style Q/K/V interface). +* a warning for LLM decoder attention layers: HuggingFace-style + ``hidden_states`` attention is incompatible with + :class:`~deepspeed.sequence.layer.DistributedAttention`'s Q/K/V interface; + configure LLM sequence parallelism manually. The vision-language projection layer (Phase 2) is detected and a warning is emitted; wrap it manually with @@ -31,9 +33,8 @@ import torch.nn as nn -from deepspeed.sequence.autosp_detector import detect_model_sp_info +from deepspeed.sequence.autosp_detector import detect_model_sp_info, _VIT_HAS_CLS_TOKEN from deepspeed.sequence.autosp_vit import UlyssesSPViTAttention -from deepspeed.sequence.layer import DistributedAttention logger = logging.getLogger(__name__) @@ -45,7 +46,8 @@ def auto_wrap_model_for_sp(model: nn.Module, process_group) -> nn.Module: with their SP-aware equivalents: * ViT attention → :class:`UlyssesSPViTAttention` - * LLM attention → :class:`DistributedAttention` + * LLM attention → warning only (HuggingFace ``hidden_states`` interface + is incompatible with :class:`DistributedAttention`'s Q/K/V interface) The function modifies *model* in-place **and** returns it for convenience. @@ -81,25 +83,35 @@ def auto_wrap_model_for_sp(model: nn.Module, process_group) -> nn.Module: # Wrap ViT encoder attention layers # ------------------------------------------------------------------ for name, module in info.vit_attn_modules: - wrapped = UlyssesSPViTAttention(module, process_group) + cls_name = type(module).__name__ + # Look up whether this ViT architecture uses a CLS token; default True + # (safe fallback) for unknown classes not yet in the registry. + has_cls = _VIT_HAS_CLS_TOKEN.get(cls_name, True) + wrapped = UlyssesSPViTAttention(module, process_group, has_cls_token=has_cls) _set_module_by_name(model, name, wrapped) - logger.debug("AutoSP: wrapped ViT attention '%s' with UlyssesSPViTAttention", name) + logger.debug("AutoSP: wrapped ViT attention '%s' with UlyssesSPViTAttention (has_cls_token=%s)", name, + has_cls) logger.info("AutoSP: wrapped %d ViT attention layer(s).", len(info.vit_attn_modules)) # ------------------------------------------------------------------ - # Wrap LLM decoder attention layers + # LLM decoder attention layers — warn, do not auto-wrap # ------------------------------------------------------------------ + # DistributedAttention expects a Megatron-style (query, key, value) + # interface, but every class in _LLM_ATTN_CLASSNAMES uses the + # HuggingFace hidden_states interface. Wrapping them silently would + # produce incorrect behaviour at the first forward pass. Emit a + # per-layer warning so the user can configure SP manually. for name, module in info.llm_attn_modules: - # DistributedAttention wraps a Megatron-style attention that receives - # (query, key, value) tensors separately. For HuggingFace-style - # attention that receives hidden_states, use scatter_idx=2 / gather_idx=0 - # defaults which match the typical [bs, seq, heads, dim] layout. - wrapped = DistributedAttention(local_attention=module, sequence_process_group=process_group) - _set_module_by_name(model, name, wrapped) - logger.debug("AutoSP: wrapped LLM attention '%s' with DistributedAttention", name) - - logger.info("AutoSP: wrapped %d LLM attention layer(s).", len(info.llm_attn_modules)) + logger.warning( + "AutoSP: LLM attention '%s' (class %s) uses a HuggingFace hidden_states " + "interface that is incompatible with DistributedAttention's Q/K/V interface. " + "Skipping auto-wrap. Configure sequence parallelism for this layer manually.", name, + type(module).__name__) + + if info.llm_attn_modules: + logger.info("AutoSP: found %d LLM attention layer(s); skipped wrapping (see warnings above).", + len(info.llm_attn_modules)) # ------------------------------------------------------------------ # Warn about the vision projection layer (Phase 2) diff --git a/deepspeed/sequence/autosp_detector.py b/deepspeed/sequence/autosp_detector.py index 2bd423b67d71..be9aab5b320d 100644 --- a/deepspeed/sequence/autosp_detector.py +++ b/deepspeed/sequence/autosp_detector.py @@ -29,6 +29,19 @@ "PaliGemmaVisionAttention", } +# Whether each known ViT class uses a prepended CLS token. +# CLS is replicated on every rank and is NOT sharded across the sequence. +# Defaults to True for unknown classes (safe fallback). +_VIT_HAS_CLS_TOKEN = { + "ViTAttention": True, + "CLIPAttention": True, + "SiglipAttention": False, + "InternVisionAttention": False, + "Qwen2VLVisionAttention": False, + "Idefics2VisionAttention": False, + "PaliGemmaVisionAttention": False, +} + # Known LLM decoder attention class names _LLM_ATTN_CLASSNAMES = { "LlamaAttention", diff --git a/tests/unit/sequence_parallelism/test_autosp_integration.py b/tests/unit/sequence_parallelism/test_autosp_integration.py index 3412283535fd..4efcdb07c302 100644 --- a/tests/unit/sequence_parallelism/test_autosp_integration.py +++ b/tests/unit/sequence_parallelism/test_autosp_integration.py @@ -6,15 +6,12 @@ Each test builds a minimal mock model whose attention-layer class names match the autosp_detector registry, then verifies two things: -1. auto_wrap_model_for_sp correctly identifies and wraps the attention modules. +1. auto_wrap_model_for_sp correctly identifies and wraps ViT attention modules + (with the correct has_cls_token value from the registry) and emits warnings + for HF-style LLM attention without wrapping them. 2. The full pipeline (SP-wrapped ViT -> fusion adapter) produces fused output numerically equivalent to the single-device splice reference. -The LLM decoder branch is intentionally not called in the forward-pass -equivalence tests because DistributedAttention uses a Megatron Q/K/V -interface that is incompatible with the simple hidden_states mock. Its -correct injection is verified by the detection tests instead. - These tests require 2 GPUs. Run with: @@ -28,7 +25,6 @@ import deepspeed.comm as dist from deepspeed.sequence.auto_sp import auto_wrap_model_for_sp from deepspeed.sequence.autosp_vit import UlyssesSPViTAttention -from deepspeed.sequence.layer import DistributedAttention from deepspeed.sequence.autosp_fusion import InternVLFusionAdapter, Qwen2VLFusionAdapter from deepspeed.accelerator import get_accelerator @@ -107,7 +103,7 @@ class _MinimalInternVLModel(nn.Module): - ``mm_projector`` -> keyword in _VISION_PROJ_KEYWORDS ``forward`` exercises only the ViT + fusion path; ``language_model`` is - present solely for detection verification (DistributedAttention wrapping). + present to verify that auto_wrap does NOT wrap HF-style LLM attention. """ def __init__(self) -> None: @@ -155,7 +151,8 @@ class TestInternVLIntegration(DistributedTest): def test_auto_wrap_detects_and_wraps_modules(self): """auto_wrap_model_for_sp must replace InternVisionAttention with - UlyssesSPViTAttention and InternLM2Attention with DistributedAttention.""" + UlyssesSPViTAttention (has_cls_token=False) and must NOT wrap + InternLM2Attention (HF-style, incompatible with DistributedAttention).""" sp_group = dist.new_group(ranks=list(range(self.world_size))) model = _MinimalInternVLModel().to(get_accelerator().device_name()) auto_wrap_model_for_sp(model, sp_group) @@ -163,9 +160,10 @@ def test_auto_wrap_detects_and_wraps_modules(self): assert isinstance( model.vision_encoder[0].attn, UlyssesSPViTAttention), ("Expected vision_encoder[0].attn to be UlyssesSPViTAttention after auto_wrap") - assert isinstance( - model.language_model[0].attn, - DistributedAttention), ("Expected language_model[0].attn to be DistributedAttention after auto_wrap") + assert not model.vision_encoder[0].attn.has_cls_token, ( + "InternVisionAttention has no CLS token; has_cls_token must be False") + assert isinstance(model.language_model[0].attn, + InternLM2Attention), ("HF-style LLM attention must NOT be wrapped by auto_wrap") def test_full_pipeline_visual_to_fused(self): """SP-wrapped ViT -> InternVLFusionAdapter must produce fused output @@ -186,10 +184,6 @@ def test_full_pipeline_visual_to_fused(self): model = _MinimalInternVLModel().to(get_accelerator().device_name()) auto_wrap_model_for_sp(model, sp_group) - # The mock ViT has no CLS token; override the default set by auto_wrap. - for m in model.modules(): - if isinstance(m, UlyssesSPViTAttention): - m.has_cls_token = False model.fusion = InternVLFusionAdapter(model.mm_projector, sp_group, image_token_id=_INTERNVL_CONTEXT_ID).to(get_accelerator().device_name()) @@ -224,7 +218,8 @@ class TestQwen2VLIntegration(DistributedTest): def test_auto_wrap_detects_and_wraps_modules(self): """auto_wrap_model_for_sp must replace Qwen2VLVisionAttention with - UlyssesSPViTAttention and Qwen2Attention with DistributedAttention.""" + UlyssesSPViTAttention (has_cls_token=False) and must NOT wrap + Qwen2Attention (HF-style, incompatible with DistributedAttention).""" sp_group = dist.new_group(ranks=list(range(self.world_size))) model = _MinimalQwen2VLModel().to(get_accelerator().device_name()) auto_wrap_model_for_sp(model, sp_group) @@ -232,8 +227,10 @@ def test_auto_wrap_detects_and_wraps_modules(self): assert isinstance( model.visual[0].attn, UlyssesSPViTAttention), ("Expected visual[0].attn to be UlyssesSPViTAttention after auto_wrap") + assert not model.visual[0].attn.has_cls_token, ( + "Qwen2VLVisionAttention has no CLS token; has_cls_token must be False") assert isinstance(model.model[0].attn, - DistributedAttention), ("Expected model[0].attn to be DistributedAttention after auto_wrap") + Qwen2Attention), ("HF-style LLM attention must NOT be wrapped by auto_wrap") def test_full_pipeline_visual_to_fused(self): """SP-wrapped ViT -> Qwen2VLFusionAdapter must produce fused output @@ -255,9 +252,6 @@ def test_full_pipeline_visual_to_fused(self): model = _MinimalQwen2VLModel().to(get_accelerator().device_name()) auto_wrap_model_for_sp(model, sp_group) - for m in model.modules(): - if isinstance(m, UlyssesSPViTAttention): - m.has_cls_token = False model.fusion = Qwen2VLFusionAdapter(model.multi_modal_projector, sp_group, vision_start_token_id=_QWEN2VL_START_ID, From b46e5888488f379a5d174803c8fe6fc5aa2df864 Mon Sep 17 00:00:00 2001 From: nathon-lee Date: Wed, 29 Apr 2026 10:52:55 +0800 Subject: [PATCH 19/21] fix: Fix non-divisible patch padding in UlyssesSPViTAttention; add docs and tests Signed-off-by: nathon-lee --- deepspeed/sequence/autosp_fusion.py | 40 +++++++++++++++++ deepspeed/sequence/autosp_vit.py | 34 +++++++++++--- .../test_autosp_equivalence.py | 45 +++++++++++++++++++ 3 files changed, 112 insertions(+), 7 deletions(-) diff --git a/deepspeed/sequence/autosp_fusion.py b/deepspeed/sequence/autosp_fusion.py index cdd3286dd6e9..3608cfe9a64f 100644 --- a/deepspeed/sequence/autosp_fusion.py +++ b/deepspeed/sequence/autosp_fusion.py @@ -21,6 +21,46 @@ │ LLM decoder (SP-aware) +Usage +----- +After calling :func:`~deepspeed.sequence.auto_sp.auto_wrap_model_for_sp` to +wrap the ViT attention layers, attach the appropriate fusion adapter to the +vision-language projection layer **before** the first forward pass. Choose +the adapter that matches your model architecture:: + + from deepspeed.sequence.auto_sp import auto_wrap_model_for_sp + from deepspeed.sequence.autosp_fusion import ( + LlavaFusionAdapter, + InternVLFusionAdapter, + Qwen2VLFusionAdapter, + ) + from deepspeed.utils import groups + + # 1. Wrap ViT and LLM attention layers automatically. + sp_group = groups._get_sequence_parallel_group() + auto_wrap_model_for_sp(model, process_group=sp_group) + + # 2. Attach the fusion adapter for the vision-language projection layer. + # LLaVA — replaces image-placeholder tokens with visual tokens: + model.mm_projector = LlavaFusionAdapter( + model.mm_projector, sp_group, image_token_id=IMAGE_TOKEN_ID + ) + + # InternVL — replaces IMG_CONTEXT tokens 1-to-1 with visual tokens: + model.mm_projector = InternVLFusionAdapter( + model.mm_projector, sp_group, image_token_id=IMG_CONTEXT_TOKEN_ID + ) + + # Qwen2-VL — replaces tokens between vision_start/end pairs 1-to-1: + model.visual.merger = Qwen2VLFusionAdapter( + model.visual.merger, sp_group, + vision_start_token_id=VISION_START_ID, + vision_end_token_id=VISION_END_ID, + ) + + # 3. Use the model as normal; the adapter handles all SP gather/scatter. + outputs = model(input_ids=input_ids, pixel_values=pixel_values, ...) + Status: Phase 2. ``_splice_visual_into_text`` is intentionally left as a ``NotImplementedError``; override it per model architecture (see docstring). """ diff --git a/deepspeed/sequence/autosp_vit.py b/deepspeed/sequence/autosp_vit.py index 85c2ea593423..9834b8ee0621 100644 --- a/deepspeed/sequence/autosp_vit.py +++ b/deepspeed/sequence/autosp_vit.py @@ -37,6 +37,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F import deepspeed.comm as dist @@ -100,10 +101,27 @@ def forward(self, hidden_states: torch.Tensor, **kwargs): # ------------------------------------------------------------------- # 1. All-gather patches from all ranks to reconstruct the full sequence # ------------------------------------------------------------------- - # We need to all-gather so every rank sees the full K/V context. - gathered = [torch.zeros_like(local_patches) for _ in range(self.world_size)] - dist.all_gather(gathered, local_patches.contiguous(), group=self.process_group) - full_patches = torch.cat(gathered, dim=1) # [bs, num_patches_padded, hidden_dim] + # When num_patches % world_size != 0, ranks may hold different numbers + # of patches (the first `num_patches % world_size` ranks carry one extra + # patch). We find the largest local_patch_len across ranks and zero-pad + # shorter slices so that all_gather receives equal-size tensors. + max_len_t = torch.tensor(local_patch_len, dtype=torch.long, device=local_patches.device) + dist.all_reduce(max_len_t, op=dist.ReduceOp.MAX, group=self.process_group) + max_local_len = int(max_len_t.item()) + + pad_len = max_local_len - local_patch_len + if pad_len > 0: + # Append zero rows so this rank's buffer matches the largest shard. + local_patches_padded = F.pad(local_patches, (0, 0, 0, pad_len)) + else: + local_patches_padded = local_patches + + gathered = [ + torch.zeros(bs, max_local_len, hidden_dim, dtype=local_patches.dtype, device=local_patches.device) + for _ in range(self.world_size) + ] + dist.all_gather(gathered, local_patches_padded.contiguous(), group=self.process_group) + full_patches = torch.cat(gathered, dim=1) # [bs, world_size * max_local_len, hidden_dim] # ------------------------------------------------------------------- # 2. Build the full input (prepend CLS if needed) and call attention @@ -123,7 +141,9 @@ def forward(self, hidden_states: torch.Tensor, **kwargs): extra = [] # ------------------------------------------------------------------- - # 3. Scatter output: each rank keeps only its local slice of patches + # 3. Scatter output: each rank keeps only its local slice of patches. + # Slice starts at rank * max_local_len and spans local_patch_len + # tokens, dropping the zero-padding rows that may have been appended. # ------------------------------------------------------------------- if self.has_cls_token: cls_out = full_out[:, :1, :] @@ -131,9 +151,9 @@ def forward(self, hidden_states: torch.Tensor, **kwargs): else: patch_out = full_out - # Determine this rank's slice boundaries rank = dist.get_rank(self.process_group) - local_out = patch_out[:, rank * local_patch_len:(rank + 1) * local_patch_len, :].contiguous() + start = rank * max_local_len + local_out = patch_out[:, start:start + local_patch_len, :].contiguous() if self.has_cls_token: local_out = torch.cat([cls_out, local_out], dim=1) diff --git a/tests/unit/sequence_parallelism/test_autosp_equivalence.py b/tests/unit/sequence_parallelism/test_autosp_equivalence.py index 391d3ae264cc..2b908e1c87c1 100644 --- a/tests/unit/sequence_parallelism/test_autosp_equivalence.py +++ b/tests/unit/sequence_parallelism/test_autosp_equivalence.py @@ -97,6 +97,51 @@ def test_output_equals_single_device(self, has_cls_token, num_patches): atol=1e-5), (f"rank={rank} sp_out differs from reference: " f"max_diff={( sp_out - ref_slice).abs().max().item():.2e}") + @pytest.mark.parametrize("has_cls_token", [True, False]) + def test_noneven_patches(self, has_cls_token): + """When num_patches % world_size != 0, the wrapper must still produce + correct per-rank output. With 5 patches and world_size=2, rank 0 + holds 3 patches and rank 1 holds 2 patches.""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + rank = dist.get_rank(sp_group) + bs, hidden = 2, 16 + num_patches = 5 # not divisible by world_size=2 + + torch.manual_seed(77) + if has_cls_token: + full_input = torch.randn(bs, 1 + num_patches, hidden).to(get_accelerator().device_name()) + else: + full_input = torch.randn(bs, num_patches, hidden).to(get_accelerator().device_name()) + + # Distribute: first (num_patches % world_size) ranks carry one extra patch. + extra = num_patches % self.world_size # = 1 + base = num_patches // self.world_size # = 2 + local_v = base + (1 if rank < extra else 0) + patch_start = rank * base + min(rank, extra) + + if has_cls_token: + cls = full_input[:, :1, :] + patch_slice = full_input[:, 1 + patch_start:1 + patch_start + local_v, :] + local_input = torch.cat([cls, patch_slice], dim=1) + else: + local_input = full_input[:, patch_start:patch_start + local_v, :] + + wrapper = UlyssesSPViTAttention(_IdentityAttn().to(get_accelerator().device_name()), + sp_group, + has_cls_token=has_cls_token) + sp_out = wrapper(local_input) + + # Reference: identity wrapper — each rank's output must equal its input slice. + if has_cls_token: + ref_slice = torch.cat( + [full_input[:, :1, :], full_input[:, 1 + patch_start:1 + patch_start + local_v, :]], dim=1) + else: + ref_slice = full_input[:, patch_start:patch_start + local_v, :] + + assert torch.allclose(sp_out, ref_slice, + atol=1e-5), (f"rank={rank} non-even patches: sp_out differs from reference: " + f"max_diff={(sp_out - ref_slice).abs().max().item():.2e}") + # --------------------------------------------------------------------------- # LlavaFusionAdapter equivalence From 9f0c277bbf59779c7e704d822b5012701c168224 Mon Sep 17 00:00:00 2001 From: nathon-lee Date: Wed, 29 Apr 2026 03:00:06 +0000 Subject: [PATCH 20/21] fix: fix some format errs by tool Signed-off-by: nathon-lee --- deepspeed/sequence/auto_sp.py | 3 +-- tests/unit/sequence_parallelism/test_autosp_equivalence.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/deepspeed/sequence/auto_sp.py b/deepspeed/sequence/auto_sp.py index e6f5326fb11f..2bb17413512b 100644 --- a/deepspeed/sequence/auto_sp.py +++ b/deepspeed/sequence/auto_sp.py @@ -89,8 +89,7 @@ def auto_wrap_model_for_sp(model: nn.Module, process_group) -> nn.Module: has_cls = _VIT_HAS_CLS_TOKEN.get(cls_name, True) wrapped = UlyssesSPViTAttention(module, process_group, has_cls_token=has_cls) _set_module_by_name(model, name, wrapped) - logger.debug("AutoSP: wrapped ViT attention '%s' with UlyssesSPViTAttention (has_cls_token=%s)", name, - has_cls) + logger.debug("AutoSP: wrapped ViT attention '%s' with UlyssesSPViTAttention (has_cls_token=%s)", name, has_cls) logger.info("AutoSP: wrapped %d ViT attention layer(s).", len(info.vit_attn_modules)) diff --git a/tests/unit/sequence_parallelism/test_autosp_equivalence.py b/tests/unit/sequence_parallelism/test_autosp_equivalence.py index 2b908e1c87c1..0c289353a29d 100644 --- a/tests/unit/sequence_parallelism/test_autosp_equivalence.py +++ b/tests/unit/sequence_parallelism/test_autosp_equivalence.py @@ -133,8 +133,8 @@ def test_noneven_patches(self, has_cls_token): # Reference: identity wrapper — each rank's output must equal its input slice. if has_cls_token: - ref_slice = torch.cat( - [full_input[:, :1, :], full_input[:, 1 + patch_start:1 + patch_start + local_v, :]], dim=1) + ref_slice = torch.cat([full_input[:, :1, :], full_input[:, 1 + patch_start:1 + patch_start + local_v, :]], + dim=1) else: ref_slice = full_input[:, patch_start:patch_start + local_v, :] From b6bb8b3c63b78e8fa9987ee3c99cd92200392789 Mon Sep 17 00:00:00 2001 From: nathon-lee Date: Thu, 30 Apr 2026 10:39:37 +0800 Subject: [PATCH 21/21] [AutoSP] Fix padding-before-attention bug in UlyssesSPViTAttention De-pad gathered shards before calling attention so that dummy zero-tokens never enter the softmax computation: - all_gather each rank's exact local_patch_len into all_lens - strip per-rank padding from gathered buffers before torch.cat - update scatter offset to sum(all_lens[:rank]) instead of rank*max_local_len All 6 TestViTSPEquivalence tests pass (including test_noneven_patches). Signed-off-by: nathon-lee fix some format err by tool Signed-off-by: nathon-lee --- deepspeed/sequence/autosp_vit.py | 42 ++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/deepspeed/sequence/autosp_vit.py b/deepspeed/sequence/autosp_vit.py index 9834b8ee0621..09bebbbac8f3 100644 --- a/deepspeed/sequence/autosp_vit.py +++ b/deepspeed/sequence/autosp_vit.py @@ -30,9 +30,12 @@ across the sequence dimension. Each rank appends its local patches to the same ``cls`` token before calling the wrapped attention. -Padding: when ``num_patches % world_size != 0``, we pad patches with zeros -before scattering and strip the padding after gathering. Padding tokens do -not carry gradients and are never passed to downstream layers. +Padding: when ``num_patches % world_size != 0``, shorter shards are +zero-padded to a uniform size for ``all_gather``. The padding is stripped +*before* the attention call by trimming each rank's contribution to its true +length, so the wrapped attention always sees exactly ``num_patches`` real +tokens — identical to single-device execution and free of softmax pollution +from dummy tokens. """ import torch @@ -101,13 +104,18 @@ def forward(self, hidden_states: torch.Tensor, **kwargs): # ------------------------------------------------------------------- # 1. All-gather patches from all ranks to reconstruct the full sequence # ------------------------------------------------------------------- - # When num_patches % world_size != 0, ranks may hold different numbers - # of patches (the first `num_patches % world_size` ranks carry one extra - # patch). We find the largest local_patch_len across ranks and zero-pad - # shorter slices so that all_gather receives equal-size tensors. - max_len_t = torch.tensor(local_patch_len, dtype=torch.long, device=local_patches.device) - dist.all_reduce(max_len_t, op=dist.ReduceOp.MAX, group=self.process_group) - max_local_len = int(max_len_t.item()) + # When num_patches % world_size != 0, ranks hold different shard sizes. + # We all-gather every rank's local_patch_len so we can: + # (a) zero-pad shorter slices to uniform size for all_gather, and + # (b) strip the padding per rank *before* calling attention, so that + # the wrapped module never sees dummy tokens (which would corrupt + # the softmax normalisation). + len_bufs = [torch.zeros(1, dtype=torch.long, device=local_patches.device) for _ in range(self.world_size)] + dist.all_gather(len_bufs, + torch.tensor([local_patch_len], dtype=torch.long, device=local_patches.device), + group=self.process_group) + all_lens = [int(t.item()) for t in len_bufs] + max_local_len = max(all_lens) pad_len = max_local_len - local_patch_len if pad_len > 0: @@ -121,7 +129,11 @@ def forward(self, hidden_states: torch.Tensor, **kwargs): for _ in range(self.world_size) ] dist.all_gather(gathered, local_patches_padded.contiguous(), group=self.process_group) - full_patches = torch.cat(gathered, dim=1) # [bs, world_size * max_local_len, hidden_dim] + + # Strip per-rank padding before concatenation so attention only sees + # the true num_patches tokens, identical to single-device execution. + real_parts = [gathered[r][:, :all_lens[r], :] for r in range(self.world_size)] + full_patches = torch.cat(real_parts, dim=1) # [bs, total_real_patches, hidden_dim] # ------------------------------------------------------------------- # 2. Build the full input (prepend CLS if needed) and call attention @@ -141,9 +153,9 @@ def forward(self, hidden_states: torch.Tensor, **kwargs): extra = [] # ------------------------------------------------------------------- - # 3. Scatter output: each rank keeps only its local slice of patches. - # Slice starts at rank * max_local_len and spans local_patch_len - # tokens, dropping the zero-padding rows that may have been appended. + # 3. Scatter output: each rank keeps its local slice of the real patches. + # Because padding was stripped before attention, scatter offsets are + # the cumulative sums of all_lens, not rank * max_local_len. # ------------------------------------------------------------------- if self.has_cls_token: cls_out = full_out[:, :1, :] @@ -152,7 +164,7 @@ def forward(self, hidden_states: torch.Tensor, **kwargs): patch_out = full_out rank = dist.get_rank(self.process_group) - start = rank * max_local_len + start = sum(all_lens[:rank]) local_out = patch_out[:, start:start + local_patch_len, :].contiguous() if self.has_cls_token: