diff --git a/benchmarks/autosp/bench_multimodal_sp.py b/benchmarks/autosp/bench_multimodal_sp.py new file mode 100644 index 000000000000..e6ad8faf86ce --- /dev/null +++ b/benchmarks/autosp/bench_multimodal_sp.py @@ -0,0 +1,275 @@ +# SPDX-License-Identifier: Apache-2.0 +# DeepSpeed Team +""" +Benchmark: AutoSP multimodal sequence parallelism (ViT SP + fusion adapter). + +Measures per-iteration latency, throughput, and peak GPU memory for the +ViT-SP + fusion-adapter pipeline at a given SP degree. + +Launch (from repo root): + + # SP degree 2 — two GPUs: + NCCL_P2P_DISABLE=1 torchrun --nproc_per_node=2 \\ + benchmarks/autosp/bench_multimodal_sp.py [args] + + # Baseline — single GPU (all-gather/scatter are no-ops): + torchrun --nproc_per_node=1 \\ + benchmarks/autosp/bench_multimodal_sp.py [args] + +Compare the two output tables to quantify memory savings and throughput scaling. + +Arguments: + --arch {internvl, qwen2vl} architecture to simulate (default: internvl) + --batch-size N samples per batch (default: 2) + --seq-len N text sequence length (default: 512) + --visual-tokens N total visual tokens per sample (default: 256) + --hidden N hidden dimension (default: 1024) + --num-layers N ViT and LLM layers each (default: 2) + --iters N measured iterations (default: 50) + --warmup N warmup iterations (default: 10) +""" + +import argparse +import logging +import statistics + +import torch +import torch.nn as nn + +import deepspeed +import deepspeed.comm as dist +from deepspeed.accelerator import get_accelerator +from deepspeed.sequence.auto_sp import auto_wrap_model_for_sp +from deepspeed.sequence.autosp_fusion import InternVLFusionAdapter, Qwen2VLFusionAdapter + +# --------------------------------------------------------------------------- +# Token IDs +# --------------------------------------------------------------------------- + +_INTERNVL_CONTEXT_ID = 92546 +_QWEN2VL_START_ID = 151652 +_QWEN2VL_END_ID = 151653 + +# --------------------------------------------------------------------------- +# Mock attention classes — names match autosp_detector registries exactly +# --------------------------------------------------------------------------- + + +class InternVisionAttention(nn.Module): + + def forward(self, hidden_states, **kwargs): + return hidden_states + + +class InternLM2Attention(nn.Module): + + def forward(self, hidden_states, **kwargs): + return hidden_states + + +class Qwen2VLVisionAttention(nn.Module): + + def forward(self, hidden_states, **kwargs): + return hidden_states + + +class Qwen2Attention(nn.Module): + + def forward(self, hidden_states, **kwargs): + return hidden_states + + +# --------------------------------------------------------------------------- +# Model building blocks +# --------------------------------------------------------------------------- + + +class _ViTBlock(nn.Module): + """One ViT transformer block: attention (to be SP-wrapped) + linear FFN.""" + + def __init__(self, attn_cls, hidden: int) -> None: + super().__init__() + self.attn = attn_cls() + self.ffn = nn.Linear(hidden, hidden, bias=False) + + def forward(self, x, **kwargs): + out = self.attn(x, **kwargs) + if isinstance(out, (tuple, list)): + out = out[0] + return self.ffn(out) + + +class _MinimalInternVLModel(nn.Module): + """InternVL-like benchmark model. + + Module paths detected by autosp_detector: + - ``vision_encoder.*.attn`` -> InternVisionAttention (_VIT_ATTN_CLASSNAMES) + - ``mm_projector`` -> keyword in _VISION_PROJ_KEYWORDS + + ``language_model`` uses plain nn.Linear layers so it is NOT wrapped by + DistributedAttention (avoids the Q/K/V interface requirement) yet still + contributes realistic compute on the scattered fused sequence. + """ + + def __init__(self, hidden: int, num_layers: int) -> None: + super().__init__() + self.vision_encoder = nn.Sequential(*[_ViTBlock(InternVisionAttention, hidden) for _ in range(num_layers)]) + self.mm_projector = nn.Identity() + self.language_model = nn.Sequential(*[nn.Linear(hidden, hidden, bias=False) for _ in range(num_layers)]) + self.fusion = None + + def forward(self, local_patches: torch.Tensor, text_embeds: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor: + local_visual = self.vision_encoder(local_patches) + local_fused = self.fusion(local_visual, text_embeds, input_ids) + return self.language_model(local_fused) + + +class _MinimalQwen2VLModel(nn.Module): + """Qwen2VL-like benchmark model.""" + + def __init__(self, hidden: int, num_layers: int) -> None: + super().__init__() + self.visual = nn.Sequential(*[_ViTBlock(Qwen2VLVisionAttention, hidden) for _ in range(num_layers)]) + self.multi_modal_projector = nn.Identity() + self.model = nn.Sequential(*[nn.Linear(hidden, hidden, bias=False) for _ in range(num_layers)]) + self.fusion = None + + def forward(self, local_patches: torch.Tensor, text_embeds: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor: + local_visual = self.visual(local_patches) + local_fused = self.fusion(local_visual, text_embeds, input_ids) + return self.model(local_fused) + + +# --------------------------------------------------------------------------- +# Setup helpers +# --------------------------------------------------------------------------- + + +def _build_model_and_inputs(arch: str, args, sp_group, device): + rank = dist.get_rank(sp_group) + world_size = dist.get_world_size(sp_group) + + local_v = args.visual_tokens // world_size + bs, text_len, hidden = args.batch_size, args.seq_len, args.hidden + + torch.manual_seed(0) + local_patches = torch.randn(bs, local_v, hidden, device=device) + text_embeds = torch.randn(bs, text_len, hidden, device=device) + input_ids = torch.zeros(bs, text_len, dtype=torch.long, device=device) + + if arch == "internvl": + num_ctx = min(local_v * world_size, text_len - 2) + input_ids[:, 2:2 + num_ctx] = _INTERNVL_CONTEXT_ID + + model = _MinimalInternVLModel(hidden, args.num_layers).to(device) + # Suppress the Phase 2 projection-layer warning: we wrap manually below. + _auto_sp_logger = logging.getLogger("deepspeed.sequence.auto_sp") + _prev_level = _auto_sp_logger.level + _auto_sp_logger.setLevel(logging.ERROR) + auto_wrap_model_for_sp(model, sp_group) + _auto_sp_logger.setLevel(_prev_level) + model.fusion = InternVLFusionAdapter(model.mm_projector, sp_group, + image_token_id=_INTERNVL_CONTEXT_ID).to(device) + else: # qwen2vl + num_inner = min(local_v * world_size, text_len - 3) + input_ids[:, 1] = _QWEN2VL_START_ID + input_ids[:, 2 + num_inner] = _QWEN2VL_END_ID + + model = _MinimalQwen2VLModel(hidden, args.num_layers).to(device) + _auto_sp_logger = logging.getLogger("deepspeed.sequence.auto_sp") + _prev_level = _auto_sp_logger.level + _auto_sp_logger.setLevel(logging.ERROR) + auto_wrap_model_for_sp(model, sp_group) + _auto_sp_logger.setLevel(_prev_level) + model.fusion = Qwen2VLFusionAdapter(model.multi_modal_projector, + sp_group, + vision_start_token_id=_QWEN2VL_START_ID, + vision_end_token_id=_QWEN2VL_END_ID).to(device) + + return model, local_patches, text_embeds, input_ids + + +# --------------------------------------------------------------------------- +# Benchmark runner +# --------------------------------------------------------------------------- + + +def _run(arch: str, args) -> None: + deepspeed.init_distributed(dist_backend="nccl") + rank = dist.get_rank() + world_size = dist.get_world_size() + device = torch.device(get_accelerator().device_name(), rank % get_accelerator().device_count()) + get_accelerator().set_device(rank % get_accelerator().device_count()) + + sp_group = dist.new_group(ranks=list(range(world_size))) + model, local_patches, text_embeds, input_ids = _build_model_and_inputs(arch, args, sp_group, device) + model.eval() + + # Warmup + with torch.no_grad(): + for _ in range(args.warmup): + model(local_patches, text_embeds, input_ids) + get_accelerator().synchronize() + get_accelerator().reset_peak_memory_stats() + + # Timed iterations using CUDA events for accurate GPU-side measurement. + latencies_ms = [] + with torch.no_grad(): + for _ in range(args.iters): + t_start = get_accelerator().Event(enable_timing=True) + t_end = get_accelerator().Event(enable_timing=True) + t_start.record() + model(local_patches, text_embeds, input_ids) + t_end.record() + get_accelerator().synchronize() + latencies_ms.append(t_start.elapsed_time(t_end)) + + peak_mem_mb = get_accelerator().max_memory_allocated() / 1024**2 + mean_ms = statistics.mean(latencies_ms) + std_ms = statistics.stdev(latencies_ms) if len(latencies_ms) > 1 else 0.0 + # tokens/s: fused sequence length approximated by seq_len (length-preserving adapters). + throughput = (args.batch_size * args.seq_len) / (mean_ms / 1000.0) + + if rank == 0: + sep = "=" * 62 + print(f"\n{sep}") + print(f" AutoSP Benchmark arch={arch} sp_degree={world_size}") + print(sep) + print(f" batch_size : {args.batch_size}") + print(f" seq_len : {args.seq_len}") + print(f" visual_tokens : {args.visual_tokens} (local={args.visual_tokens // world_size}/rank)") + print(f" hidden : {args.hidden}") + print(f" num_layers : {args.num_layers}") + print(f" warmup / iters : {args.warmup} / {args.iters}") + print(f" {'─' * 58}") + print(f" Latency : {mean_ms:.2f} ± {std_ms:.2f} ms/iter") + print(f" Throughput : {throughput:,.0f} tokens/s") + print(f" Peak GPU memory : {peak_mem_mb:.1f} MB") + print(f"{sep}\n") + + dist.destroy_process_group() + + +def main() -> None: + parser = argparse.ArgumentParser(description="AutoSP multimodal SP benchmark") + parser.add_argument("--arch", + choices=["internvl", "qwen2vl"], + default="internvl", + help="Model architecture to simulate") + parser.add_argument("--batch-size", type=int, default=2) + parser.add_argument("--seq-len", type=int, default=512) + parser.add_argument("--visual-tokens", + type=int, + default=256, + help="Total visual tokens (must be divisible by --nproc_per_node)") + parser.add_argument("--hidden", type=int, default=1024) + parser.add_argument("--num-layers", type=int, default=2, help="Number of ViT blocks and LLM linear layers each") + parser.add_argument("--iters", type=int, default=50) + parser.add_argument("--warmup", type=int, default=10) + args = parser.parse_args() + + _run(args.arch, args) + + +if __name__ == "__main__": + main() diff --git a/deepspeed/sequence/__init__.py b/deepspeed/sequence/__init__.py index 208299fb8c50..b76f944eff79 100644 --- a/deepspeed/sequence/__init__.py +++ b/deepspeed/sequence/__init__.py @@ -2,3 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team + +from deepspeed.sequence.autosp_detector import detect_model_sp_info, SPModelInfo +from deepspeed.sequence.autosp_vit import UlyssesSPViTAttention +from deepspeed.sequence.autosp_fusion import (ModalityFusionSPAdapter, LlavaFusionAdapter, InternVLFusionAdapter, + Qwen2VLFusionAdapter) +from deepspeed.sequence.auto_sp import auto_wrap_model_for_sp diff --git a/deepspeed/sequence/auto_sp.py b/deepspeed/sequence/auto_sp.py new file mode 100644 index 000000000000..2bb17413512b --- /dev/null +++ b/deepspeed/sequence/auto_sp.py @@ -0,0 +1,140 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +AutoSP: one-call sequence parallelism for multimodal models. + +Usage:: + + from deepspeed.sequence.auto_sp import auto_wrap_model_for_sp + from deepspeed.utils import groups + + model, _, _, _ = deepspeed.initialize(config=ds_config, model=model, ...) + sp_group = groups._get_sequence_parallel_group() + model = auto_wrap_model_for_sp(model, process_group=sp_group) + +``auto_wrap_model_for_sp`` scans the model and injects: + +* :class:`~deepspeed.sequence.autosp_vit.UlyssesSPViTAttention` + for ViT encoder attention layers. +* a warning for LLM decoder attention layers: HuggingFace-style + ``hidden_states`` attention is incompatible with + :class:`~deepspeed.sequence.layer.DistributedAttention`'s Q/K/V interface; + configure LLM sequence parallelism manually. + +The vision-language projection layer (Phase 2) is detected and a warning is +emitted; wrap it manually with +:class:`~deepspeed.sequence.autosp_fusion.ModalityFusionSPAdapter` until +Phase 2 automation is implemented. +""" + +import logging + +import torch.nn as nn + +from deepspeed.sequence.autosp_detector import detect_model_sp_info, _VIT_HAS_CLS_TOKEN +from deepspeed.sequence.autosp_vit import UlyssesSPViTAttention + +logger = logging.getLogger(__name__) + + +def auto_wrap_model_for_sp(model: nn.Module, process_group) -> nn.Module: + """Inject sequence-parallel wrappers into *model* in-place. + + Scans the model's named modules and replaces recognised attention layers + with their SP-aware equivalents: + + * ViT attention → :class:`UlyssesSPViTAttention` + * LLM attention → warning only (HuggingFace ``hidden_states`` interface + is incompatible with :class:`DistributedAttention`'s Q/K/V interface) + + The function modifies *model* in-place **and** returns it for convenience. + + Parameters + ---------- + model: + The multimodal model to wrap. Must be on the correct device before + calling this function. + process_group: + The sequence-parallel process group (from + ``groups._get_sequence_parallel_group()``). + + Returns + ------- + The same *model* object with attention layers replaced. + + Raises + ------ + ValueError + If no recognisable attention modules are found. Register the model's + attention class names in ``autosp_detector._VIT_ATTN_CLASSNAMES`` or + ``_LLM_ATTN_CLASSNAMES`` to fix this. + """ + info = detect_model_sp_info(model) + + if not info.vit_attn_modules and not info.llm_attn_modules: + raise ValueError("auto_wrap_model_for_sp: no recognisable attention modules found. " + "Add the model's attention class name(s) to " + "_VIT_ATTN_CLASSNAMES or _LLM_ATTN_CLASSNAMES in " + "deepspeed/sequence/autosp_detector.py and retry.") + + # ------------------------------------------------------------------ + # Wrap ViT encoder attention layers + # ------------------------------------------------------------------ + for name, module in info.vit_attn_modules: + cls_name = type(module).__name__ + # Look up whether this ViT architecture uses a CLS token; default True + # (safe fallback) for unknown classes not yet in the registry. + has_cls = _VIT_HAS_CLS_TOKEN.get(cls_name, True) + wrapped = UlyssesSPViTAttention(module, process_group, has_cls_token=has_cls) + _set_module_by_name(model, name, wrapped) + logger.debug("AutoSP: wrapped ViT attention '%s' with UlyssesSPViTAttention (has_cls_token=%s)", name, has_cls) + + logger.info("AutoSP: wrapped %d ViT attention layer(s).", len(info.vit_attn_modules)) + + # ------------------------------------------------------------------ + # LLM decoder attention layers — warn, do not auto-wrap + # ------------------------------------------------------------------ + # DistributedAttention expects a Megatron-style (query, key, value) + # interface, but every class in _LLM_ATTN_CLASSNAMES uses the + # HuggingFace hidden_states interface. Wrapping them silently would + # produce incorrect behaviour at the first forward pass. Emit a + # per-layer warning so the user can configure SP manually. + for name, module in info.llm_attn_modules: + logger.warning( + "AutoSP: LLM attention '%s' (class %s) uses a HuggingFace hidden_states " + "interface that is incompatible with DistributedAttention's Q/K/V interface. " + "Skipping auto-wrap. Configure sequence parallelism for this layer manually.", name, + type(module).__name__) + + if info.llm_attn_modules: + logger.info("AutoSP: found %d LLM attention layer(s); skipped wrapping (see warnings above).", + len(info.llm_attn_modules)) + + # ------------------------------------------------------------------ + # Warn about the vision projection layer (Phase 2) + # ------------------------------------------------------------------ + if info.vision_projection_module is not None: + proj_name, _ = info.vision_projection_module + logger.warning( + "AutoSP detected vision projection layer '%s'. " + "ModalityFusionSPAdapter (Phase 2) is not yet automated. " + "Wrap this layer manually with ModalityFusionSPAdapter if you " + "need correct cross-modal sequence gather/scatter.", proj_name) + + return model + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _set_module_by_name(model: nn.Module, dotted_name: str, new_module: nn.Module) -> None: + """Replace the submodule at *dotted_name* with *new_module* in-place.""" + parts = dotted_name.split(".") + parent = model + for part in parts[:-1]: + parent = getattr(parent, part) + setattr(parent, parts[-1], new_module) diff --git a/deepspeed/sequence/autosp_detector.py b/deepspeed/sequence/autosp_detector.py new file mode 100644 index 000000000000..be9aab5b320d --- /dev/null +++ b/deepspeed/sequence/autosp_detector.py @@ -0,0 +1,115 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Automatically detect ViT encoder and LLM decoder attention modules in +multimodal models to guide AutoSP injection. + +Extend _VIT_ATTN_CLASSNAMES / _LLM_ATTN_CLASSNAMES when adding support for +new model architectures. +""" + +import torch.nn as nn +from dataclasses import dataclass, field +from typing import List, Optional, Tuple + +# --------------------------------------------------------------------------- +# Architecture registry +# --------------------------------------------------------------------------- + +# Known ViT attention class names (HuggingFace transformers naming) +_VIT_ATTN_CLASSNAMES = { + "ViTAttention", + "CLIPAttention", + "SiglipAttention", + "InternVisionAttention", + "Qwen2VLVisionAttention", + "Idefics2VisionAttention", + "PaliGemmaVisionAttention", +} + +# Whether each known ViT class uses a prepended CLS token. +# CLS is replicated on every rank and is NOT sharded across the sequence. +# Defaults to True for unknown classes (safe fallback). +_VIT_HAS_CLS_TOKEN = { + "ViTAttention": True, + "CLIPAttention": True, + "SiglipAttention": False, + "InternVisionAttention": False, + "Qwen2VLVisionAttention": False, + "Idefics2VisionAttention": False, + "PaliGemmaVisionAttention": False, +} + +# Known LLM decoder attention class names +_LLM_ATTN_CLASSNAMES = { + "LlamaAttention", + "MistralAttention", + "Qwen2Attention", + "InternLM2Attention", + "GemmaAttention", + "Phi3Attention", + "GPTNeoXAttention", + "FalconAttention", + "MptAttention", +} + +# Common attribute names that hold the vision-language projection layer +_VISION_PROJ_KEYWORDS = ( + "visual_projection", + "mm_projector", + "vision_proj", + "multi_modal_projector", + "img_projection", +) + +# --------------------------------------------------------------------------- +# Data structures +# --------------------------------------------------------------------------- + + +@dataclass +class SPModelInfo: + """Holds the detection results for a multimodal model.""" + + # (dotted_name, module) pairs for ViT attention layers + vit_attn_modules: List[Tuple[str, nn.Module]] = field(default_factory=list) + # (dotted_name, module) pairs for LLM decoder attention layers + llm_attn_modules: List[Tuple[str, nn.Module]] = field(default_factory=list) + # (dotted_name, module) for the outermost vision-language projection layer + vision_projection_module: Optional[Tuple[str, nn.Module]] = None + + +# --------------------------------------------------------------------------- +# Detection logic +# --------------------------------------------------------------------------- + + +def detect_model_sp_info(model: nn.Module) -> SPModelInfo: + """Recursively scan *model* and return an :class:`SPModelInfo`. + + The function identifies: + * ViT encoder attention layers → wrapped with :class:`UlyssesSPViTAttention` + * LLM decoder attention layers → wrapped with :class:`DistributedAttention` + * The vision-language projection layer → wrapped with + :class:`ModalityFusionSPAdapter` (Phase 2) + + To add support for a new architecture, simply register its attention class + names in ``_VIT_ATTN_CLASSNAMES`` or ``_LLM_ATTN_CLASSNAMES``. + """ + info = SPModelInfo() + for name, module in model.named_modules(): + cls_name = type(module).__name__ + if cls_name in _VIT_ATTN_CLASSNAMES: + info.vit_attn_modules.append((name, module)) + elif cls_name in _LLM_ATTN_CLASSNAMES: + info.llm_attn_modules.append((name, module)) + + # Record only the first (outermost) match to avoid double-wrapping + # nested projection modules. + if info.vision_projection_module is None: + if any(kw in name for kw in _VISION_PROJ_KEYWORDS): + info.vision_projection_module = (name, module) + + return info diff --git a/deepspeed/sequence/autosp_fusion.py b/deepspeed/sequence/autosp_fusion.py new file mode 100644 index 000000000000..3608cfe9a64f --- /dev/null +++ b/deepspeed/sequence/autosp_fusion.py @@ -0,0 +1,366 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +ModalityFusionSPAdapter — Phase 2 + +Handles the sequence scatter/gather at the vision-language boundary so that +the LLM decoder's :class:`~deepspeed.sequence.layer.DistributedAttention` +receives a uniformly sharded fused (visual + text) sequence. + +Workflow +-------- +:: + + [visual tokens, sharded] ──all-gather──► [visual tokens, full] + │ + splice into text + │ + [fused embeds, full] ──scatter──► [fused embeds, sharded per rank] + │ + LLM decoder (SP-aware) + +Usage +----- +After calling :func:`~deepspeed.sequence.auto_sp.auto_wrap_model_for_sp` to +wrap the ViT attention layers, attach the appropriate fusion adapter to the +vision-language projection layer **before** the first forward pass. Choose +the adapter that matches your model architecture:: + + from deepspeed.sequence.auto_sp import auto_wrap_model_for_sp + from deepspeed.sequence.autosp_fusion import ( + LlavaFusionAdapter, + InternVLFusionAdapter, + Qwen2VLFusionAdapter, + ) + from deepspeed.utils import groups + + # 1. Wrap ViT and LLM attention layers automatically. + sp_group = groups._get_sequence_parallel_group() + auto_wrap_model_for_sp(model, process_group=sp_group) + + # 2. Attach the fusion adapter for the vision-language projection layer. + # LLaVA — replaces image-placeholder tokens with visual tokens: + model.mm_projector = LlavaFusionAdapter( + model.mm_projector, sp_group, image_token_id=IMAGE_TOKEN_ID + ) + + # InternVL — replaces IMG_CONTEXT tokens 1-to-1 with visual tokens: + model.mm_projector = InternVLFusionAdapter( + model.mm_projector, sp_group, image_token_id=IMG_CONTEXT_TOKEN_ID + ) + + # Qwen2-VL — replaces tokens between vision_start/end pairs 1-to-1: + model.visual.merger = Qwen2VLFusionAdapter( + model.visual.merger, sp_group, + vision_start_token_id=VISION_START_ID, + vision_end_token_id=VISION_END_ID, + ) + + # 3. Use the model as normal; the adapter handles all SP gather/scatter. + outputs = model(input_ids=input_ids, pixel_values=pixel_values, ...) + +Status: Phase 2. ``_splice_visual_into_text`` is intentionally left as a +``NotImplementedError``; override it per model architecture (see docstring). +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import deepspeed.comm as dist + +# Default image placeholder token ID used by LLaVA-style models. +_DEFAULT_IMAGE_TOKEN_ID = -200 + + +class ModalityFusionSPAdapter(nn.Module): + """Wraps the vision projection layer and handles cross-modal sequence fusion. + + After projecting visual features, this adapter: + + 1. Gathers the sharded visual token slices from all SP ranks into a single + full visual token tensor. + 2. Splices the visual tokens into the text embedding sequence at the + positions marked by ``image_token_id`` placeholders. + 3. Pads and re-shards the fused sequence so that the subsequent LLM + decoder layers receive uniformly distributed sequence slices. + + Parameters + ---------- + projection: + The vision projection module (e.g. ``mm_projector``). + process_group: + The sequence-parallel process group. + image_token_id: + The token ID used as an image placeholder in the input IDs tensor. + Defaults to ``-200`` (LLaVA convention). + + Notes + ----- + Subclass this and override :meth:`_splice_visual_into_text` to adapt to a + specific multimodal architecture (LLaVA, InternVL, Qwen-VL, …). + """ + + def __init__(self, projection: nn.Module, process_group, image_token_id: int = _DEFAULT_IMAGE_TOKEN_ID) -> None: + super().__init__() + self.projection = projection + self.process_group = process_group + self.world_size = dist.get_world_size(process_group) + self.image_token_id = image_token_id + + def forward(self, visual_features: torch.Tensor, text_embeds: torch.Tensor, + input_ids: torch.Tensor) -> torch.Tensor: + """Project visual features and return a sharded fused embedding. + + Parameters + ---------- + visual_features: + Raw visual features from the ViT encoder. + Shape: ``[bs, local_visual_tokens, vit_hidden]``. + text_embeds: + Full text token embeddings (not sharded yet). + Shape: ``[bs, text_seq_len, lm_hidden]``. + input_ids: + Token IDs used to locate image placeholder positions. + Shape: ``[bs, text_seq_len]``. + + Returns + ------- + Sharded fused embedding for this rank. + Shape: ``[bs, local_fused_len, lm_hidden]``. + """ + # 1. Project visual features to the LLM hidden dimension + visual_embeds = self.projection(visual_features) # [bs, local_v, lm_hidden] + + # 2. All-gather visual slices from all SP ranks + parts = [torch.zeros_like(visual_embeds) for _ in range(self.world_size)] + dist.all_gather(parts, visual_embeds.contiguous(), group=self.process_group) + full_visual = torch.cat(parts, dim=1) # [bs, total_visual_tokens, lm_hidden] + + # 3. Splice visual tokens into text embedding sequence + fused = self._splice_visual_into_text(text_embeds, full_visual, input_ids) # [bs, fused_len, lm_hidden] + + # 4. Pad fused length to be divisible by world_size, then scatter + total_len = fused.shape[1] + pad = (self.world_size - total_len % self.world_size) % self.world_size + if pad > 0: + fused = F.pad(fused, (0, 0, 0, pad)) + + rank = dist.get_rank(self.process_group) + local_len = fused.shape[1] // self.world_size + return fused[:, rank * local_len:(rank + 1) * local_len, :].contiguous() + + def _splice_visual_into_text(self, text_embeds: torch.Tensor, visual_embeds: torch.Tensor, + input_ids: torch.Tensor) -> torch.Tensor: + """Replace image placeholder positions in *text_embeds* with *visual_embeds*. + + This is intentionally architecture-specific. The default raises + ``NotImplementedError``; override this method for each supported model. + + Reference implementations: + * LLaVA: ``LlavaMetaForCausalLM.prepare_inputs_embeds`` + * InternVL: ``InternVLChatModel.extract_feature`` + * Qwen-VL: ``Qwen2VLForConditionalGeneration.get_rope_index`` + """ + raise NotImplementedError(f"{type(self).__name__}._splice_visual_into_text is not implemented. " + "Subclass ModalityFusionSPAdapter and override this method to match " + "your model's prepare_inputs_embeds logic.") + + +class LlavaFusionAdapter(ModalityFusionSPAdapter): + """LLaVA-style splice: replace each image placeholder token with visual tokens. + + Follows the logic of ``LlavaMetaForCausalLM.prepare_inputs_labels_for_multimodal``: + for each sample, locate ``image_token_id`` placeholders in ``input_ids``, + remove them, and insert the corresponding visual token chunk in their place. + + Visual tokens for a sample are split evenly across the number of image + placeholders found. This matches the common single-image case (one + placeholder per sample) and simple multi-image cases where every image + contributes the same number of tokens. + + Parameters are inherited from :class:`ModalityFusionSPAdapter`. + """ + + def _splice_visual_into_text(self, text_embeds: torch.Tensor, visual_embeds: torch.Tensor, + input_ids: torch.Tensor) -> torch.Tensor: + bs, text_len, hidden = text_embeds.shape + device = text_embeds.device + + fused_samples = [] + for i in range(bs): + img_pos = (input_ids[i] == self.image_token_id).nonzero(as_tuple=True)[0] + num_images = img_pos.numel() + + if num_images == 0: + # No image in this sample — keep text embeddings unchanged. + fused_samples.append(text_embeds[i]) + continue + + # Split all visual tokens evenly across the image placeholders. + visual_chunks = torch.chunk(visual_embeds[i], num_images, dim=0) + + segments = [] + prev = 0 + for j, pos in enumerate(img_pos.tolist()): + # Text segment before this placeholder. + if pos > prev: + segments.append(text_embeds[i, prev:pos]) + # Visual tokens replacing this placeholder. + segments.append(visual_chunks[j]) + # Skip the placeholder token itself. + prev = pos + 1 + + # Remaining text after the last placeholder. + if prev < text_len: + segments.append(text_embeds[i, prev:]) + + fused_samples.append(torch.cat(segments, dim=0)) + + # Pad all samples to the same length so they stack into a tensor. + max_len = max(s.shape[0] for s in fused_samples) + out = torch.zeros(bs, max_len, hidden, dtype=text_embeds.dtype, device=device) + for i, s in enumerate(fused_samples): + out[i, :s.shape[0]] = s + return out + + +class InternVLFusionAdapter(ModalityFusionSPAdapter): + """InternVL-style splice: replace IMG_CONTEXT token runs with visual tokens. + + InternVL encodes each image as `` ×N `` + inside the token sequence. Each ``IMG_CONTEXT`` token (``image_token_id``) + is a 1-to-1 placeholder for one ViT visual token. This adapter locates + every contiguous run of ``image_token_id`` tokens and replaces them with + the corresponding slice of *visual_embeds*, while preserving the + ``IMG_START`` / ``IMG_END`` boundary embeddings unchanged. + + Because the replacement is 1-to-1, the output sequence length equals the + input sequence length (no length change). + + Parameters are inherited from :class:`ModalityFusionSPAdapter`. + Set ``image_token_id`` to the ``IMG_CONTEXT`` token id used by the model + (e.g. the id of ````). + """ + + def _splice_visual_into_text(self, text_embeds: torch.Tensor, visual_embeds: torch.Tensor, + input_ids: torch.Tensor) -> torch.Tensor: + # Start from a clone of text embeddings; we only overwrite context positions. + out = text_embeds.clone() + bs = text_embeds.shape[0] + + for i in range(bs): + ctx_pos = (input_ids[i] == self.image_token_id).nonzero(as_tuple=True)[0] + if ctx_pos.numel() == 0: + continue + # ctx_pos lists every IMG_CONTEXT index in order. visual_embeds[i] + # has exactly ctx_pos.numel() tokens (one per context position). + out[i, ctx_pos] = visual_embeds[i, :ctx_pos.numel()] + + return out + + +class Qwen2VLFusionAdapter(nn.Module): + """Qwen2-VL-style splice: visual tokens enclosed by vision_start/end tokens. + + Qwen2-VL wraps each image's visual tokens with a pair of special boundary + tokens in ``input_ids``: ``vision_start_token_id`` and + ``vision_end_token_id``. The placeholder tokens between each + (start, end) pair are replaced 1-to-1 by the projected visual token + embeddings. The boundary token embeddings are kept unchanged. + + Because the replacement is 1-to-1, the output sequence length equals the + input sequence length. + + Parameters + ---------- + projection: + The vision projection module (e.g. ``visual.merger``). + process_group: + The sequence-parallel process group. + vision_start_token_id: + Token id of ``<|vision_start|>``. + vision_end_token_id: + Token id of ``<|vision_end|>``. + """ + + def __init__(self, projection: nn.Module, process_group, vision_start_token_id: int, + vision_end_token_id: int) -> None: + super().__init__() + self.projection = projection + self.process_group = process_group + self.world_size = dist.get_world_size(process_group) + self.vision_start_token_id = vision_start_token_id + self.vision_end_token_id = vision_end_token_id + + def forward(self, visual_features: torch.Tensor, text_embeds: torch.Tensor, + input_ids: torch.Tensor) -> torch.Tensor: + """Project visual features and return a sharded fused embedding. + + Parameters + ---------- + visual_features: + Raw visual features from the ViT encoder. + Shape: ``[bs, local_visual_tokens, vit_hidden]``. + text_embeds: + Full text token embeddings (not sharded yet). + Shape: ``[bs, text_seq_len, lm_hidden]``. + input_ids: + Token IDs used to locate vision_start/end boundaries. + Shape: ``[bs, text_seq_len]``. + + Returns + ------- + Sharded fused embedding for this rank. + Shape: ``[bs, local_fused_len, lm_hidden]``. + """ + # 1. Project visual features to the LLM hidden dimension. + visual_embeds = self.projection(visual_features) # [bs, local_v, lm_hidden] + + # 2. All-gather visual slices from all SP ranks. + parts = [torch.zeros_like(visual_embeds) for _ in range(self.world_size)] + dist.all_gather(parts, visual_embeds.contiguous(), group=self.process_group) + full_visual = torch.cat(parts, dim=1) # [bs, total_visual_tokens, lm_hidden] + + # 3. Replace placeholder positions in text with visual tokens (length-preserving). + fused = self._splice_visual_into_text(text_embeds, full_visual, input_ids) + + # 4. Pad fused length to be divisible by world_size, then scatter. + total_len = fused.shape[1] + pad = (self.world_size - total_len % self.world_size) % self.world_size + if pad > 0: + fused = F.pad(fused, (0, 0, 0, pad)) + + rank = dist.get_rank(self.process_group) + local_len = fused.shape[1] // self.world_size + return fused[:, rank * local_len:(rank + 1) * local_len, :].contiguous() + + def _splice_visual_into_text(self, text_embeds: torch.Tensor, visual_embeds: torch.Tensor, + input_ids: torch.Tensor) -> torch.Tensor: + """Replace inner placeholder tokens between vision_start/end pairs with visual embeddings.""" + out = text_embeds.clone() + bs = text_embeds.shape[0] + + for i in range(bs): + start_pos = (input_ids[i] == self.vision_start_token_id).nonzero(as_tuple=True)[0] + end_pos = (input_ids[i] == self.vision_end_token_id).nonzero(as_tuple=True)[0] + + if start_pos.numel() == 0: + continue + + # Accumulate inner placeholder positions across all start/end pairs. + # Inner positions are (start+1) .. (end-1) inclusive, i.e. excluding + # the boundary tokens themselves. + inner_positions = [] + for s, e in zip(start_pos.tolist(), end_pos.tolist()): + inner_positions.extend(range(s + 1, e)) + + if not inner_positions: + continue + + inner_pos = torch.tensor(inner_positions, dtype=torch.long, device=text_embeds.device) + out[i, inner_pos] = visual_embeds[i, :len(inner_positions)] + + return out diff --git a/deepspeed/sequence/autosp_vit.py b/deepspeed/sequence/autosp_vit.py new file mode 100644 index 000000000000..09bebbbac8f3 --- /dev/null +++ b/deepspeed/sequence/autosp_vit.py @@ -0,0 +1,175 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Ulysses-style sequence-parallel wrapper for ViT encoder attention layers. + +Design notes +------------ +ViT self-attention is non-causal: every patch token attends to every other +patch token. This means a straightforward per-rank local attention (as used +for causal LLMs) would be *incorrect* — each rank must have access to the +full key/value context. + +We therefore use a **gather-compute-scatter** pattern: + +1. Input arrives already sharded along the sequence dimension (each rank owns + ``local_patches = num_patches // world_size`` consecutive patches). +2. Before attention we **all-gather** patch tokens so that every rank runs the + full ViT attention over the complete patch sequence. This keeps the + computation equivalent to single-device execution. +3. The output is **scattered** back so that each rank returns only its local + slice, matching the sharded input contract expected by downstream layers. + +Memory benefit: activations *outside* the attention block (e.g. feed-forward +layers, layer norms) are stored only locally, reducing per-rank memory +proportional to ``world_size``. + +The ``cls`` token (if present) is replicated on every rank and is not split +across the sequence dimension. Each rank appends its local patches to the +same ``cls`` token before calling the wrapped attention. + +Padding: when ``num_patches % world_size != 0``, shorter shards are +zero-padded to a uniform size for ``all_gather``. The padding is stripped +*before* the attention call by trimming each rank's contribution to its true +length, so the wrapped attention always sees exactly ``num_patches`` real +tokens — identical to single-device execution and free of softmax pollution +from dummy tokens. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import deepspeed.comm as dist + + +class UlyssesSPViTAttention(nn.Module): + """Sequence-parallel wrapper for an opaque ViT attention module. + + Parameters + ---------- + attn: + The original ViT attention layer (any ``nn.Module`` that maps + ``hidden_states`` → ``hidden_states`` or a tuple whose first element + is the attention output tensor). + process_group: + The sequence-parallel process group. + has_cls_token: + Set to ``True`` (default) when the first token in the sequence is a + ``[CLS]`` token that should be replicated on every rank rather than + sharded. + """ + + def __init__(self, attn: nn.Module, process_group, has_cls_token: bool = True) -> None: + super().__init__() + self.attn = attn + self.process_group = process_group + self.world_size = dist.get_world_size(process_group) + self.has_cls_token = has_cls_token + + # ------------------------------------------------------------------ + # forward + # ------------------------------------------------------------------ + + def forward(self, hidden_states: torch.Tensor, **kwargs): + """ + Parameters + ---------- + hidden_states: + Shape ``[bs, local_seq_len, hidden_dim]`` where + ``local_seq_len = (1 + local_patches)`` if ``has_cls_token`` else + ``local_patches``. Each rank holds a contiguous slice of patches. + **kwargs: + Passed through to the wrapped attention (e.g. ``attention_mask``, + ``head_mask``, ``output_attentions``). + + Returns + ------- + Same shape as input (or a tuple whose first element matches the input + shape, preserving whatever the wrapped module returns). + """ + bs, local_seq_len, hidden_dim = hidden_states.shape + + if self.has_cls_token: + # CLS token is replicated on every rank — not part of the sharded seq + cls_token = hidden_states[:, :1, :] + local_patches = hidden_states[:, 1:, :] + else: + local_patches = hidden_states + + local_patch_len = local_patches.shape[1] + + # ------------------------------------------------------------------- + # 1. All-gather patches from all ranks to reconstruct the full sequence + # ------------------------------------------------------------------- + # When num_patches % world_size != 0, ranks hold different shard sizes. + # We all-gather every rank's local_patch_len so we can: + # (a) zero-pad shorter slices to uniform size for all_gather, and + # (b) strip the padding per rank *before* calling attention, so that + # the wrapped module never sees dummy tokens (which would corrupt + # the softmax normalisation). + len_bufs = [torch.zeros(1, dtype=torch.long, device=local_patches.device) for _ in range(self.world_size)] + dist.all_gather(len_bufs, + torch.tensor([local_patch_len], dtype=torch.long, device=local_patches.device), + group=self.process_group) + all_lens = [int(t.item()) for t in len_bufs] + max_local_len = max(all_lens) + + pad_len = max_local_len - local_patch_len + if pad_len > 0: + # Append zero rows so this rank's buffer matches the largest shard. + local_patches_padded = F.pad(local_patches, (0, 0, 0, pad_len)) + else: + local_patches_padded = local_patches + + gathered = [ + torch.zeros(bs, max_local_len, hidden_dim, dtype=local_patches.dtype, device=local_patches.device) + for _ in range(self.world_size) + ] + dist.all_gather(gathered, local_patches_padded.contiguous(), group=self.process_group) + + # Strip per-rank padding before concatenation so attention only sees + # the true num_patches tokens, identical to single-device execution. + real_parts = [gathered[r][:, :all_lens[r], :] for r in range(self.world_size)] + full_patches = torch.cat(real_parts, dim=1) # [bs, total_real_patches, hidden_dim] + + # ------------------------------------------------------------------- + # 2. Build the full input (prepend CLS if needed) and call attention + # ------------------------------------------------------------------- + if self.has_cls_token: + full_input = torch.cat([cls_token, full_patches], dim=1) + else: + full_input = full_patches + + attn_out = self.attn(full_input, **kwargs) + + # Unwrap tuple: some ViT implementations return (attn_output, attn_weights) + if isinstance(attn_out, (tuple, list)): + full_out, *extra = attn_out + else: + full_out = attn_out + extra = [] + + # ------------------------------------------------------------------- + # 3. Scatter output: each rank keeps its local slice of the real patches. + # Because padding was stripped before attention, scatter offsets are + # the cumulative sums of all_lens, not rank * max_local_len. + # ------------------------------------------------------------------- + if self.has_cls_token: + cls_out = full_out[:, :1, :] + patch_out = full_out[:, 1:, :] + else: + patch_out = full_out + + rank = dist.get_rank(self.process_group) + start = sum(all_lens[:rank]) + local_out = patch_out[:, start:start + local_patch_len, :].contiguous() + + if self.has_cls_token: + local_out = torch.cat([cls_out, local_out], dim=1) + + if extra: + return (local_out, *extra) + return local_out diff --git a/deepspeed/sequence/test_autosp.py b/deepspeed/sequence/test_autosp.py new file mode 100644 index 000000000000..771a9e7bb6b7 --- /dev/null +++ b/deepspeed/sequence/test_autosp.py @@ -0,0 +1,724 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Unit tests for AutoSP multimodal sequence parallelism: + - autosp_detector: model scanning + - UlyssesSPViTAttention: ViT SP wrapper + - auto_wrap_model_for_sp: end-to-end wrapping + - ModalityFusionSPAdapter: cross-modal gather/scatter + - LlavaFusionAdapter: LLaVA-style visual token splice + - InternVLFusionAdapter: InternVL-style IMG_CONTEXT token splice + - Qwen2VLFusionAdapter: Qwen2-VL vision_start/end bounded splice +""" + +import pytest +import torch +import torch.nn as nn + +from deepspeed.sequence.autosp_detector import (SPModelInfo, _LLM_ATTN_CLASSNAMES, _VIT_ATTN_CLASSNAMES, + detect_model_sp_info) +from deepspeed.sequence.autosp_fusion import (InternVLFusionAdapter, LlavaFusionAdapter, ModalityFusionSPAdapter, + Qwen2VLFusionAdapter) +from deepspeed.sequence.autosp_vit import UlyssesSPViTAttention +from deepspeed.sequence.auto_sp import _set_module_by_name, auto_wrap_model_for_sp +from deepspeed.sequence.layer import DistributedAttention + +# --------------------------------------------------------------------------- +# Minimal fake modules that mimic the interface of real attention layers +# without requiring a GPU or a real transformer model. +# --------------------------------------------------------------------------- + + +class _FakeViTAttn(nn.Module): + """Identity ViT attention — returns hidden_states unchanged.""" + + def forward(self, hidden_states, **kwargs): + return hidden_states + + +class _FakeViTAttnTuple(nn.Module): + """ViT attention that returns a (output, weights) tuple.""" + + def forward(self, hidden_states, **kwargs): + weights = torch.zeros(hidden_states.shape[0], 1, hidden_states.shape[1], hidden_states.shape[1]) + return hidden_states, weights + + +class _FakeLLMAttn(nn.Module): + """Identity LLM attention.""" + + def forward(self, query, key, value, *args, **kwargs): + return query + + +# Register fake class names so the detector recognises them +_VIT_ATTN_CLASSNAMES.add("_FakeViTAttn") +_VIT_ATTN_CLASSNAMES.add("_FakeViTAttnTuple") +_LLM_ATTN_CLASSNAMES.add("_FakeLLMAttn") + + +class _FakeMultimodalModel(nn.Module): + """Minimal multimodal model with one ViT and one LLM attention layer.""" + + def __init__(self): + super().__init__() + self.vision_encoder = nn.ModuleList([_FakeViTAttn()]) + self.mm_projector = nn.Linear(64, 64) + self.llm = nn.ModuleList([_FakeLLMAttn()]) + + +class _FakeViTOnlyModel(nn.Module): + + def __init__(self, num_layers=3): + super().__init__() + self.layers = nn.ModuleList([_FakeViTAttn() for _ in range(num_layers)]) + + +class _FakeLLMOnlyModel(nn.Module): + """Minimal LLM-only model with multiple decoder attention layers.""" + + def __init__(self, num_layers=2): + super().__init__() + self.layers = nn.ModuleList([_FakeLLMAttn() for _ in range(num_layers)]) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_mock_process_group(world_size: int, rank: int): + """Return a mock object that satisfies dist.get_world_size / get_rank.""" + import unittest.mock as mock + import deepspeed.comm as dist + + pg = mock.MagicMock() + dist.get_world_size = mock.MagicMock(return_value=world_size) + dist.get_rank = mock.MagicMock(return_value=rank) + + def _fake_all_gather(tensor_list, tensor, group=None): + for t in tensor_list: + t.copy_(tensor) + + dist.all_gather = _fake_all_gather + return pg + + +# --------------------------------------------------------------------------- +# autosp_detector tests +# --------------------------------------------------------------------------- + + +class TestAutospDetector: + + def test_detects_vit_and_llm(self): + model = _FakeMultimodalModel() + info = detect_model_sp_info(model) + assert len(info.vit_attn_modules) == 1 + assert len(info.llm_attn_modules) == 1 + + def test_detects_vision_projection(self): + model = _FakeMultimodalModel() + info = detect_model_sp_info(model) + assert info.vision_projection_module is not None + name, module = info.vision_projection_module + assert "mm_projector" in name + + def test_detects_multiple_vit_layers(self): + model = _FakeViTOnlyModel(num_layers=4) + info = detect_model_sp_info(model) + assert len(info.vit_attn_modules) == 4 + assert len(info.llm_attn_modules) == 0 + assert info.vision_projection_module is None + + def test_empty_model_returns_empty_info(self): + model = nn.Sequential(nn.Linear(8, 8)) + info = detect_model_sp_info(model) + assert isinstance(info, SPModelInfo) + assert len(info.vit_attn_modules) == 0 + assert len(info.llm_attn_modules) == 0 + + def test_only_first_projection_is_recorded(self): + """Multiple projection-like names → only the outermost is recorded.""" + + class _M(nn.Module): + + def __init__(self): + super().__init__() + self.mm_projector = nn.Sequential(nn.Linear(8, 8)) + self.mm_projector.visual_projection = nn.Linear(8, 8) + + model = _M() + info = detect_model_sp_info(model) + assert info.vision_projection_module is not None + # Should be the outermost "mm_projector", not the nested one + name, _ = info.vision_projection_module + assert name == "mm_projector" + + +# --------------------------------------------------------------------------- +# UlyssesSPViTAttention tests (CPU, rank-0 simulation via mocks) +# --------------------------------------------------------------------------- + + +class TestUlyssesSPViTAttention: + + @pytest.mark.parametrize("has_cls_token", [True, False]) + @pytest.mark.parametrize("num_patches,world_size", [ + (16, 4), + (16, 2), + (9, 3), + ]) + def test_output_shape_matches_input(self, has_cls_token, num_patches, world_size): + """Output shape must equal input shape for any padding scenario.""" + pg = _make_mock_process_group(world_size=world_size, rank=0) + attn = _FakeViTAttn() + wrapper = UlyssesSPViTAttention(attn, pg, has_cls_token=has_cls_token) + + local_patches = num_patches // world_size + seq_len = (1 + local_patches) if has_cls_token else local_patches + x = torch.randn(2, seq_len, 32) + + out = wrapper(x) + assert out.shape == x.shape, f"Expected {x.shape}, got {out.shape}" + + def test_tuple_output_unwrapped_correctly(self): + """Wrappers that return (output, weights) tuples are handled.""" + pg = _make_mock_process_group(world_size=2, rank=0) + attn = _FakeViTAttnTuple() + wrapper = UlyssesSPViTAttention(attn, pg, has_cls_token=False) + + x = torch.randn(1, 8, 16) # 8 patches, 2 ranks → 4 local each + result = wrapper(x) + # Should return a tuple: (attention_output, attention_weights) + assert isinstance(result, tuple) + assert result[0].shape == x.shape + + def test_identity_attn_preserves_values(self): + """When attn is identity, output values should match input values.""" + world_size = 2 + pg = _make_mock_process_group(world_size=world_size, rank=0) + attn = _FakeViTAttn() + wrapper = UlyssesSPViTAttention(attn, pg, has_cls_token=True) + + # Each rank holds cls + 4 local patches + x = torch.arange(2 * 5 * 4, dtype=torch.float).reshape(2, 5, 4) + out = wrapper(x) + # CLS token should be identical + assert torch.allclose(out[:, :1, :], x[:, :1, :]) + # Local patch slice should match input patches for identity attn + assert torch.allclose(out[:, 1:, :], x[:, 1:, :]) + + +# --------------------------------------------------------------------------- +# auto_wrap_model_for_sp tests +# --------------------------------------------------------------------------- + + +class TestAutoWrapModelForSP: + + def test_vit_layers_replaced(self): + pg = _make_mock_process_group(world_size=2, rank=0) + model = _FakeViTOnlyModel(num_layers=2) + auto_wrap_model_for_sp(model, pg) + for layer in model.layers: + assert isinstance(layer, UlyssesSPViTAttention) + + def test_raises_on_unknown_model(self): + pg = _make_mock_process_group(world_size=2, rank=0) + model = nn.Sequential(nn.Linear(8, 8)) + with pytest.raises(ValueError, match="no recognisable attention"): + auto_wrap_model_for_sp(model, pg) + + def test_set_module_by_name_shallow(self): + model = _FakeViTOnlyModel(num_layers=1) + new_mod = nn.Linear(4, 4) + _set_module_by_name(model, "layers.0", new_mod) + assert model.layers[0] is new_mod + + def test_set_module_by_name_deep(self): + model = _FakeMultimodalModel() + new_mod = nn.Identity() + _set_module_by_name(model, "vision_encoder.0", new_mod) + assert model.vision_encoder[0] is new_mod + + def test_llm_layers_replaced_with_distributed_attention(self): + """LLM attention layers must be wrapped with DistributedAttention.""" + pg = _make_mock_process_group(world_size=2, rank=0) + model = _FakeLLMOnlyModel(num_layers=3) + auto_wrap_model_for_sp(model, pg) + for layer in model.layers: + assert isinstance(layer, DistributedAttention) + + def test_multimodal_model_wraps_both_branches(self): + """Both ViT and LLM attention layers must be replaced in a combined model.""" + pg = _make_mock_process_group(world_size=2, rank=0) + model = _FakeMultimodalModel() + returned = auto_wrap_model_for_sp(model, pg) + # auto_wrap_model_for_sp must return the same object (in-place) + assert returned is model + assert isinstance(model.vision_encoder[0], UlyssesSPViTAttention) + assert isinstance(model.llm[0], DistributedAttention) + + def test_original_module_preserved_inside_wrapper(self): + """The wrapped module should still be accessible inside the wrapper.""" + pg = _make_mock_process_group(world_size=2, rank=0) + model = _FakeViTOnlyModel(num_layers=1) + original_attn = model.layers[0] + auto_wrap_model_for_sp(model, pg) + assert model.layers[0].attn is original_attn + + +# --------------------------------------------------------------------------- +# ModalityFusionSPAdapter tests +# --------------------------------------------------------------------------- + + +class _ConcatFusionAdapter(ModalityFusionSPAdapter): + """Concrete subclass that appends visual tokens after text tokens.""" + + def _splice_visual_into_text(self, text_embeds, visual_embeds, input_ids): + return torch.cat([text_embeds, visual_embeds], dim=1) + + +class TestModalityFusionSPAdapter: + + def test_base_class_raises_not_implemented(self): + """The base _splice_visual_into_text must raise NotImplementedError.""" + pg = _make_mock_process_group(world_size=2, rank=0) + adapter = ModalityFusionSPAdapter(nn.Identity(), pg) + with pytest.raises(NotImplementedError): + adapter._splice_visual_into_text(None, None, None) + + @pytest.mark.parametrize("world_size,local_v,text_len,hidden", [ + (2, 4, 6, 8), + (4, 3, 5, 16), + (1, 8, 8, 4), + ]) + def test_output_shape(self, world_size, local_v, text_len, hidden): + """Output local_len must equal ceil(fused_len / world_size).""" + pg = _make_mock_process_group(world_size=world_size, rank=0) + adapter = _ConcatFusionAdapter(nn.Identity(), pg) + + bs = 2 + visual = torch.randn(bs, local_v, hidden) + text = torch.randn(bs, text_len, hidden) + ids = torch.zeros(bs, text_len, dtype=torch.long) + + out = adapter(visual, text, ids) + + # all_gather mock copies local_v to each of world_size slots + fused_len = text_len + local_v * world_size + pad = (world_size - fused_len % world_size) % world_size + expected_local = (fused_len + pad) // world_size + assert out.shape == (bs, expected_local, hidden), f"Expected ({bs},{expected_local},{hidden}), got {out.shape}" + + def test_padding_produces_valid_output_when_not_divisible(self): + """When fused_len % world_size != 0, padding must not raise and output is well-formed.""" + world_size = 4 + # text_len=5, local_v=3 → fused_len = 5 + 3*4 = 17, needs padding of 3 + pg = _make_mock_process_group(world_size=world_size, rank=0) + adapter = _ConcatFusionAdapter(nn.Identity(), pg) + + bs, local_v, text_len, hidden = 1, 3, 5, 4 + out = adapter( + torch.randn(bs, local_v, hidden), + torch.randn(bs, text_len, hidden), + torch.zeros(bs, text_len, dtype=torch.long), + ) + # padded_len = 20, local_len = 5 + assert out.shape == (bs, 5, hidden) + + def test_no_padding_when_divisible(self): + """When fused_len is already divisible, no extra tokens should be added.""" + world_size = 4 + # text_len=4, local_v=4 → fused_len = 4 + 4*4 = 20, divisible by 4 + pg = _make_mock_process_group(world_size=world_size, rank=0) + adapter = _ConcatFusionAdapter(nn.Identity(), pg) + + bs, local_v, text_len, hidden = 1, 4, 4, 8 + out = adapter( + torch.randn(bs, local_v, hidden), + torch.randn(bs, text_len, hidden), + torch.zeros(bs, text_len, dtype=torch.long), + ) + assert out.shape == (bs, 5, hidden) # 20 // 4 = 5 + + def test_different_ranks_return_different_slices(self): + """Rank 0 and rank 1 must return different slices of the fused sequence.""" + world_size = 2 + bs, local_v, text_len, hidden = 1, 4, 4, 8 + # Use distinct text vs visual values so slices clearly differ + text = torch.zeros(bs, text_len, hidden) + visual = torch.ones(bs, local_v, hidden) + ids = torch.zeros(bs, text_len, dtype=torch.long) + + outputs = {} + for rank in range(world_size): + pg = _make_mock_process_group(world_size=world_size, rank=rank) + adapter = _ConcatFusionAdapter(nn.Identity(), pg) + outputs[rank] = adapter(visual.clone(), text.clone(), ids.clone()) + + # fused = [0,0,0,0, 1,1,1,1, 1,1,1,1] (text zeros then visual ones x2) + # rank 0: indices 0-5, rank 1: indices 6-11 + assert not torch.allclose(outputs[0], outputs[1]) + + def test_projection_is_applied(self): + """Projection layer must transform visual features before gather.""" + world_size = 2 + pg = _make_mock_process_group(world_size=world_size, rank=0) + + # Use a projection that doubles all values + class _DoubleProjection(nn.Module): + + def forward(self, x): + return x * 2.0 + + adapter = _ConcatFusionAdapter(_DoubleProjection(), pg) + bs, local_v, text_len, hidden = 1, 4, 4, 8 + visual = torch.ones(bs, local_v, hidden) + text = torch.zeros(bs, text_len, hidden) + ids = torch.zeros(bs, text_len, dtype=torch.long) + + out = adapter(visual, text, ids) + # The visual part of the output should have value 2.0 (doubled), not 1.0 + # rank 0 gets the first local_len tokens; fused = [text(0)*4, visual(2)*8] + # Since text_len=4 and local_len=6, rank0 slice starts with text zeros + # and ends with some visual twos. + assert out.max().item() == pytest.approx(2.0) + + +# --------------------------------------------------------------------------- +# LlavaFusionAdapter tests (tests _splice_visual_into_text directly) +# --------------------------------------------------------------------------- + +_IMAGE_ID = -200 # matches ModalityFusionSPAdapter default + + +def _make_llava_adapter(world_size=2, rank=0): + pg = _make_mock_process_group(world_size=world_size, rank=rank) + return LlavaFusionAdapter(nn.Identity(), pg, image_token_id=_IMAGE_ID) + + +class TestLlavaFusionAdapter: + + def test_single_image_fused_shape(self): + """One image placeholder per sample → fused length = text_len - 1 + num_visual.""" + adapter = _make_llava_adapter() + bs, text_len, num_vis, hidden = 2, 6, 4, 8 + # Place a single image placeholder at position 2. + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[:, 2] = _IMAGE_ID + text = torch.randn(bs, text_len, hidden) + visual = torch.randn(bs, num_vis, hidden) + + fused = adapter._splice_visual_into_text(text, visual, ids) + # placeholder is removed and replaced by num_vis tokens + assert fused.shape == (bs, text_len - 1 + num_vis, hidden) + + def test_text_values_preserved_around_image(self): + """Text tokens before and after the placeholder must be numerically intact.""" + adapter = _make_llava_adapter() + bs, text_len, num_vis, hidden = 1, 5, 3, 4 + # Placeholder at index 2: text = [A, B, , C, D] + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[0, 2] = _IMAGE_ID + text = torch.arange(bs * text_len * hidden, dtype=torch.float).reshape(bs, text_len, hidden) + visual = torch.ones(bs, num_vis, hidden) * 99.0 + + fused = adapter._splice_visual_into_text(text, visual, ids) + # fused = [A, B, vis0, vis1, vis2, C, D] + assert torch.allclose(fused[0, :2], text[0, :2]) # A, B preserved + assert torch.allclose(fused[0, 5:], text[0, 3:]) # C, D preserved + assert torch.allclose(fused[0, 2:5], visual[0]) # visual inserted + + def test_no_image_token_returns_text_unchanged(self): + """When input_ids has no placeholder, output equals text_embeds exactly.""" + adapter = _make_llava_adapter() + bs, text_len, hidden = 2, 6, 8 + ids = torch.zeros(bs, text_len, dtype=torch.long) # no -200 + text = torch.randn(bs, text_len, hidden) + visual = torch.randn(bs, 4, hidden) + + fused = adapter._splice_visual_into_text(text, visual, ids) + assert fused.shape == (bs, text_len, hidden) + assert torch.allclose(fused, text) + + def test_multi_image_splice(self): + """Two placeholders per sample → visual tokens split evenly between them.""" + adapter = _make_llava_adapter() + bs, text_len, num_vis, hidden = 1, 7, 6, 4 + # Placeholders at index 1 and 4: [t0, , t2, t3, , t5, t6] + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[0, 1] = _IMAGE_ID + ids[0, 4] = _IMAGE_ID + text = torch.zeros(bs, text_len, hidden) + # First 3 visual tokens = 1.0, last 3 = 2.0 (so we can tell them apart) + visual = torch.cat([torch.ones(bs, 3, hidden), torch.full((bs, 3, hidden), 2.0)], dim=1) + + fused = adapter._splice_visual_into_text(text, visual, ids) + # Expected fused length: 7 - 2 placeholders + 6 visual = 11 + assert fused.shape == (bs, 11, hidden) + # First chunk (indices 1-3) should be 1.0 + assert torch.allclose(fused[0, 1:4], torch.ones(3, hidden)) + # Second chunk (indices 6-8) should be 2.0 + assert torch.allclose(fused[0, 6:9], torch.full((3, hidden), 2.0)) + + def test_batch_padding_when_lengths_differ(self): + """Samples with different numbers of image tokens are padded to max length.""" + adapter = _make_llava_adapter() + hidden = 4 + # Sample 0: 1 placeholder in a 4-token sequence + 2 visual → fused len = 5 + # Sample 1: no placeholder in a 4-token sequence → fused len = 4 + ids = torch.zeros(2, 4, dtype=torch.long) + ids[0, 1] = _IMAGE_ID + text = torch.ones(2, 4, hidden) + visual = torch.ones(2, 2, hidden) * 3.0 + + fused = adapter._splice_visual_into_text(text, visual, ids) + # Max fused length is 5; sample 1 padded with zeros at the end. + assert fused.shape == (2, 5, hidden) + assert torch.all(fused[1, 4:] == 0) # padding tokens are zero + + def test_forward_end_to_end_shape(self): + """Full forward pass through LlavaFusionAdapter returns the correct shard shape.""" + world_size = 2 + pg = _make_mock_process_group(world_size=world_size, rank=0) + adapter = LlavaFusionAdapter(nn.Identity(), pg, image_token_id=_IMAGE_ID) + + bs, local_v, text_len, hidden = 1, 4, 6, 8 + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[0, 2] = _IMAGE_ID # one placeholder + visual = torch.randn(bs, local_v, hidden) + text = torch.randn(bs, text_len, hidden) + + out = adapter(visual, text, ids) + # fused_len = text_len - 1 + local_v * world_size = 5 + 8 = 13 + # padded to 14 (next multiple of 2), local = 7 + assert out.shape == (bs, 7, hidden) + + +# --------------------------------------------------------------------------- +# InternVLFusionAdapter tests (tests _splice_visual_into_text directly) +# --------------------------------------------------------------------------- + +_CONTEXT_ID = 92546 # arbitrary IMG_CONTEXT token id for tests +_START_ID = 92545 +_END_ID = 92547 + + +def _make_internvl_adapter(world_size=2, rank=0): + pg = _make_mock_process_group(world_size=world_size, rank=rank) + return InternVLFusionAdapter(nn.Identity(), pg, image_token_id=_CONTEXT_ID) + + +class TestInternVLFusionAdapter: + + def test_context_tokens_replaced_with_visual(self): + """IMG_CONTEXT positions must carry visual embeddings after splice.""" + adapter = _make_internvl_adapter() + bs, text_len, hidden = 1, 7, 4 + # Layout: [t0, START, ctx, ctx, ctx, END, t6] + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[0, 2] = _CONTEXT_ID + ids[0, 3] = _CONTEXT_ID + ids[0, 4] = _CONTEXT_ID + + text = torch.zeros(bs, text_len, hidden) + visual = torch.ones(bs, 3, hidden) * 7.0 + + fused = adapter._splice_visual_into_text(text, visual, ids) + assert torch.allclose(fused[0, 2:5], visual[0]) + + def test_sequence_length_preserved(self): + """Output length must equal input length (1-to-1 replacement).""" + adapter = _make_internvl_adapter() + bs, text_len, hidden = 2, 10, 8 + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[:, 3:7] = _CONTEXT_ID # 4 context tokens per sample + text = torch.randn(bs, text_len, hidden) + visual = torch.randn(bs, 4, hidden) + + fused = adapter._splice_visual_into_text(text, visual, ids) + assert fused.shape == (bs, text_len, hidden) + + def test_boundary_tokens_preserved(self): + """IMG_START and IMG_END embeddings must be unchanged after splice.""" + adapter = _make_internvl_adapter() + bs, text_len, hidden = 1, 5, 4 + # [START, ctx, ctx, END, text] + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[0, 1] = _CONTEXT_ID + ids[0, 2] = _CONTEXT_ID + + text = torch.arange(bs * text_len * hidden, dtype=torch.float).reshape(bs, text_len, hidden) + visual = torch.ones(bs, 2, hidden) * 99.0 + + fused = adapter._splice_visual_into_text(text, visual, ids) + # Position 0 (START) and 3 (END) must be unchanged. + assert torch.allclose(fused[0, 0], text[0, 0]) + assert torch.allclose(fused[0, 3], text[0, 3]) + + def test_no_context_tokens_returns_text_unchanged(self): + """When there are no IMG_CONTEXT tokens the output must equal text_embeds.""" + adapter = _make_internvl_adapter() + bs, text_len, hidden = 2, 6, 8 + ids = torch.zeros(bs, text_len, dtype=torch.long) + text = torch.randn(bs, text_len, hidden) + visual = torch.randn(bs, 4, hidden) + + fused = adapter._splice_visual_into_text(text, visual, ids) + assert torch.allclose(fused, text) + + def test_multi_image_replacement(self): + """Two separate runs of context tokens correspond to two images.""" + adapter = _make_internvl_adapter() + bs, text_len, hidden = 1, 10, 4 + # Image 1: positions 1-2, Image 2: positions 6-7 + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[0, 1] = _CONTEXT_ID + ids[0, 2] = _CONTEXT_ID + ids[0, 6] = _CONTEXT_ID + ids[0, 7] = _CONTEXT_ID + + text = torch.zeros(bs, text_len, hidden) + # First 2 visual tokens = 1.0, next 2 = 2.0 + visual = torch.cat([torch.ones(bs, 2, hidden), torch.full((bs, 2, hidden), 2.0)], dim=1) + + fused = adapter._splice_visual_into_text(text, visual, ids) + assert fused.shape == (bs, text_len, hidden) + assert torch.allclose(fused[0, 1:3], torch.ones(2, hidden)) + assert torch.allclose(fused[0, 6:8], torch.full((2, hidden), 2.0)) + + def test_forward_end_to_end_shape(self): + """Full forward pass returns the correct shard shape.""" + world_size = 2 + pg = _make_mock_process_group(world_size=world_size, rank=0) + adapter = InternVLFusionAdapter(nn.Identity(), pg, image_token_id=_CONTEXT_ID) + + bs, local_v, text_len, hidden = 1, 3, 8, 4 + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[0, 2:5] = _CONTEXT_ID # 3 context tokens; local_v * world_size = 6 total + visual = torch.randn(bs, local_v, hidden) + text = torch.randn(bs, text_len, hidden) + + out = adapter(visual, text, ids) + # fused_len == text_len == 8 (length-preserving); padded to 8 (divisible by 2); local = 4 + assert out.shape == (bs, 4, hidden) + + +# --------------------------------------------------------------------------- +# Qwen2VLFusionAdapter tests (tests _splice_visual_into_text directly) +# --------------------------------------------------------------------------- + +_VIS_START_ID = 151652 +_VIS_END_ID = 151653 + + +def _make_qwen2vl_adapter(world_size=2, rank=0): + pg = _make_mock_process_group(world_size=world_size, rank=rank) + return Qwen2VLFusionAdapter(nn.Identity(), + pg, + vision_start_token_id=_VIS_START_ID, + vision_end_token_id=_VIS_END_ID) + + +class TestQwen2VLFusionAdapter: + + def test_inner_tokens_replaced_with_visual(self): + """Tokens between vision_start and vision_end must become visual embeddings.""" + adapter = _make_qwen2vl_adapter() + bs, text_len, hidden = 1, 7, 4 + # [t0, t1, , pad, pad, , t6] + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[0, 2] = _VIS_START_ID + ids[0, 5] = _VIS_END_ID + + text = torch.zeros(bs, text_len, hidden) + visual = torch.ones(bs, 2, hidden) * 5.0 + + fused = adapter._splice_visual_into_text(text, visual, ids) + assert torch.allclose(fused[0, 3:5], visual[0]) + + def test_sequence_length_preserved(self): + """Output length must equal input length (1-to-1 replacement).""" + adapter = _make_qwen2vl_adapter() + bs, text_len, hidden = 2, 12, 8 + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[:, 2] = _VIS_START_ID + ids[:, 8] = _VIS_END_ID # 5 inner placeholder tokens + text = torch.randn(bs, text_len, hidden) + visual = torch.randn(bs, 5, hidden) + + fused = adapter._splice_visual_into_text(text, visual, ids) + assert fused.shape == (bs, text_len, hidden) + + def test_boundary_tokens_preserved(self): + """vision_start and vision_end embeddings must be unchanged after splice.""" + adapter = _make_qwen2vl_adapter() + bs, text_len, hidden = 1, 6, 4 + # [t0, , pad, pad, , t5] + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[0, 1] = _VIS_START_ID + ids[0, 4] = _VIS_END_ID + + text = torch.arange(bs * text_len * hidden, dtype=torch.float).reshape(bs, text_len, hidden) + visual = torch.ones(bs, 2, hidden) * 99.0 + + fused = adapter._splice_visual_into_text(text, visual, ids) + assert torch.allclose(fused[0, 1], text[0, 1]) # vision_start preserved + assert torch.allclose(fused[0, 4], text[0, 4]) # vision_end preserved + + def test_no_vision_tokens_returns_text_unchanged(self): + """When there are no vision_start/end tokens the output must equal text_embeds.""" + adapter = _make_qwen2vl_adapter() + bs, text_len, hidden = 2, 8, 4 + ids = torch.zeros(bs, text_len, dtype=torch.long) + text = torch.randn(bs, text_len, hidden) + visual = torch.randn(bs, 4, hidden) + + fused = adapter._splice_visual_into_text(text, visual, ids) + assert torch.allclose(fused, text) + + def test_multi_image_replacement(self): + """Two vision blocks are handled independently.""" + adapter = _make_qwen2vl_adapter() + bs, text_len, hidden = 1, 14, 4 + # Block 1: positions 1 (start) .. 4 (end), 2 inner tokens at 2-3 + # Block 2: positions 8 (start) .. 12 (end), 3 inner tokens at 9-11 + ids = torch.zeros(bs, text_len, dtype=torch.long) + ids[0, 1] = _VIS_START_ID + ids[0, 4] = _VIS_END_ID + ids[0, 8] = _VIS_START_ID + ids[0, 12] = _VIS_END_ID + + text = torch.zeros(bs, text_len, hidden) + visual = torch.cat([torch.ones(bs, 2, hidden), torch.full((bs, 3, hidden), 2.0)], dim=1) + + fused = adapter._splice_visual_into_text(text, visual, ids) + assert fused.shape == (bs, text_len, hidden) + assert torch.allclose(fused[0, 2:4], torch.ones(2, hidden)) + assert torch.allclose(fused[0, 9:12], torch.full((3, hidden), 2.0)) + + def test_forward_end_to_end_shape(self): + """Full forward pass returns the correct shard shape.""" + world_size = 2 + pg = _make_mock_process_group(world_size=world_size, rank=0) + adapter = Qwen2VLFusionAdapter(nn.Identity(), + pg, + vision_start_token_id=_VIS_START_ID, + vision_end_token_id=_VIS_END_ID) + + bs, local_v, text_len, hidden = 1, 3, 10, 4 + ids = torch.zeros(bs, text_len, dtype=torch.long) + # 6 inner placeholder tokens (local_v * world_size = 6) + ids[0, 1] = _VIS_START_ID + ids[0, 8] = _VIS_END_ID + visual = torch.randn(bs, local_v, hidden) + text = torch.randn(bs, text_len, hidden) + + out = adapter(visual, text, ids) + # fused_len == text_len == 10 (length-preserving); padded to 10; local = 5 + assert out.shape == (bs, 5, hidden) diff --git a/tests/unit/sequence_parallelism/test_autosp_equivalence.py b/tests/unit/sequence_parallelism/test_autosp_equivalence.py new file mode 100644 index 000000000000..0c289353a29d --- /dev/null +++ b/tests/unit/sequence_parallelism/test_autosp_equivalence.py @@ -0,0 +1,430 @@ +# SPDX-License-Identifier: Apache-2.0 +# DeepSpeed Team +""" +Numerical equivalence tests for AutoSP multimodal sequence parallelism. + +Each test verifies that running the SP-wrapped path across N ranks produces +the same result as the equivalent single-device (non-SP) computation. + +These tests require 2 GPUs. +Run with: + + NCCL_P2P_DISABLE=1 python -m pytest tests/unit/sequence_parallelism/test_autosp_equivalence.py -v +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import pytest + +import deepspeed.comm as dist +from deepspeed.sequence.autosp_vit import UlyssesSPViTAttention +from deepspeed.sequence.autosp_fusion import InternVLFusionAdapter, LlavaFusionAdapter, Qwen2VLFusionAdapter +from deepspeed.accelerator import get_accelerator + +from unit.common import DistributedTest + +# --------------------------------------------------------------------------- +# Shared identity attention — deterministic, easy to verify +# --------------------------------------------------------------------------- + +_IMAGE_TOKEN_ID = -200 + + +class _IdentityAttn(nn.Module): + """Returns hidden_states unchanged so that gather-compute-scatter is a no-op.""" + + def forward(self, hidden_states, **kwargs): + return hidden_states + + +# --------------------------------------------------------------------------- +# UlyssesSPViTAttention equivalence +# --------------------------------------------------------------------------- + + +class TestViTSPEquivalence(DistributedTest): + """SP-wrapped ViT attention with an identity inner module must reproduce + the unsharded output on every rank.""" + + world_size = 2 + + @pytest.mark.parametrize("has_cls_token", [True, False]) + @pytest.mark.parametrize("num_patches", [8, 12]) + def test_output_equals_single_device(self, has_cls_token, num_patches): + """Each rank's local output slice must match the corresponding slice of + the single-device output.""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + rank = dist.get_rank(sp_group) + bs, hidden = 2, 32 + + # --- Single-device reference --- + # Build the full input (all ranks see the same RNG seed so the tensor + # is identical everywhere). + torch.manual_seed(42) + if has_cls_token: + full_input = torch.randn(bs, 1 + num_patches, hidden).to(get_accelerator().device_name()) + else: + full_input = torch.randn(bs, num_patches, hidden).to(get_accelerator().device_name()) + + identity = _IdentityAttn().to(get_accelerator().device_name()) + # Single-device path is just identity — output == input. + ref_out = identity(full_input) + + # --- SP path --- + local_patches = num_patches // self.world_size + if has_cls_token: + cls = full_input[:, :1, :] + patch_slice = full_input[:, 1 + rank * local_patches:1 + (rank + 1) * local_patches, :] + local_input = torch.cat([cls, patch_slice], dim=1) + else: + local_input = full_input[:, rank * local_patches:(rank + 1) * local_patches, :] + + wrapper = UlyssesSPViTAttention(_IdentityAttn().to(get_accelerator().device_name()), + sp_group, + has_cls_token=has_cls_token).to(get_accelerator().device_name()) + sp_out = wrapper(local_input) + + # --- Compare --- + # sp_out is the local slice; reconstruct what slice of ref_out it maps to. + if has_cls_token: + ref_slice = torch.cat( + [ref_out[:, :1, :], ref_out[:, 1 + rank * local_patches:1 + (rank + 1) * local_patches, :]], dim=1) + else: + ref_slice = ref_out[:, rank * local_patches:(rank + 1) * local_patches, :] + + assert torch.allclose(sp_out, ref_slice, + atol=1e-5), (f"rank={rank} sp_out differs from reference: " + f"max_diff={( sp_out - ref_slice).abs().max().item():.2e}") + + @pytest.mark.parametrize("has_cls_token", [True, False]) + def test_noneven_patches(self, has_cls_token): + """When num_patches % world_size != 0, the wrapper must still produce + correct per-rank output. With 5 patches and world_size=2, rank 0 + holds 3 patches and rank 1 holds 2 patches.""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + rank = dist.get_rank(sp_group) + bs, hidden = 2, 16 + num_patches = 5 # not divisible by world_size=2 + + torch.manual_seed(77) + if has_cls_token: + full_input = torch.randn(bs, 1 + num_patches, hidden).to(get_accelerator().device_name()) + else: + full_input = torch.randn(bs, num_patches, hidden).to(get_accelerator().device_name()) + + # Distribute: first (num_patches % world_size) ranks carry one extra patch. + extra = num_patches % self.world_size # = 1 + base = num_patches // self.world_size # = 2 + local_v = base + (1 if rank < extra else 0) + patch_start = rank * base + min(rank, extra) + + if has_cls_token: + cls = full_input[:, :1, :] + patch_slice = full_input[:, 1 + patch_start:1 + patch_start + local_v, :] + local_input = torch.cat([cls, patch_slice], dim=1) + else: + local_input = full_input[:, patch_start:patch_start + local_v, :] + + wrapper = UlyssesSPViTAttention(_IdentityAttn().to(get_accelerator().device_name()), + sp_group, + has_cls_token=has_cls_token) + sp_out = wrapper(local_input) + + # Reference: identity wrapper — each rank's output must equal its input slice. + if has_cls_token: + ref_slice = torch.cat([full_input[:, :1, :], full_input[:, 1 + patch_start:1 + patch_start + local_v, :]], + dim=1) + else: + ref_slice = full_input[:, patch_start:patch_start + local_v, :] + + assert torch.allclose(sp_out, ref_slice, + atol=1e-5), (f"rank={rank} non-even patches: sp_out differs from reference: " + f"max_diff={(sp_out - ref_slice).abs().max().item():.2e}") + + +# --------------------------------------------------------------------------- +# LlavaFusionAdapter equivalence +# --------------------------------------------------------------------------- + + +class TestLlavaFusionEquivalence(DistributedTest): + """Verifies that the SP gather/scatter in LlavaFusionAdapter is a lossless + round-trip: concatenating all ranks' output shards reproduces the full + fused sequence that single-device splicing would produce.""" + + world_size = 2 + + def _build_inputs(self, bs, local_v, text_len, hidden, rank): + """Build deterministic visual and text tensors identical on every rank.""" + torch.manual_seed(0) + # Each rank holds a contiguous slice of the visual tokens. + full_visual = torch.randn(bs, local_v * self.world_size, hidden).to(get_accelerator().device_name()) + text = torch.randn(bs, text_len, hidden).to(get_accelerator().device_name()) + ids = torch.zeros(bs, text_len, dtype=torch.long).to(get_accelerator().device_name()) + ids[:, 1] = _IMAGE_TOKEN_ID # one image placeholder at position 1 + local_visual = full_visual[:, rank * local_v:(rank + 1) * local_v, :] + return full_visual, local_visual, text, ids + + def test_shards_reassemble_to_full_fused(self): + """Gathering all ranks' output shards must equal the single-device + fused sequence (modulo padding zeros).""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + rank = dist.get_rank(sp_group) + + bs, local_v, text_len, hidden = 1, 4, 6, 8 + full_visual, local_visual, text, ids = self._build_inputs(bs, local_v, text_len, hidden, rank) + + # --- SP path: each rank gets one shard --- + adapter = LlavaFusionAdapter(nn.Identity(), sp_group, + image_token_id=_IMAGE_TOKEN_ID).to(get_accelerator().device_name()) + local_out = adapter(local_visual, text, ids) # [bs, local_fused, hidden] + + # Gather all shards onto every rank so we can compare globally. + gathered = [torch.zeros_like(local_out) for _ in range(self.world_size)] + dist.all_gather(gathered, local_out, group=sp_group) + full_sp_out = torch.cat(gathered, dim=1) # [bs, padded_fused, hidden] + + # --- Single-device reference --- + # Simulate what a non-SP LlavaFusionAdapter would produce: project the + # full visual tensor (identity here) and splice once. + ref_adapter = LlavaFusionAdapter(nn.Identity(), sp_group, + image_token_id=_IMAGE_TOKEN_ID).to(get_accelerator().device_name()) + # Call _splice_visual_into_text directly so we bypass the SP scatter. + ref_fused = ref_adapter._splice_visual_into_text(text, full_visual, ids) + + # Pad reference to the same padded length. + fused_len = ref_fused.shape[1] + pad = (self.world_size - fused_len % self.world_size) % self.world_size + if pad > 0: + ref_fused = F.pad(ref_fused, (0, 0, 0, pad)) + + assert torch.allclose(full_sp_out, ref_fused, + atol=1e-5), (f"rank={rank} reassembled SP output differs from reference: " + f"max_diff={( full_sp_out - ref_fused).abs().max().item():.2e}") + + def test_no_image_token_passthrough(self): + """When there are no image placeholders the SP fused output must equal + the sharded text after padding/scatter (all-text path).""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + rank = dist.get_rank(sp_group) + + bs, local_v, text_len, hidden = 1, 2, 8, 4 + torch.manual_seed(1) + local_visual = torch.randn(bs, local_v, hidden).to(get_accelerator().device_name()) + text = torch.randn(bs, text_len, hidden).to(get_accelerator().device_name()) + ids = torch.zeros(bs, text_len, dtype=torch.long).to(get_accelerator().device_name()) # no image placeholder + + adapter = LlavaFusionAdapter(nn.Identity(), sp_group, + image_token_id=_IMAGE_TOKEN_ID).to(get_accelerator().device_name()) + local_out = adapter(local_visual, text, ids) + + # Gather shards and strip the padding slice from visual gather. + gathered = [torch.zeros_like(local_out) for _ in range(self.world_size)] + dist.all_gather(gathered, local_out, group=sp_group) + full_sp_out = torch.cat(gathered, dim=1) + + # Expected: when there is no image token, the visual tokens are ignored. + # So the fused output should just be the text tokens. + ref_fused = text + pad = (self.world_size - ref_fused.shape[1] % self.world_size) % self.world_size + if pad > 0: + ref_fused = F.pad(ref_fused, (0, 0, 0, pad)) + + assert torch.allclose(full_sp_out, ref_fused, + atol=1e-5), (f"rank={rank} no-image path differs from reference: " + f"max_diff={( full_sp_out - ref_fused).abs().max().item():.2e}") + + +# --------------------------------------------------------------------------- +# InternVLFusionAdapter equivalence +# --------------------------------------------------------------------------- + +_INTERNVL_CONTEXT_TOKEN_ID = 92546 + + +class TestInternVLFusionEquivalence(DistributedTest): + """Verifies that the SP gather/scatter in InternVLFusionAdapter is a lossless + round-trip: concatenating all ranks' output shards reproduces the full fused + sequence that single-device splicing would produce. + + InternVL replaces IMG_CONTEXT tokens 1-to-1 with visual tokens, so the + sequence length is preserved. + """ + + world_size = 2 + + def _build_inputs(self, bs, local_v, text_len, hidden, rank, num_ctx_tokens): + """Build deterministic inputs with a run of IMG_CONTEXT tokens in the middle.""" + torch.manual_seed(2) + full_visual = torch.randn(bs, local_v * self.world_size, hidden).to(get_accelerator().device_name()) + text = torch.randn(bs, text_len, hidden).to(get_accelerator().device_name()) + ids = torch.zeros(bs, text_len, dtype=torch.long).to(get_accelerator().device_name()) + # Place IMG_CONTEXT tokens starting at position 2. + ids[:, 2:2 + num_ctx_tokens] = _INTERNVL_CONTEXT_TOKEN_ID + local_visual = full_visual[:, rank * local_v:(rank + 1) * local_v, :] + return full_visual, local_visual, text, ids + + def test_shards_reassemble_to_full_fused(self): + """Gathering all ranks' output shards must equal the single-device + fused sequence (modulo padding zeros).""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + rank = dist.get_rank(sp_group) + + bs, local_v, text_len, hidden = 1, 3, 8, 4 + full_visual, local_visual, text, ids = self._build_inputs(bs, + local_v, + text_len, + hidden, + rank, + num_ctx_tokens=local_v * self.world_size) + + # SP path. + adapter = InternVLFusionAdapter(nn.Identity(), sp_group, + image_token_id=_INTERNVL_CONTEXT_TOKEN_ID).to(get_accelerator().device_name()) + local_out = adapter(local_visual, text, ids) + + gathered = [torch.zeros_like(local_out) for _ in range(self.world_size)] + dist.all_gather(gathered, local_out, group=sp_group) + full_sp_out = torch.cat(gathered, dim=1) + + # Single-device reference. + ref_adapter = InternVLFusionAdapter(nn.Identity(), sp_group, image_token_id=_INTERNVL_CONTEXT_TOKEN_ID).to( + get_accelerator().device_name()) + ref_fused = ref_adapter._splice_visual_into_text(text, full_visual, ids) + + fused_len = ref_fused.shape[1] + pad = (self.world_size - fused_len % self.world_size) % self.world_size + if pad > 0: + ref_fused = F.pad(ref_fused, (0, 0, 0, pad)) + + assert torch.allclose(full_sp_out, ref_fused, + atol=1e-5), (f"rank={rank} InternVL reassembled output differs from reference: " + f"max_diff={( full_sp_out - ref_fused).abs().max().item():.2e}") + + def test_no_context_token_passthrough(self): + """When there are no IMG_CONTEXT tokens the fused output must equal the text.""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + rank = dist.get_rank(sp_group) + + bs, local_v, text_len, hidden = 1, 2, 6, 4 + torch.manual_seed(3) + local_visual = torch.randn(bs, local_v, hidden).to(get_accelerator().device_name()) + text = torch.randn(bs, text_len, hidden).to(get_accelerator().device_name()) + ids = torch.zeros(bs, text_len, dtype=torch.long).to(get_accelerator().device_name()) + + adapter = InternVLFusionAdapter(nn.Identity(), sp_group, + image_token_id=_INTERNVL_CONTEXT_TOKEN_ID).to(get_accelerator().device_name()) + local_out = adapter(local_visual, text, ids) + + gathered = [torch.zeros_like(local_out) for _ in range(self.world_size)] + dist.all_gather(gathered, local_out, group=sp_group) + full_sp_out = torch.cat(gathered, dim=1) + + ref_fused = text + pad = (self.world_size - ref_fused.shape[1] % self.world_size) % self.world_size + if pad > 0: + ref_fused = F.pad(ref_fused, (0, 0, 0, pad)) + + assert torch.allclose(full_sp_out, ref_fused, + atol=1e-5), (f"rank={rank} InternVL no-context path differs from reference: " + f"max_diff={( full_sp_out - ref_fused).abs().max().item():.2e}") + + +# --------------------------------------------------------------------------- +# Qwen2VLFusionAdapter equivalence +# --------------------------------------------------------------------------- + +_QWEN2VL_START_ID = 151652 +_QWEN2VL_END_ID = 151653 + + +class TestQwen2VLFusionEquivalence(DistributedTest): + """Verifies that the SP gather/scatter in Qwen2VLFusionAdapter is a lossless + round-trip: concatenating all ranks' output shards reproduces the full fused + sequence that single-device splicing would produce. + + Qwen2-VL replaces inner placeholder tokens (between vision_start/end pairs) + 1-to-1 with visual tokens, so the sequence length is preserved. + """ + + world_size = 2 + + def _build_inputs(self, bs, local_v, text_len, hidden, rank, num_inner): + """Build inputs with a single vision_start/end block containing num_inner placeholders.""" + torch.manual_seed(4) + full_visual = torch.randn(bs, local_v * self.world_size, hidden).to(get_accelerator().device_name()) + text = torch.randn(bs, text_len, hidden).to(get_accelerator().device_name()) + ids = torch.zeros(bs, text_len, dtype=torch.long).to(get_accelerator().device_name()) + # [t0, , pad×num_inner, , ...] + ids[:, 1] = _QWEN2VL_START_ID + ids[:, 2 + num_inner] = _QWEN2VL_END_ID + local_visual = full_visual[:, rank * local_v:(rank + 1) * local_v, :] + return full_visual, local_visual, text, ids + + def test_shards_reassemble_to_full_fused(self): + """Gathering all ranks' output shards must equal the single-device + fused sequence (modulo padding zeros).""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + rank = dist.get_rank(sp_group) + + bs, local_v, text_len, hidden = 1, 3, 10, 4 + num_inner = local_v * self.world_size # inner placeholder count equals total visual tokens + full_visual, local_visual, text, ids = self._build_inputs(bs, local_v, text_len, hidden, rank, num_inner) + + # SP path. + adapter = Qwen2VLFusionAdapter(nn.Identity(), + sp_group, + vision_start_token_id=_QWEN2VL_START_ID, + vision_end_token_id=_QWEN2VL_END_ID).to(get_accelerator().device_name()) + local_out = adapter(local_visual, text, ids) + + gathered = [torch.zeros_like(local_out) for _ in range(self.world_size)] + dist.all_gather(gathered, local_out, group=sp_group) + full_sp_out = torch.cat(gathered, dim=1) + + # Single-device reference. + ref_adapter = Qwen2VLFusionAdapter(nn.Identity(), + sp_group, + vision_start_token_id=_QWEN2VL_START_ID, + vision_end_token_id=_QWEN2VL_END_ID).to(get_accelerator().device_name()) + ref_fused = ref_adapter._splice_visual_into_text(text, full_visual, ids) + + fused_len = ref_fused.shape[1] + pad = (self.world_size - fused_len % self.world_size) % self.world_size + if pad > 0: + ref_fused = F.pad(ref_fused, (0, 0, 0, pad)) + + assert torch.allclose(full_sp_out, ref_fused, + atol=1e-5), (f"rank={rank} Qwen2VL reassembled output differs from reference: " + f"max_diff={( full_sp_out - ref_fused).abs().max().item():.2e}") + + def test_no_vision_token_passthrough(self): + """When there are no vision_start/end tokens the fused output must equal the text.""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + rank = dist.get_rank(sp_group) + + bs, local_v, text_len, hidden = 1, 2, 8, 4 + torch.manual_seed(5) + local_visual = torch.randn(bs, local_v, hidden).to(get_accelerator().device_name()) + text = torch.randn(bs, text_len, hidden).to(get_accelerator().device_name()) + ids = torch.zeros(bs, text_len, dtype=torch.long).to(get_accelerator().device_name()) + + adapter = Qwen2VLFusionAdapter(nn.Identity(), + sp_group, + vision_start_token_id=_QWEN2VL_START_ID, + vision_end_token_id=_QWEN2VL_END_ID).to(get_accelerator().device_name()) + local_out = adapter(local_visual, text, ids) + + gathered = [torch.zeros_like(local_out) for _ in range(self.world_size)] + dist.all_gather(gathered, local_out, group=sp_group) + full_sp_out = torch.cat(gathered, dim=1) + + ref_fused = text + pad = (self.world_size - ref_fused.shape[1] % self.world_size) % self.world_size + if pad > 0: + ref_fused = F.pad(ref_fused, (0, 0, 0, pad)) + + assert torch.allclose(full_sp_out, ref_fused, + atol=1e-5), (f"rank={rank} Qwen2VL no-vision path differs from reference: " + f"max_diff={( full_sp_out - ref_fused).abs().max().item():.2e}") diff --git a/tests/unit/sequence_parallelism/test_autosp_integration.py b/tests/unit/sequence_parallelism/test_autosp_integration.py new file mode 100644 index 000000000000..4efcdb07c302 --- /dev/null +++ b/tests/unit/sequence_parallelism/test_autosp_integration.py @@ -0,0 +1,277 @@ +# SPDX-License-Identifier: Apache-2.0 +# DeepSpeed Team +""" +End-to-end integration tests for AutoSP multimodal sequence parallelism. + +Each test builds a minimal mock model whose attention-layer class names match +the autosp_detector registry, then verifies two things: + +1. auto_wrap_model_for_sp correctly identifies and wraps ViT attention modules + (with the correct has_cls_token value from the registry) and emits warnings + for HF-style LLM attention without wrapping them. +2. The full pipeline (SP-wrapped ViT -> fusion adapter) produces fused output + numerically equivalent to the single-device splice reference. + +These tests require 2 GPUs. +Run with: + + NCCL_P2P_DISABLE=1 python -m pytest tests/unit/sequence_parallelism/test_autosp_integration.py -v +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import deepspeed.comm as dist +from deepspeed.sequence.auto_sp import auto_wrap_model_for_sp +from deepspeed.sequence.autosp_vit import UlyssesSPViTAttention +from deepspeed.sequence.autosp_fusion import InternVLFusionAdapter, Qwen2VLFusionAdapter +from deepspeed.accelerator import get_accelerator + +from unit.common import DistributedTest + +# --------------------------------------------------------------------------- +# Token IDs +# --------------------------------------------------------------------------- + +_INTERNVL_CONTEXT_ID = 92546 +_QWEN2VL_START_ID = 151652 +_QWEN2VL_END_ID = 151653 + +# --------------------------------------------------------------------------- +# Mock attention classes +# +# Class names must match exactly the entries in autosp_detector._VIT_ATTN_CLASSNAMES +# and _LLM_ATTN_CLASSNAMES so that auto_wrap_model_for_sp detects them. +# --------------------------------------------------------------------------- + + +class InternVisionAttention(nn.Module): + """Mock ViT attention for InternVL (registered in _VIT_ATTN_CLASSNAMES).""" + + def forward(self, hidden_states, **kwargs): + return hidden_states + + +class InternLM2Attention(nn.Module): + """Mock LLM attention for InternVL (registered in _LLM_ATTN_CLASSNAMES).""" + + def forward(self, hidden_states, **kwargs): + return hidden_states + + +class Qwen2VLVisionAttention(nn.Module): + """Mock ViT attention for Qwen2-VL (registered in _VIT_ATTN_CLASSNAMES).""" + + def forward(self, hidden_states, **kwargs): + return hidden_states + + +class Qwen2Attention(nn.Module): + """Mock LLM attention for Qwen2-VL (registered in _LLM_ATTN_CLASSNAMES).""" + + def forward(self, hidden_states, **kwargs): + return hidden_states + + +# --------------------------------------------------------------------------- +# Model skeleton helpers +# --------------------------------------------------------------------------- + + +class _AttnLayer(nn.Module): + """Generic transformer block that holds an attention submodule. + + auto_wrap_model_for_sp scans named_modules() and replaces ``self.attn`` + when its class name is in the detector's registry. + """ + + def __init__(self, attn: nn.Module) -> None: + super().__init__() + self.attn = attn + + def forward(self, x, **kwargs): + return self.attn(x, **kwargs) + + +class _MinimalInternVLModel(nn.Module): + """Minimal InternVL-like skeleton for integration testing. + + Module paths recognised by autosp_detector: + - ``vision_encoder.0.attn`` -> InternVisionAttention (_VIT_ATTN_CLASSNAMES) + - ``language_model.0.attn`` -> InternLM2Attention (_LLM_ATTN_CLASSNAMES) + - ``mm_projector`` -> keyword in _VISION_PROJ_KEYWORDS + + ``forward`` exercises only the ViT + fusion path; ``language_model`` is + present to verify that auto_wrap does NOT wrap HF-style LLM attention. + """ + + def __init__(self) -> None: + super().__init__() + self.vision_encoder = nn.Sequential(_AttnLayer(InternVisionAttention())) + self.mm_projector = nn.Identity() + self.language_model = nn.Sequential(_AttnLayer(InternLM2Attention())) + self.fusion = None + + def forward(self, local_patches: torch.Tensor, text_embeds: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor: + local_visual = self.vision_encoder(local_patches) + return self.fusion(local_visual, text_embeds, input_ids) + + +class _MinimalQwen2VLModel(nn.Module): + """Minimal Qwen2-VL-like skeleton for integration testing. + + Module paths recognised by autosp_detector: + - ``visual.0.attn`` -> Qwen2VLVisionAttention (_VIT_ATTN_CLASSNAMES) + - ``model.0.attn`` -> Qwen2Attention (_LLM_ATTN_CLASSNAMES) + - ``multi_modal_projector`` -> keyword in _VISION_PROJ_KEYWORDS + """ + + def __init__(self) -> None: + super().__init__() + self.visual = nn.Sequential(_AttnLayer(Qwen2VLVisionAttention())) + self.multi_modal_projector = nn.Identity() + self.model = nn.Sequential(_AttnLayer(Qwen2Attention())) + self.fusion = None + + def forward(self, local_patches: torch.Tensor, text_embeds: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor: + local_visual = self.visual(local_patches) + return self.fusion(local_visual, text_embeds, input_ids) + + +# --------------------------------------------------------------------------- +# InternVL integration tests +# --------------------------------------------------------------------------- + + +class TestInternVLIntegration(DistributedTest): + """Integration tests for the InternVL multimodal SP pipeline.""" + + world_size = 2 + + def test_auto_wrap_detects_and_wraps_modules(self): + """auto_wrap_model_for_sp must replace InternVisionAttention with + UlyssesSPViTAttention (has_cls_token=False) and must NOT wrap + InternLM2Attention (HF-style, incompatible with DistributedAttention).""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + model = _MinimalInternVLModel().to(get_accelerator().device_name()) + auto_wrap_model_for_sp(model, sp_group) + + assert isinstance( + model.vision_encoder[0].attn, + UlyssesSPViTAttention), ("Expected vision_encoder[0].attn to be UlyssesSPViTAttention after auto_wrap") + assert not model.vision_encoder[0].attn.has_cls_token, ( + "InternVisionAttention has no CLS token; has_cls_token must be False") + assert isinstance(model.language_model[0].attn, + InternLM2Attention), ("HF-style LLM attention must NOT be wrapped by auto_wrap") + + def test_full_pipeline_visual_to_fused(self): + """SP-wrapped ViT -> InternVLFusionAdapter must produce fused output + numerically equivalent to the single-device splice reference.""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + rank = dist.get_rank(sp_group) + + bs, local_v, text_len, hidden = 1, 4, 10, 8 + num_ctx = local_v * self.world_size + + torch.manual_seed(20) + full_visual = torch.randn(bs, local_v * self.world_size, hidden).to(get_accelerator().device_name()) + text = torch.randn(bs, text_len, hidden).to(get_accelerator().device_name()) + ids = torch.zeros(bs, text_len, dtype=torch.long).to(get_accelerator().device_name()) + ids[:, 2:2 + num_ctx] = _INTERNVL_CONTEXT_ID + + local_patches = full_visual[:, rank * local_v:(rank + 1) * local_v, :] + + model = _MinimalInternVLModel().to(get_accelerator().device_name()) + auto_wrap_model_for_sp(model, sp_group) + model.fusion = InternVLFusionAdapter(model.mm_projector, sp_group, + image_token_id=_INTERNVL_CONTEXT_ID).to(get_accelerator().device_name()) + + local_out = model(local_patches, text, ids) + + gathered = [torch.zeros_like(local_out) for _ in range(self.world_size)] + dist.all_gather(gathered, local_out, group=sp_group) + full_sp_out = torch.cat(gathered, dim=1) + + # Single-device reference: splice without SP scatter. + ref_adapter = InternVLFusionAdapter(nn.Identity(), sp_group, + image_token_id=_INTERNVL_CONTEXT_ID).to(get_accelerator().device_name()) + ref_fused = ref_adapter._splice_visual_into_text(text, full_visual, ids) + pad = (self.world_size - ref_fused.shape[1] % self.world_size) % self.world_size + if pad > 0: + ref_fused = F.pad(ref_fused, (0, 0, 0, pad)) + + assert torch.allclose(full_sp_out, ref_fused, + atol=1e-5), (f"rank={rank} InternVL full pipeline output differs from reference: " + f"max_diff={(full_sp_out - ref_fused).abs().max().item():.2e}") + + +# --------------------------------------------------------------------------- +# Qwen2-VL integration tests +# --------------------------------------------------------------------------- + + +class TestQwen2VLIntegration(DistributedTest): + """Integration tests for the Qwen2-VL multimodal SP pipeline.""" + + world_size = 2 + + def test_auto_wrap_detects_and_wraps_modules(self): + """auto_wrap_model_for_sp must replace Qwen2VLVisionAttention with + UlyssesSPViTAttention (has_cls_token=False) and must NOT wrap + Qwen2Attention (HF-style, incompatible with DistributedAttention).""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + model = _MinimalQwen2VLModel().to(get_accelerator().device_name()) + auto_wrap_model_for_sp(model, sp_group) + + assert isinstance( + model.visual[0].attn, + UlyssesSPViTAttention), ("Expected visual[0].attn to be UlyssesSPViTAttention after auto_wrap") + assert not model.visual[0].attn.has_cls_token, ( + "Qwen2VLVisionAttention has no CLS token; has_cls_token must be False") + assert isinstance(model.model[0].attn, + Qwen2Attention), ("HF-style LLM attention must NOT be wrapped by auto_wrap") + + def test_full_pipeline_visual_to_fused(self): + """SP-wrapped ViT -> Qwen2VLFusionAdapter must produce fused output + numerically equivalent to the single-device splice reference.""" + sp_group = dist.new_group(ranks=list(range(self.world_size))) + rank = dist.get_rank(sp_group) + + bs, local_v, text_len, hidden = 1, 3, 10, 8 + num_inner = local_v * self.world_size + + torch.manual_seed(21) + full_visual = torch.randn(bs, local_v * self.world_size, hidden).to(get_accelerator().device_name()) + text = torch.randn(bs, text_len, hidden).to(get_accelerator().device_name()) + ids = torch.zeros(bs, text_len, dtype=torch.long).to(get_accelerator().device_name()) + ids[:, 1] = _QWEN2VL_START_ID + ids[:, 2 + num_inner] = _QWEN2VL_END_ID + + local_patches = full_visual[:, rank * local_v:(rank + 1) * local_v, :] + + model = _MinimalQwen2VLModel().to(get_accelerator().device_name()) + auto_wrap_model_for_sp(model, sp_group) + model.fusion = Qwen2VLFusionAdapter(model.multi_modal_projector, + sp_group, + vision_start_token_id=_QWEN2VL_START_ID, + vision_end_token_id=_QWEN2VL_END_ID).to(get_accelerator().device_name()) + + local_out = model(local_patches, text, ids) + + gathered = [torch.zeros_like(local_out) for _ in range(self.world_size)] + dist.all_gather(gathered, local_out, group=sp_group) + full_sp_out = torch.cat(gathered, dim=1) + + ref_adapter = Qwen2VLFusionAdapter(nn.Identity(), + sp_group, + vision_start_token_id=_QWEN2VL_START_ID, + vision_end_token_id=_QWEN2VL_END_ID).to(get_accelerator().device_name()) + ref_fused = ref_adapter._splice_visual_into_text(text, full_visual, ids) + pad = (self.world_size - ref_fused.shape[1] % self.world_size) % self.world_size + if pad > 0: + ref_fused = F.pad(ref_fused, (0, 0, 0, pad)) + + assert torch.allclose(full_sp_out, ref_fused, + atol=1e-5), (f"rank={rank} Qwen2VL full pipeline output differs from reference: " + f"max_diff={(full_sp_out - ref_fused).abs().max().item():.2e}")