Skip to content

Integrate training DSpark draft models#127

Closed
chungen04 wants to merge 9 commits into
lightseekorg:mainfrom
chungen04:chungen/dspark
Closed

Integrate training DSpark draft models#127
chungen04 wants to merge 9 commits into
lightseekorg:mainfrom
chungen04:chungen/dspark

Conversation

@chungen04

@chungen04 chungen04 commented Jun 27, 2026

Copy link
Copy Markdown
Collaborator

Summary

DeepSeek just released DSpark speculative-decoding draft model, yet the upstream implementation is architected as a synchronous pipeline and requires to dump the hidden state. TorchSpec's design alleviates the storage requirement via the asynchronous, disaggregated pipeline design.

This PR ports DeepSpec's DSpark speculative-decoding draft model into TorchSpec's online (SGLang) training pipeline, as a first-class parallel path alongside Eagle3 and DFlash. DSpark is a block-anchor parallel drafter. It samples num_anchors anchor positions, expands each into a block_size-token block of MASK tokens, and predicts the whole block in one pass using dual-source KV attention (context KV from the projected multi-layer target hidden states + draft KV from the noise embeddings) under a block-causal FlexAttention mask. Optional low-rank Markov and confidence heads refine the block logits and predict per-position accept rates.

The current scope is with Qwen3 backbone + SGLang engine.

Files changed

New modules (ported from DeepSpec, add_metric logging side-effects dropped):

  • torchspec/models/dspark_common.py — anchor sampler, block-causal mask, noise embed,
    eval mask, forward-output dataclass, confidence predictor.
  • torchspec/models/dspark_markov.py — vanilla / gated / RNN Markov heads.
  • torchspec/models/dspark_loss.pycompute_dspark_loss (DP-correct global-denominator
    normalization; returns backward loss + detached components for logging).
  • torchspec/models/draft/dspark.pyQwen3DSparkModel, DSparkConfig, and
    build_dspark_draft_config (derives the backbone from the target model config).
  • torchspec/training/dspark_trainer.pyDSparkTrainer (loads + freezes target
    embed/lm_head, FSDP-shards decoder layers, runs forward + loss, reports per-position
    acc / sim_acc_len).

Wiring:

  • models/draft/auto.py, torchspec/__init__.py, models/draft/__init__.py — register
    DSparkConfig -> Qwen3DSparkModel and export the public symbols.
  • training/trainer_actor.py — dispatch DSparkConfig draft configs to DSparkTrainer.
  • config/train_config.py, train_entry.pydspark_* loss hyperparameters and
    _validate_and_configure_dspark (forces store_last_hidden_states, auto-sets aux layer
    ids, fails fast on misconfig).

Configs:

  • config/dspark_draft_config.json (Qwen3-8B recipe), configs/sglang_qwen3_8b_dspark.yaml,
    configs/dspark_qwen3_8b_repro.yaml.

Faithfulness to DeepSpec

The training math is mirrored exactly: model forward, Markov/confidence heads, loss terms, the world_size-scaled global-denominator normalization, and the loss / gradient_accumulation_steps scaling all match DeepSpec line-for-line. The only
adaptations are infrastructural (offline-cache+FSDP1+DataLoader → online-mooncake+FSDP2+Ray) and the metric logging (replaced DeepSpec's add_metric with TorchSpec's per-position acc / sim_acc_len).

Validation

1k-step training on open-perfectblend (not re-generated), Qwen3-8B target, SGLang engine, 4 GPUs (1 inference + 3 training), 4xB300 GPUs:

metric step 500 step 1000
eval loss 5.69 5.07
eval acc 0.217 0.271
eval sim_acc_len 0.38 0.55

Throughput: ~9.6 entries/s (training-bound on 3 GPUs).

Follow-ups

