A from-scratch speculative decoding inference engine built in PyTorch.
The goal is to deeply understand, implement, and benchmark speculative decoding — measuring real latency, token acceptance rates, and speedup across multiple draft/target model pairs.
Status: Infrastructure, data pipeline, model layer, profiling foundation, and autoregressive baseline complete. Speculative decoding loop next.
- What is Speculative Decoding?
- Project Structure
- Setup
- What You Can Run Right Now
- Model Pairs
- Roadmap
- Theory Reference
Autoregressive LLM inference is memory-bandwidth-bound at batch size 1: generating each token requires streaming the entire model's weights through GPU memory, leaving compute idle ~99% of the time.
Speculative decoding exploits this by running a cheap draft model to propose K candidate tokens, then verifying all K with the target model in a single parallel forward pass — the same memory cost as generating one token. A rejection sampling scheme guarantees the output distribution is identical to running the target model alone.
At a 70–80% token acceptance rate with a 10–20× cheaper draft model, this yields 2–3× wall-clock speedup with no change in output quality.
For the full derivation with worked numerical examples, see docs/speculative_decoding_theory.md.
speculative-decoding/
│
├── src/
│ ├── models/
│ │ ├── registry.py ✅ ModelPair dataclass + 4 registered pairs
│ │ ├── loader.py ✅ load_model() — HuggingFace weights + dtype handling
│ │ ├── ollama_loader.py ✅ Ollama convenience wrapper (hardware validation only)
│ │ └── wrapper.py ✅ ModelWrapper — forward pass + KV cache management
│ │
│ ├── decoding/
│ │ ├── autoregressive.py ✅ Baseline token-by-token generation with per-token timing
│ │ ├── speculative.py 🔲 Speculative decoding loop
│ │ └── rejection.py 🔲 Rejection sampling + adjusted distribution
│ │
│ ├── profiling/
│ │ ├── timer.py ✅ cuda_sync_time, CUDATimer, CUDATimerCollection
│ │ ├── memory.py 🔲 GPU memory tracking
│ │ └── metrics.py ✅ GenerationResult + BenchmarkConfig dataclasses
│ │
│ ├── data/
│ │ └── prompts.py ✅ PromptDataset — 150 prompts across 3 domains
│ │
│ └── utils/
│ ├── logging.py 🔲 Structured logging
│ └── reproducibility.py 🔲 Seed control, deterministic flags
│
├── configs/ 🔲 YAML experiment configs (not yet populated)
├── benchmarks/ 🔲 Benchmark entry-point scripts
├── tests/
│ ├── test_kv_cache.py ✅ 21 integration tests for KV cache shape + truncation
│ ├── test_timer.py ✅ 26 tests for GPU timing utilities (CPU + CUDA tiers)
│ ├── test_metrics.py ✅ 66 tests for GenerationResult + BenchmarkConfig
│ └── test_autoregressive.py ✅ 22 tests for AutoregressiveDecoder (incl. HF match)
├── notebooks/ 🔲 Analysis notebooks
├── results/ (gitignored — generated outputs)
├── figures/ (gitignored — generated plots)
│
├── docs/
│ └── speculative_decoding_theory.md ✅ Theory deep-dive with worked examples
│
├── environment.yml ✅ Conda env (Python 3.11, PyTorch 2.3, CUDA 12.1)
└── pyproject.toml ✅ Package metadata + src layout
✅ = implemented and runnable 🔲 = scaffolded, implementation pending
- Conda (or Mamba)
- CUDA 12.1-compatible GPU recommended; CPU works for
gpt2-scale models
conda env create -f environment.yml
conda activate speculative-decodingpip install -e .python -c "import torch; print(torch.cuda.get_device_name(0))"The tinyllama_llama3, phi3_llama3, and llama3_self pairs require access to
meta-llama/Meta-Llama-3-8B-Instruct. Request access at
huggingface.co/meta-llama, then:
huggingface-cli loginThe gpt2_dev pair has no access restrictions and works immediately.
Loads a HuggingFace causal LM, prints parameter count and memory footprint, then generates a short test sequence to confirm the model is functional.
# Smallest pair — no token required, fits on any GPU with ~3 GB VRAM
python -m src.models.loader --model gpt2 --dtype float16
python -m src.models.loader --model gpt2-xl --dtype float16
# Larger models (requires HF token + sufficient VRAM)
python -m src.models.loader --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 --dtype float16
python -m src.models.loader --model meta-llama/Meta-Llama-3-8B-Instruct --dtype float16
python -m src.models.loader --model meta-llama/Meta-Llama-3-8B-Instruct --dtype int4
# Control how many tokens to generate in the smoke test
python -m src.models.loader --model gpt2 --dtype float16 --n-tokens 20Expected output:
[loader] tokenizer ← gpt2
[loader] model ← gpt2 (dtype=float16)
model : gpt2
parameters : 124.4 M
memory : 0.23 GB
[loader] generating 10 tokens …
[loader] output: 'The quick brown fox jumps over the lazy dog ...'
[loader] OK
from src.models.registry import list_pairs, get_pair
print(list_pairs())
# ['gpt2_dev', 'tinyllama_llama3', 'phi3_llama3', 'llama3_self']
pair = get_pair("gpt2_dev")
print(pair.draft_model_id) # gpt2
print(pair.target_model_id) # gpt2-xl
print(pair.draft_dtype) # float16Downloads 150 prompts (50 per domain) from HuggingFace datasets and saves them as JSON with pre-computed token counts. Falls back to built-in prompts automatically if a dataset is unavailable.
# Full dataset (150 prompts) + tiny dev set (15 prompts)
python -m src.data.prompts --tokenizer gpt2 --output data/prompts.json
# Use a different tokenizer (important: token counts depend on the tokenizer)
python -m src.data.prompts --tokenizer TinyLlama/TinyLlama-1.1B-Chat-v1.0Expected output:
[prompts] loading tokenizer: gpt2
[prompts] building full dataset (50 prompts per domain) …
domain n min mean max
-----------------------------------------------
code 50 34 89.4 213
conversation 50 10 24.7 76
summarization 50 132 158.2 187
[prompts] saved 150 prompts → data/prompts.json
[prompts] saved 15 prompts → data/prompts_tiny.json
Load saved prompts in your own code:
from src.data.prompts import PromptDataset
dataset = PromptDataset.load("data/prompts.json")
code_prompts = dataset.get_by_domain("code") # list[Prompt]
all_prompts = dataset.get_all() # list[Prompt]
print(len(dataset)) # 150
print(code_prompts[0].prompt_id) # "code_000"
print(code_prompts[0].token_count)Prompt domains:
| Domain | Source | Format |
|---|---|---|
code |
openai/openai_humaneval |
Python function signature + docstring |
conversation |
tatsu-lab/alpaca |
Instruction text |
summarization |
cnn_dailymail 3.0.0 |
Article (≤512 tokens) + "Summarize the above article:" |
ModelWrapper is the interface every decoding component uses to talk to a
model. It exposes forward passes, cache length queries, and — critically —
truncate_cache(), which rolls the KV cache back to an arbitrary position
when speculative decoding rejects a draft token.
import torch
from src.models.loader import load_model
from src.models.wrapper import ModelWrapper
model, tokenizer = load_model("gpt2", "float32")
device = next(model.parameters()).device
wrapper = ModelWrapper(model, tokenizer, device)
# --- Prefill ---
prompt_ids = tokenizer.encode("The quick brown fox", return_tensors="pt").to(device)
logits, cache = wrapper.forward(prompt_ids)
print(wrapper.get_cache_length(cache)) # 4 (one entry per prompt token)
# logits shape: [1, 4, 50257]
# --- Decode one token ---
next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True) # greedy
logits, cache = wrapper.forward(next_token, past_key_values=cache)
print(wrapper.get_cache_length(cache)) # 5
# --- Simulate rejection at position 3: roll back to position 3 ---
cache = wrapper.truncate_cache(cache, keep_length=3)
print(wrapper.get_cache_length(cache)) # 3
# Continue from position 3 with a replacement token
replacement = torch.tensor([[1234]]).to(device)
logits, cache = wrapper.forward(replacement, past_key_values=cache)
print(wrapper.get_cache_length(cache)) # 4KV cache shape — every tensor inside past_key_values has shape
[batch, num_heads, seq_len, head_dim]. For GPT-2 small that is
[1, 12, seq_len, 64] per layer across 12 layers.
# All tests (requires GPT-2 weights — ~500 MB download on first run)
pytest tests/test_kv_cache.py -v
# Just shape checks (faster)
pytest tests/test_kv_cache.py -v -k "Shape"21 tests across 4 classes:
| Class | What it checks |
|---|---|
TestCacheShape |
len(past_kv)==12, each entry is (key, value), shapes [1, 12, 20, 64], get_cache_length |
TestCacheTruncation |
Shapes after keep_length=10, truncate to 1, original cache unmodified |
TestForwardWithCache |
Logits [1, 1, 50257] after truncated-cache pass, cache grows by 1, use_cache=False→None |
TestProperties |
vocab_size, param_count, memory_footprint_mb, repr |
CUDATimer uses CUDA Events recorded inside the GPU stream for precise
measurement without pipeline stalls. CUDATimerCollection manages named
phases — draft, verify, sample — the same way the speculative decoding loop
will use it.
import torch
from src.profiling.timer import CUDATimer, CUDATimerCollection, cuda_sync_time
# --- Single operation ---
with CUDATimer() as t:
result = torch.mm(
torch.randn(2048, 2048, device="cuda"),
torch.randn(2048, 2048, device="cuda"),
)
print(f"matmul: {t.elapsed_ms:.2f} ms")
# --- Multi-phase profiling ---
timers = CUDATimerCollection()
timers.start("draft")
# ... draft model forward passes ...
timers.stop("draft")
timers.start("verify")
# ... target model verification pass ...
timers.stop("verify")
print(timers)
# CUDATimerCollection:
# draft : 1.243 ms
# verify : 4.817 ms
print(timers.summary()) # {'draft': 1.243, 'verify': 4.817}On CPU-only machines both classes fall back to time.perf_counter() automatically.
Both decoders will return a GenerationResult. Speculative runs populate the
extra fields; autoregressive runs leave them as None. BenchmarkConfig
describes what was measured and produces a stable hash for deduplication.
from src.profiling.metrics import GenerationResult, BenchmarkConfig
# --- Build a result (normally returned by the decoder) ---
result = GenerationResult(
generated_ids=[1, 2, 3, 4, 5],
per_token_latencies=[0.08, 0.09, 0.08, 0.10, 0.09], # seconds
total_time=0.44,
peak_memory_mb=1823.0,
time_to_first_token=0.08,
# Speculative-only fields:
acceptance_rate=0.78,
tokens_per_step=4.12,
num_speculation_rounds=18,
draft_time_total_ms=120.5,
verify_time_total_ms=310.2,
sampling_time_total_ms=9.3,
)
print(result.tokens_per_second) # 5 / 0.44 ≈ 11.36
print(result.latency_p50) # median per-round latency
print(result.latency_p95) # 95th-percentile latency
print(result.summary())
# SD | 5 tok | 11.4 tok/s | p50=90ms p95=100ms | mem=1823MB | accept=0.78 rounds=18
# --- Serialise to JSONL for results files ---
with open("results/run.jsonl", "a") as f:
f.write(result.to_json_line()) # one JSON object per line
# --- Round-trip from JSON ---
import json
restored = GenerationResult.from_dict(json.loads(result.to_json_line()))
# --- Config hashing for deduplication ---
config = BenchmarkConfig(
model_pair_name="gpt2_dev",
K=4,
temperature=0.0,
max_new_tokens=200,
prompt_domain="code",
seed=42,
)
print(config.decoder_label) # "SD-K4"
print(config.config_hash()) # e.g. "a3f9c1b20d44" (12-char SHA-256 prefix)Key computed properties on GenerationResult:
| Property | Formula |
|---|---|
tokens_per_second |
len(generated_ids) / total_time |
latency_p50/p95/p99 |
np.percentile(per_token_latencies, 50/95/99) |
is_speculative |
acceptance_rate is not None |
num_tokens |
len(generated_ids) |
AutoregressiveDecoder implements the baseline generation loop with explicit
KV cache management and GPU-synchronised per-token latency measurement. It
does not use model.generate() — it builds the loop explicitly so every
timing measurement reflects real GPU work.
import torch
from src.models.loader import load_model
from src.models.wrapper import ModelWrapper
from src.decoding.autoregressive import AutoregressiveDecoder
model, tokenizer = load_model("gpt2", "float32")
device = next(model.parameters()).device
wrapper = ModelWrapper(model, tokenizer, device)
decoder = AutoregressiveDecoder(wrapper)
# --- Token-ID interface (returns GenerationResult) ---
prompt_ids = tokenizer.encode("The quick brown fox", return_tensors="pt").to(device)
result = decoder.generate(prompt_ids, max_new_tokens=50, temperature=0.0)
print(result.num_tokens) # ≤ 50
print(f"{result.tokens_per_second:.1f} tok/s")
print(f"p50={result.latency_p50*1000:.0f}ms p95={result.latency_p95*1000:.0f}ms")
print(result.summary())
# AR | 50 tok | 63.2 tok/s | p50=15ms p95=21ms | mem=0MB
# --- Plain-text convenience wrapper ---
text = decoder.generate_text("The quick brown fox", max_new_tokens=50, temperature=0.0)
print(text)
# --- Temperature sampling ---
result_sampled = decoder.generate(prompt_ids, max_new_tokens=50, temperature=0.8)Two-phase generation loop:
| Phase | What happens | Why |
|---|---|---|
| Prefill | Full prompt processed in one parallel forward pass | All prompt tokens are known upfront — no sequential dependency |
| Decode | One token generated per step using cached K/V | Each new token depends on the previously sampled token |
The time_to_first_token field in GenerationResult measures prefill latency
— the time a user waits before seeing any output. Subsequent tokens use the KV
cache and are faster (GEMV vs GEMM).
# Everything (132 tests total) — only test_kv_cache and test_autoregressive require model weights
pytest tests/ -v
# Individual suites
pytest tests/test_metrics.py -v # 66 tests, no GPU or model needed
pytest tests/test_timer.py -v # 26 tests, CUDA tier auto-skipped on CPU
pytest tests/test_kv_cache.py -v # 21 tests, downloads GPT-2 on first run
pytest tests/test_autoregressive.py -v # 22 tests, includes HF greedy-match check
# Skip CUDA-only timer tests on a CPU machine
pytest tests/test_timer.py -v -k "not cuda"
# Skip the slower HF-match parametrised tests
pytest tests/test_autoregressive.py -v -k "not hf_match"Test coverage by file:
| File | Tests | Requires |
|---|---|---|
test_metrics.py |
66 | numpy only |
test_timer.py |
26 (18 CPU + 8 CUDA) | torch; CUDA tier skipped if no GPU |
test_kv_cache.py |
21 | torch + transformers + GPT-2 weights |
test_autoregressive.py |
22 | torch + transformers + GPT-2 weights |
If you have Ollama installed, use this to check whether
a model fits on your hardware before downloading the full HuggingFace
weights. This is a convenience tool only — see
src/models/ollama_loader.py for why it cannot
be used in the benchmarking pipeline.
# Check if Ollama is running and list available models
python -m src.models.ollama_loader
# Test generation on a specific model
python -m src.models.ollama_loader --model llama3:8b --prompt "Explain attention in transformers"
# Adjust generation length
python -m src.models.ollama_loader --model llama3:8b --max-tokens 100Expected output (status check):
[ollama] checking server at http://localhost:11434 …
[ollama] ✓ server is running.
model params quant size (GB)
--------------------------------------------------------------
llama3:8b 8B Q4_0 4.7
phi3:mini 3.8B Q4_0 2.2
Four pairs are registered in src/models/registry.py:
| Pair name | Draft model | Target model | Draft dtype | Notes |
|---|---|---|---|---|
gpt2_dev |
gpt2 (124M) |
gpt2-xl (1.5B) |
fp16 | No token required. Fast to iterate on. |
tinyllama_llama3 |
TinyLlama-1.1B-Chat |
Llama-3-8B-Instruct |
fp16 | Shared BPE vocab → decent acceptance rate |
phi3_llama3 |
Phi-3-mini-4k (3.8B) |
Llama-3-8B-Instruct |
fp16 | Higher draft cost, higher expected acceptance |
llama3_self |
Llama-3-8B-Instruct (int4) |
Llama-3-8B-Instruct |
int4 / fp16 | Self-speculation: near-perfect acceptance, speedup from quantised draft bandwidth |
VRAM requirements (approximate):
| Pair | Draft VRAM | Target VRAM | Total |
|---|---|---|---|
gpt2_dev |
0.25 GB | 3.0 GB | ~3.3 GB |
tinyllama_llama3 |
2.2 GB | 16 GB | ~18 GB |
phi3_llama3 |
7.6 GB | 16 GB | ~24 GB |
llama3_self |
4.7 GB (int4) | 16 GB | ~21 GB |
- Project scaffold, environment, packaging
- Model registry with 4 draft/target pairs
- HuggingFace model loader with dtype handling (fp16, bf16, fp32, int4)
- Ollama convenience loader for hardware validation
- 150-prompt benchmark dataset across code / conversation / summarization
-
ModelWrapper— forward pass, KV cache truncation, cache length utilities -
CUDATimer/CUDATimerCollection— GPU-event-based latency measurement -
GenerationResult+BenchmarkConfig— result dataclasses with JSON serialisation - 132 tests across
test_kv_cache.py,test_timer.py,test_metrics.py,test_autoregressive.py
-
src/decoding/autoregressive.py— baseline greedy/sampling loop with per-token GPU timing ✅ -
src/decoding/rejection.py— rejection sampling + adjusted distribution -
src/decoding/speculative.py— full speculative decoding loop (K draft tokens → parallel target verification)
-
src/profiling/timer.py— CUDA event-based latency measurement ✅ -
src/profiling/memory.py— peak VRAM tracking per phase -
src/profiling/metrics.py—GenerationResult+BenchmarkConfigdataclasses ✅
- End-to-end benchmark runner across all 4 model pairs × 3 domains × K ∈ {1,2,4,8}
- Acceptance rate vs prompt domain analysis
- Speedup breakdown: time in draft vs target vs overhead
- Figures: speedup heatmap, acceptance rate distributions, latency CDFs
docs/speculative_decoding_theory.md covers:
- Why autoregressive inference is slow — arithmetic intensity analysis with concrete numbers for a 7B FP16 model (1 FLOP/byte vs 333 FLOP/byte ridge point, 0.3% GPU compute utilisation)
- How speculative decoding works — complete K=4 walkthrough with a 5-word vocabulary, including rejection decisions with actual probability values
- Why rejection sampling preserves the target distribution — full proof that P(output=x) = p_target(x) in both the p≥q and p<q cases
- What determines the speedup — derivation of
S = (avg_accepted + 1) / (1 + C_draft/T_target), numerical examples, and conditions under which speculative decoding hurts
MIT — see LICENSE.