Integrate training DSpark draft models#127
Conversation
Anchor sampler, block-causal FlexAttention mask, noise-input construction, eval mask, forward-output dataclass, and confidence-head predictor shared by the DSpark draft model. Ported from DeepSpec. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: chungen04 <cho322@gatech.edu>
Vanilla / gated / RNN low-rank token-conditioned bias heads added to draft logits. Ported from DeepSpec. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: chungen04 <cho322@gatech.edu>
compute_dspark_loss with DP-correct global-denominator normalization; returns backward loss plus detached components for logging. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: chungen04 <cho322@gatech.edu>
Qwen3DSparkModel (dual-source KV, block-parallel forward, Markov + confidence heads) and DSparkConfig, plus build_dspark_draft_config that derives the backbone from the target model config. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: chungen04 <cho322@gatech.edu>
DSparkTrainer drives the model forward + compute_dspark_loss in TorchSpec's async pipeline; loads and freezes target embed/lm_head, FSDP-shards decoder layers, reports per-position acc / sim_acc_len. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: chungen04 <cho322@gatech.edu>
Map DSparkConfig -> Qwen3DSparkModel in AutoEagle3DraftModel / AutoDraftModelConfig and export the public symbols. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: chungen04 <cho322@gatech.edu>
Route DSparkConfig draft configs to DSparkTrainer (alongside the DFlash/Eagle3 dispatch). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: chungen04 <cho322@gatech.edu>
Add dspark_* loss hyperparameters to TrainingConfig and _validate_and_configure_dspark (force store_last_hidden_states, set aux layer ids, fail fast on misconfig). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: chungen04 <cho322@gatech.edu>
DSpark draft-config JSON plus SGLang training example and 4-GPU repro config. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: chungen04 <cho322@gatech.edu>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 17037f58ba
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| target_last_hidden_states.size(-1), | ||
| ), | ||
| ) | ||
| aligned_target_logits = self.compute_logits(aligned_target_hidden) |
There was a problem hiding this comment.
Apply target norm before vLLM target logits
For inference_engine_type=vllm, last_hidden_states_prenorm defaults to true and the vLLM connector stores the final pre-norm hidden state; Eagle3 handles this by applying the verifier norm before the LM head. DSpark allows vLLM in _validate_and_configure_dspark, but this path feeds those raw pre-norm states directly through the copied lm_head, so the L1 and confidence targets are logits from the wrong distribution for vLLM DSpark runs. Either apply the target model norm here when last_hidden_states_prenorm is set, or reject vLLM for DSpark.
Useful? React with 👍 / 👎.
| has_confidence = outputs.confidence_pred is not None | ||
| confidence_loss_num = zero | ||
| confidence_loss_den = zero | ||
| if has_confidence: |
There was a problem hiding this comment.
Gate confidence loss on its alpha
When users disable the auxiliary DSpark objectives (for example dspark_l1_loss_alpha=0, dspark_confidence_head_alpha=0) and turn off store_last_hidden_states to run CE-only training, the validator no longer forces last hidden states. With the provided draft config, however, the confidence module still exists, so has_confidence becomes true and the loss immediately asserts that aligned_target_logits is present even though the confidence term has zero weight. This makes CE-only DSpark configs crash unless they keep storing unused target hidden states; gate this on confidence_head_alpha > 0 or skip producing/consuming confidence_pred when the term is disabled.
Useful? React with 👍 / 👎.
|
@yubofredwang Please check -- Looking at the upstream implementation, the pipeline will produce large hidden state dump, and TorchSpec avoids this by design. I am also reproducing DeepSeek's results. Suggestions is welcomed. Thank you! |
|
@chungen04 I have a WIP PR, the implementation is pretty similar but I've simplified the trainer by subclassing DFlash trainer into DSpec trainer. I am doing some final sanity checks and I will share here soon. |
|
@Dogacel sure, also feel free to collaborate on this pr if the implementation is similar |
I think it looks very similar, I have high confidence about all files except one:
Let me know if you see any errors or want to merge yours into this if you are certain about your loss/fwd pass implementation. |
|
Hi @chungen04, we have implemented and merged PR #129, let us know if you have any improvements over that implementation. Thanks for your effort! |
|
@Dogacel Sure, will try it out in the following days. Thanks for your contribution as well! |
Summary
DeepSeek just released DSpark speculative-decoding draft model, yet the upstream implementation is architected as a synchronous pipeline and requires to dump the hidden state. TorchSpec's design alleviates the storage requirement via the asynchronous, disaggregated pipeline design.
This PR ports DeepSpec's DSpark speculative-decoding draft model into TorchSpec's online (SGLang) training pipeline, as a first-class parallel path alongside Eagle3 and DFlash. DSpark is a block-anchor parallel drafter. It samples
num_anchorsanchor positions, expands each into ablock_size-token block of MASK tokens, and predicts the whole block in one pass using dual-source KV attention (context KV from the projected multi-layer target hidden states + draft KV from the noise embeddings) under a block-causal FlexAttention mask. Optional low-rank Markov and confidence heads refine the block logits and predict per-position accept rates.The current scope is with Qwen3 backbone + SGLang engine.
Files changed
New modules (ported from DeepSpec,
add_metriclogging side-effects dropped):torchspec/models/dspark_common.py— anchor sampler, block-causal mask, noise embed,eval mask, forward-output dataclass, confidence predictor.
torchspec/models/dspark_markov.py— vanilla / gated / RNN Markov heads.torchspec/models/dspark_loss.py—compute_dspark_loss(DP-correct global-denominatornormalization; returns backward loss + detached components for logging).
torchspec/models/draft/dspark.py—Qwen3DSparkModel,DSparkConfig, andbuild_dspark_draft_config(derives the backbone from the target model config).torchspec/training/dspark_trainer.py—DSparkTrainer(loads + freezes targetembed/lm_head, FSDP-shards decoder layers, runs forward + loss, reports per-position
acc /
sim_acc_len).Wiring:
models/draft/auto.py,torchspec/__init__.py,models/draft/__init__.py— registerDSparkConfig -> Qwen3DSparkModeland export the public symbols.training/trainer_actor.py— dispatchDSparkConfigdraft configs toDSparkTrainer.config/train_config.py,train_entry.py—dspark_*loss hyperparameters and_validate_and_configure_dspark(forcesstore_last_hidden_states, auto-sets aux layerids, fails fast on misconfig).
Configs:
config/dspark_draft_config.json(Qwen3-8B recipe),configs/sglang_qwen3_8b_dspark.yaml,configs/dspark_qwen3_8b_repro.yaml.Faithfulness to DeepSpec
The training math is mirrored exactly: model forward, Markov/confidence heads, loss terms, the
world_size-scaled global-denominator normalization, and theloss / gradient_accumulation_stepsscaling all match DeepSpec line-for-line. The onlyadaptations are infrastructural (offline-cache+FSDP1+DataLoader → online-mooncake+FSDP2+Ray) and the metric logging (replaced DeepSpec's
add_metricwith TorchSpec's per-position acc /sim_acc_len).Validation
1k-step training on open-perfectblend (not re-generated), Qwen3-8B target, SGLang engine, 4 GPUs (1 inference + 3 training), 4xB300 GPUs:
sim_acc_lenThroughput: ~9.6 entries/s (training-bound on 3 GPUs).
Follow-ups
DataFetcher#126 .