chungen04 and others added 9 commits June 27, 2026 18:23
Anchor sampler, block-causal FlexAttention mask, noise-input construction, eval mask, forward-output dataclass, and confidence-head predictor shared by the DSpark draft model. Ported from DeepSpec.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: chungen04 <cho322@gatech.edu>
Vanilla / gated / RNN low-rank token-conditioned bias heads added to draft logits. Ported from DeepSpec.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: chungen04 <cho322@gatech.edu>
compute_dspark_loss with DP-correct global-denominator normalization; returns backward loss plus detached components for logging.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: chungen04 <cho322@gatech.edu>
Qwen3DSparkModel (dual-source KV, block-parallel forward, Markov + confidence heads) and DSparkConfig, plus build_dspark_draft_config that derives the backbone from the target model config.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: chungen04 <cho322@gatech.edu>
DSparkTrainer drives the model forward + compute_dspark_loss in TorchSpec's async pipeline; loads and freezes target embed/lm_head, FSDP-shards decoder layers, reports per-position acc / sim_acc_len.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: chungen04 <cho322@gatech.edu>
Map DSparkConfig -> Qwen3DSparkModel in AutoEagle3DraftModel / AutoDraftModelConfig and export the public symbols.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: chungen04 <cho322@gatech.edu>
Route DSparkConfig draft configs to DSparkTrainer (alongside the DFlash/Eagle3 dispatch).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: chungen04 <cho322@gatech.edu>
Add dspark_* loss hyperparameters to TrainingConfig and _validate_and_configure_dspark (force store_last_hidden_states, set aux layer ids, fail fast on misconfig).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: chungen04 <cho322@gatech.edu>
DSpark draft-config JSON plus SGLang training example and 4-GPU repro config.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: chungen04 <cho322@gatech.edu>

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 17037f58ba

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

target_last_hidden_states.size(-1),
),
)
aligned_target_logits = self.compute_logits(aligned_target_hidden)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Apply target norm before vLLM target logits

For inference_engine_type=vllm, last_hidden_states_prenorm defaults to true and the vLLM connector stores the final pre-norm hidden state; Eagle3 handles this by applying the verifier norm before the LM head. DSpark allows vLLM in _validate_and_configure_dspark, but this path feeds those raw pre-norm states directly through the copied lm_head, so the L1 and confidence targets are logits from the wrong distribution for vLLM DSpark runs. Either apply the target model norm here when last_hidden_states_prenorm is set, or reject vLLM for DSpark.

Useful? React with 👍 / 👎.

Comment on lines +147 to +150
has_confidence = outputs.confidence_pred is not None
confidence_loss_num = zero
confidence_loss_den = zero
if has_confidence:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Gate confidence loss on its alpha

When users disable the auxiliary DSpark objectives (for example dspark_l1_loss_alpha=0, dspark_confidence_head_alpha=0) and turn off store_last_hidden_states to run CE-only training, the validator no longer forces last hidden states. With the provided draft config, however, the confidence module still exists, so has_confidence becomes true and the loss immediately asserts that aligned_target_logits is present even though the confidence term has zero weight. This makes CE-only DSpark configs crash unless they keep storing unused target hidden states; gate this on confidence_head_alpha > 0 or skip producing/consuming confidence_pred when the term is disabled.

Useful? React with 👍 / 👎.

@chungen04

Copy link
Copy Markdown
Collaborator Author

@yubofredwang Please check -- Looking at the upstream implementation, the pipeline will produce large hidden state dump, and TorchSpec avoids this by design. I am also reproducing DeepSeek's results. Suggestions is welcomed. Thank you!

@Dogacel

Dogacel commented Jun 27, 2026

Copy link
Copy Markdown
Collaborator

@chungen04 I have a WIP PR, the implementation is pretty similar but I've simplified the trainer by subclassing DFlash trainer into DSpec trainer. I am doing some final sanity checks and I will share here soon.

@chungen04

Copy link
Copy Markdown
Collaborator Author

@Dogacel sure, also feel free to collaborate on this pr if the implementation is similar

@Dogacel

Dogacel commented Jun 27, 2026

Copy link
Copy Markdown
Collaborator

@Dogacel sure, also feel free to collaborate on this pr if the implementation is similar

#129

I think it looks very similar, I have high confidence about all files except one:

  • torchspec/models/dspark.py

Let me know if you see any errors or want to merge yours into this if you are certain about your loss/fwd pass implementation.

@Dogacel

Dogacel commented Jun 29, 2026

Copy link
Copy Markdown
Collaborator

Hi @chungen04, we have implemented and merged PR #129, let us know if you have any improvements over that implementation. Thanks for your effort!

@Dogacel Dogacel closed this Jun 29, 2026
@chungen04

Copy link
Copy Markdown
Collaborator Author

@Dogacel Sure, will try it out in the following days. Thanks for your contribution as well!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants