Sequence-packing utilities for training transformers on variable-length data. Given a list (or histogram) of sequence lengths and a fixed max_seq_len, seqpack produces bins that combine multiple short sequences into one packed context, with per-epoch randomized pairings and an attention mask that prevents tokens from one sequence from attending to another.
For the algorithm internals, see seqpack/PACKING.md.
seqpack/
packing.py # bin-packing functions
masking.py # attention-bias helper
stats.py # PackingStats — uniform efficiency reporting
tests/
test_seqpack.py
pyproject.toml
Installation:
pip install seqpackThen run the tests from the project root:
python -m unittest discover -s testsfrom seqpack.packing import heap_pack_sequences
lengths = [128, 64, 32, 16, 200, 80, 50]
bins = heap_pack_sequences(lengths, max_seq_len=256)
# bins == [[4, 0], [5, 3], [1, 2, 6]] (or similar — indices into `lengths`)Each inner list is one packed sample. Use the indices to fetch the underlying sequences from your dataset.
Worst-fit-decreasing using a max-heap. O(n log n). Use this when packing more than a few thousand sequences.
from seqpack.packing import heap_pack_sequences
bins = heap_pack_sequences(lengths, max_seq_len=2048)First-fit-decreasing with a linear bin scan. O(n²) worst case. Slightly fewer bins than heap_pack_sequences on most inputs, but slow at scale.
from seqpack.packing import greedy_pack_sequences
bins = greedy_pack_sequences(lengths, max_seq_len=2048)| Scenario | Use |
|---|---|
| < ~10k sequences | either; greedy is tighter |
| Hundreds of thousands or more | heap_pack_sequences |
| Billions | histogram path (below) |
Both packers accept seed and tolerance. With seed=None (the default) the output is deterministic. With a seed, items of equal length (or within tolerance of each other) are shuffled before packing — same templates, different pairings each epoch.
from seqpack.packing import heap_pack_sequences, seed_from_components
for epoch in range(num_epochs):
seed = seed_from_components(epoch) # stable, reproducible
bins = heap_pack_sequences(
lengths,
max_seq_len=2048,
tolerance=1, # only true ties shuffle
seed=seed,
)
# ... train on `bins`tolerance=1: only sequences of exactly equal length get reshuffled.tolerance=4: lengths within ±3 of each other share a bucket. More mixing, very small packing penalty.tolerance=16: aggressive mixing; packing efficiency may drop slightly.
seed_from_components(*ints) is a small helper that gives a stable seed from any combination of integers (epoch, dataset_id, worker_rank, etc.) — useful when you need different streams that are all reproducible.
If you want to keep the same packing but shuffle the order bins are visited in (e.g., for the dataloader sampler), use shuffle_bins:
from seqpack.packing import shuffle_bins
bin_indices = list(range(len(bins)))
order = shuffle_bins(bin_indices, seed=epoch)
for i in order:
yield bins[i]It doesn't mutate the input list and produces the same output for the same seed.
Packing on its own only solves the throughput problem — it tells you which short sequences should share a context. You still need two more things before you can feed the result to a transformer:
- A
sequence_idstensor that labels each token position with the packed sequence it belongs to. - An attention bias built from
sequence_idsthat prevents one sequence in the packed context from attending to tokens belonging to a different sequence.
Without (2) packing is broken. The whole point of packing is to share GPU work without sharing semantics: sequence A and sequence B happen to live in the same row of tokens, but the model must treat them as completely independent. Skipping the attention bias means token 0 of B would attend to every token of A and vice versa, so:
- The model learns spurious dependencies between concatenated documents.
- Causal LMs leak future tokens of "the next sequence" into "the previous sequence" — disastrous for any autoregressive loss.
- Loss/gradient updates couple unrelated training examples; sample efficiency drops, and convergence behavior diverges from the non-packed baseline.
build_packing_attention_bias produces the additive mask that fixes this: 0 where attention is allowed (within the same sequence), -inf where it must be blocked (across sequences). You add it to the QK^T scores in every attention layer.
A bin from any of the packers is List[int] — sequence indices into your dataset. To build the model input, fetch the actual token tensors, concatenate them, and emit a sequence_ids array the same length as the concatenated tokens. Each sequence in the bin gets a unique integer label inside this packed row.
import numpy as np
import torch
def pack_one_bin(bin_seq_ids, raw_dataset, max_seq_len, pad_token_id):
"""
bin_seq_ids: List[int] -- e.g. [4, 1] from materialize_epoch or heap_pack_sequences
raw_dataset[i]: 1-D token tensor for sequence i
Returns: (tokens [L], sequence_ids [L]) both of length max_seq_len.
"""
tokens_chunks = []
seq_id_chunks = []
for local_id, seq_idx in enumerate(bin_seq_ids):
seq_tokens = raw_dataset[seq_idx] # [n_i]
tokens_chunks.append(seq_tokens)
seq_id_chunks.append(
torch.full_like(seq_tokens, local_id) # all positions in seq -> local_id
)
tokens = torch.cat(tokens_chunks, dim=0) # [sum(n_i)]
seq_ids = torch.cat(seq_id_chunks, dim=0) # [sum(n_i)]
# Pad to max_seq_len. Padding positions reuse the last sequence's ID:
# cross-sequence blocking already prevents real tokens from attending to them.
pad_len = max_seq_len - tokens.numel()
if pad_len > 0:
tokens = torch.cat([tokens, tokens.new_full((pad_len,), pad_token_id)])
seq_ids = torch.cat([seq_ids, seq_ids.new_full((pad_len,), seq_ids[-1].item())])
return tokens, seq_idsTwo important details:
- Local sequence IDs: the labels are local to the bin (
0, 1, 2, ...), not global dataset indices. The bias only needs to know "are these two positions from the same packed sequence?" — a per-bin counter is sufficient and keeps the integer range small. - Padding positions reuse the last sequence's ID, not a sentinel like
-1. This works because cross-sequence blocking already keeps real tokens from attending into the padding region — no extra special-case logic is needed inside the mask.
build_packing_attention_bias accepts [B, L] (or [L]) and returns [B, 1, L, L]:
from seqpack.masking import build_packing_attention_bias
tokens_batch, seq_ids_batch = [], []
for bin_seq_ids in next_B_bins:
t, s = pack_one_bin(bin_seq_ids, raw_dataset, max_seq_len, pad_token_id)
tokens_batch.append(t)
seq_ids_batch.append(s)
tokens = torch.stack(tokens_batch).to("cuda") # [B, L]
seq_ids = torch.stack(seq_ids_batch).to("cuda") # [B, L]
attn_bias = build_packing_attention_bias(
seq_ids,
dtype=torch.bfloat16,
device="cuda",
)
# attn_bias.shape == [B, 1, L, L]; broadcast over heads.Build the bias once in the parent model's forward, then pass it to every layer — recomputing per layer is wasted work since seq_ids doesn't change:
class PackedTransformer(nn.Module):
def forward(self, tokens, seq_ids):
attn_bias = build_packing_attention_bias(seq_ids, self.dtype, tokens.device)
x = self.embed(tokens)
for layer in self.layers:
x = layer(x, attn_bias=attn_bias)
return self.head(x)Inside each attention layer you add the bias to the raw scores before the softmax:
scores = (q @ k.transpose(-1, -2)) * self.scale # [B, H, L, L]
scores = scores + attn_bias # [B, 1, L, L] broadcasts over heads
weights = scores.softmax(dim=-1)-inf entries softmax to exactly 0 — they contribute nothing to the output and no gradient flows back through them. Allowed entries see a zero added, so the bias is a no-op for them.
The packing bias only blocks cross-sequence attention. For autoregressive training you still need a causal (lower-triangular) mask. Add both:
L = seq_ids.size(-1)
causal = torch.zeros(L, L, dtype=self.dtype, device=tokens.device)
causal.masked_fill_(torch.triu(torch.ones(L, L, dtype=torch.bool, device=tokens.device), diagonal=1), float("-inf"))
attn_bias = build_packing_attention_bias(seq_ids, self.dtype, tokens.device) + causalThe sum is still 0 for allowed positions and -inf everywhere else — -inf + 0 = -inf, -inf + -inf = -inf. Each sequence in the packed row gets its own independent causal triangle.
A quick test that the bias is doing its job: pick a packed batch, run the model, and verify that the loss on each sequence matches the loss you'd get by feeding that sequence on its own (no packing). If the two diverge, your bias isn't being applied somewhere — most often it's a layer that's recomputing its own mask and ignoring the one you passed in.
The [B, 1, L, L] bias is the dense representation: it forces the kernel to compute every QK^T entry and then mask them out. Modern fast-attention kernels (FlashAttention 2, PyTorch SDPA's varlen variant, xformers BlockDiagonalMask) prefer the compact representation — a cu_seqlens offset array that tells the kernel exactly which token spans to self-attend over and lets it skip the cross-sequence work entirely.
build_packing_cu_seqlens builds this from the same sequence_ids tensor:
from seqpack import build_packing_cu_seqlens
sequence_ids = torch.tensor([
[0, 0, 0, 0, 0, 1, 1, 1, 0, 0], # 5+3 real tokens, then 2 padding
[0, 0, 0, 0, 1, 1, 2, 2, 2, 2], # 4+2+4 real tokens, no padding
])
attention_mask = torch.tensor([
[1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
])
cu_seqlens, max_seqlen = build_packing_cu_seqlens(sequence_ids, attention_mask)
# cu_seqlens = tensor([ 0, 5, 8, 12, 14, 18], dtype=torch.int32)
# max_seqlen = 5Sub-sequence i lives at positions [cu_seqlens[i], cu_seqlens[i+1]) in the flattened, padding-stripped token stream. Pass directly to FlashAttention's varlen kernel:
from flash_attn import flash_attn_varlen_func
# Strip padding and flatten q/k/v to [total_real_tokens, n_heads, head_dim].
mask_flat = attention_mask.view(-1).bool()
q_flat = q.view(-1, n_heads, head_dim)[mask_flat]
k_flat = k.view(-1, n_heads, head_dim)[mask_flat]
v_flat = v.view(-1, n_heads, head_dim)[mask_flat]
out = flash_attn_varlen_func(
q_flat, k_flat, v_flat,
cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen,
causal=True, # applied per sub-sequence
)
# out: [total_real_tokens, n_heads, head_dim] — scatter back to [B, L, ...] if needed.Use the dense bias when (a) you're not on a varlen-capable kernel, (b) you need a single mask object for portability, or (c) you have very few sequences per row and the kernel overhead of varlen outweighs the savings. Use varlen for any serious training workload — the wins grow with packing density.
When the dataset gets big, the per-sequence packers stop being practical. At 3B sequences, just the input lengths: List[int] is ~24 GB and the output List[List[int]] is tens of GB. The histogram path packs length classes (a Mapping[length, count]) instead of individuals — the packing plan ends up as a tiny Counter, and concrete sequence IDs are bound to bins lazily per epoch.
There are three functions involved:
pack_length_histogram/pack_length_histogram_batched— produce a packing plan (Counter of bin templates).length_histogram_from_lengths— in-memory helper to build(counts, pools)from a flat list.materialize_epoch— yield concreteList[int]bins for one epoch by drawing from per-length pools.
from seqpack.packing import (
length_histogram_from_lengths,
materialize_epoch,
pack_length_histogram_batched,
seed_from_components,
)
# Pretend this came from your dataset:
lengths = [7, 5, 5, 5, 5, 5, 3, 3, 3] # 9 sequences, ids 0..8
# Step 1: build the histogram + per-length pools.
counts, pools = length_histogram_from_lengths(lengths)
# counts = Counter({5: 5, 3: 3, 7: 1})
# pools = {7: array([0]), 5: array([1, 2, 3, 4, 5]), 3: array([6, 7, 8])}
# Step 2: pack the histogram into a plan. Run this ONCE per dataset.
templates = pack_length_histogram_batched(counts, max_seq_len=10)
# Counter({(5, 5): 2, (7, 3): 1, (5, 3): 1, (3,): 1})
# 5 bins total; each key is a sorted-descending tuple of lengths.
# Step 3: materialize concrete bins for an epoch.
for epoch in range(3):
seed = seed_from_components(epoch)
for bin_seq_ids in materialize_epoch(templates, pools, seed=seed):
print(epoch, bin_seq_ids, "lengths:", [lengths[i] for i in bin_seq_ids])Sample output (the exact IDs change per epoch; templates do not):
0 [4, 1] lengths: [5, 5]
0 [3, 2] lengths: [5, 5]
0 [0, 7] lengths: [7, 3]
0 [5, 8] lengths: [5, 3]
0 [6] lengths: [3]
1 [2, 5] lengths: [5, 5]
1 [0, 6] lengths: [7, 3]
...
Same plan, different concrete pairings each epoch — no re-packing required.
| Function | When to use |
|---|---|
pack_length_histogram |
Reference implementation. Strict per-item WFD. O(N log B) — slow at scale. |
pack_length_histogram_batched |
Recommended default. Same output shape, O(distinct_lengths × log B + B). |
Bin counts from the two are typically identical or within ~1% on realistic distributions.
The output of either packer is a Counter[BinTemplate] where BinTemplate = Tuple[int, ...] is a sorted-descending tuple of lengths summing to ≤ max_seq_len. For a homogeneous corpus of 1 billion length-512 sequences with max_seq_len = 2048, the entire packing plan is one entry:
Counter({(512, 512, 512, 512): 250_000_000})That's the whole plan — a few bytes describing how to pack a billion sequences.
Inside the generator:
- Each per-length pool is shuffled with the epoch seed (one
rng.permutationper length). - Distinct template types are visited in a seed-shuffled order.
- Each emitted bin pulls one ID per slot from the shuffled pool at that length, advancing a cursor.
Iterator memory is O(distinct_template_types + distinct_lengths) — independent of dataset size. Per-length pools live wherever you put them (RAM, or np.memmap on disk for huge datasets).
The histogram path is designed to scale, but a few practical points matter once N gets into the hundreds-of-millions to billions range.
length_histogram_from_lengths materializes the full lengths array in RAM. For billion-scale, skip it and produce counts and pools from your dataset directly:
# In your data prep job (Spark, DuckDB, polars, etc.):
# SELECT length, COUNT(*) FROM dataset GROUP BY length -> counts
# SELECT sequence_id FROM dataset WHERE length = L -> pools[L]
#
# Save counts as a tiny JSON/parquet, pools as one file per length (np.save).At training time:
import json
import numpy as np
from pathlib import Path
counts = {int(k): v for k, v in json.loads(Path("counts.json").read_text()).items()}
pools = {L: np.load(f"pools/{L}.npy", mmap_mode="r") for L in counts}mmap_mode="r" keeps the per-length arrays on disk and only pages in the bytes you actually touch.
pack_length_histogram_batched is fast but not free, and the templates only depend on (counts, max_seq_len) — both stable across epochs. Pack once, pickle the Counter, reload it next run:
import pickle
cache = Path(f"templates_max{max_seq_len}.pkl")
if cache.exists():
templates = pickle.loads(cache.read_bytes())
else:
templates = pack_length_histogram_batched(counts, max_seq_len)
cache.write_bytes(pickle.dumps(templates))A typical templates Counter is a few hundred KB even for billion-sequence corpora.
pack_length_histogram does N heap operations internally — fine for tens of millions, painful for billions. pack_length_histogram_batched does O(distinct_lengths × log B), usually completing in seconds for any realistic input. Use it as the default; keep the per-item one only as a reference / for unit tests.
The packer rejects any length greater than max_seq_len with a ValueError. Truncate or filter your lengths before building the histogram — a single rogue entry will crash the job several hours into data prep.
Bucket-shuffling within tolerance widens which sequences are interchangeable for the per-sequence packers (heap_pack_sequences, greedy_pack_sequences). It has no effect on pack_length_histogram* — there's no sort-with-ties step in that path. The per-epoch shuffling you get from materialize_epoch's pool permutation already gives you new pairings; you don't need tolerance on top.
materialize_epoch calls rng.permutation(arr) once per length, which materializes a shuffled int64 array of size len(arr). Summed across lengths, that's N integers — fine for hundreds of millions, but at 3B sequences it's ~24 GB of RAM if you don't use memmaps.
For truly RAM-light billion-scale, the right fix is a stateless seeded permutation (a Feistel or LCG-based bijection on [0, n)) that returns the i-th element of a permuted pool in O(1) without materializing the permuted array. This isn't implemented yet — for now, either:
- Use memmapped pools and let the OS handle the working set, or
- Shard your dataset (e.g., 10 chunks of 300M sequences each) and pack each chunk independently. Templates stay tiny per shard; the dataloader concatenates them.
If you can't keep all the pools in scope, shard the dataset by some natural key (corpus, modality, file) and run the histogram pipeline per shard:
shards = ["pile.txt", "code.parquet", "books.arrow"]
per_shard_templates = {}
per_shard_pools = {}
for shard in shards:
counts, pools = build_counts_and_pools(shard) # your code
per_shard_templates[shard] = pack_length_histogram_batched(counts, max_seq_len)
per_shard_pools[shard] = pools
# At training time, interleave shards round-robin or with a weighted sampler.The total bin count is the sum across shards, with negligible packing efficiency loss at shard boundaries.
For large datasets, build counts/pools/templates once with the bundled CLI instead of doing it at training time. The CLI reads a parquet file — you tell it which column holds sequence lengths (and optionally which column holds stable sequence IDs):
pip install 'seqpack[prepare]' # pulls in pyarrow
python -m seqpack.prepare \
--input dataset.parquet \
--length-column seq_len_tokens \
--id-column doc_id \
--max-seq-len 2048 \
--output prepared/Output layout:
prepared/
manifest.json # max_seq_len, n_sequences, n_bins, efficiency, ...
counts.json # length -> count
templates.json # [(template_lengths, multiplicity), ...]
pools/
{length}.npy # int64 sequence-ID array for each length
Load it at training time without re-packing:
from seqpack.prepare import load_prepared
from seqpack.packing import materialize_epoch, seed_from_components
manifest, counts, templates, pools = load_prepared("prepared/")
# pools are loaded with mmap_mode="r" by default — zero-copy.
for epoch in range(num_epochs):
seed = seed_from_components(epoch)
for bin_seq_ids in materialize_epoch(templates, pools, seed=seed):
...Flags worth knowing:
--length-column NAME(required): the parquet column name holding lengths. Required because column names are dataset-specific.--id-column NAME(optional): a column holding stable per-sequence IDs. If omitted, row index0..N-1is used.--filter-over-cap: drop sequences whose length exceeds--max-seq-leninstead of erroring out.--packer {batched,per-item}: defaults tobatched. Useper-itemonly for parity testing.
The CLI uses Python logging (not print) — set LOGLEVEL=DEBUG or wire it into your training pipeline's logging config if needed.
For distributed training, seed each worker differently but reproducibly:
seed = seed_from_components(epoch, dist.get_rank(), worker_info.id)seed_from_components is a stable hash, so the same (epoch, rank, worker) always gets the same seed across restarts — important for resumability.
PackingStats gives you a uniform summary across either output shape — List[List[int]] from the per-sequence packers or Counter[BinTemplate] from the histogram path — so you can A/B compare them directly on the same input:
from seqpack import PackingStats, heap_pack_sequences, pack_length_histogram_batched
bins = heap_pack_sequences(lengths, max_seq_len=2048)
print(PackingStats.from_bins(bins, lengths, max_seq_len=2048))
templates = pack_length_histogram_batched(counts, max_seq_len=2048)
print(PackingStats.from_templates(templates, max_seq_len=2048))Sample output:
PackingStats(
bins: 5
sequences: 15 (avg 3.00/bin)
tokens: 1,026 / 1,280 capacity
efficiency: 80.16%
padding: 254 (19.84%)
fullness: p50=96.1% p90=98.8% p99=99.9%
)
Fields are accessible directly for programmatic use (stats.efficiency, stats.padding_ratio, stats.bin_fullness_p50, ...). For typical natural-language length distributions and a reasonable max_seq_len, expect ≥97% efficiency. If it drops below 90%, the length distribution probably has long-tail items larger than max_seq_len / 2 — consider raising max_seq_len or splitting long sequences. The from_templates path works in O(distinct_templates) without materializing per-bin lists, so it's safe to call on billion-scale histograms.
from torch.utils.data import Dataset
from seqpack.packing import (
length_histogram_from_lengths,
materialize_epoch,
pack_length_histogram_batched,
seed_from_components,
)
class PackedDataset(Dataset):
def __init__(self, raw_dataset, lengths, max_seq_len):
self.raw = raw_dataset
counts, self.pools = length_histogram_from_lengths(lengths)
self.templates = pack_length_histogram_batched(counts, max_seq_len)
self.bins = []
def set_epoch(self, epoch: int):
seed = seed_from_components(epoch)
self.bins = list(materialize_epoch(self.templates, self.pools, seed=seed))
def __len__(self):
return len(self.bins)
def __getitem__(self, idx):
return [self.raw[i] for i in self.bins[idx]]Call dataset.set_epoch(epoch) before each epoch.
from torch.utils.data import IterableDataset
class PackedDataset(IterableDataset):
def __init__(self, raw_dataset, length_counts, length_to_indices, max_seq_len):
self.raw = raw_dataset
self.pools = length_to_indices # can be memmaps
self.templates = pack_length_histogram_batched(length_counts, max_seq_len)
self.epoch = 0
def set_epoch(self, epoch: int):
self.epoch = epoch
def __iter__(self):
seed = seed_from_components(self.epoch)
for seq_ids in materialize_epoch(self.templates, self.pools, seed=seed):
yield [self.raw[i] for i in seq_ids]The generator yields one bin at a time; peak memory is one bin plus the cursor state.
| Function | Purpose |
|---|---|
sort_indices_by_length(lengths_arr, tolerance=1, seed=None) |
Sort indices descending by length, with optional intra-bucket shuffle. Used internally by both packers; exposed for advanced use. |
greedy_pack_sequences(lengths, max_seq_len, tolerance=1, seed=None) |
First-fit-decreasing bin packing. Tightest packing, O(n²) worst case. |
heap_pack_sequences(lengths, max_seq_len, tolerance=1, seed=None) |
Worst-fit-decreasing via max-heap. O(n log n). Recommended for most workloads. |
pack_length_histogram(length_counts, max_seq_len) |
WFD on a length histogram. Returns Counter[BinTemplate]. Memory O(distinct_lengths). |
pack_length_histogram_batched(length_counts, max_seq_len) |
Same output shape; batched placement for billion-scale inputs. |
materialize_epoch(templates, length_to_indices, seed=None) |
Yield concrete List[int] bins for one epoch from templates and per-length pools. |
length_histogram_from_lengths(lengths) |
In-memory helper: builds (counts, pools) from a flat list of lengths. |
shuffle_bins(bin_indices, seed=None) |
Permute a list of bin indices; doesn't mutate input. |
seed_from_components(*components) |
Stable hash → seed integer from any combination of ints. |
| Function | Purpose |
|---|---|
PackingStats.from_bins(bins, lengths, max_seq_len) |
Build a PackingStats from List[List[int]] output (greedy/heap packers). |
PackingStats.from_templates(templates, max_seq_len) |
Build a PackingStats from Counter[BinTemplate] output (histogram packers). Runs in O(distinct_templates). |
| Function | Purpose |
|---|---|
build_packing_attention_bias(sequence_ids, dtype, device) |
Dense [B, 1, L, L] additive bias: 0 within a sequence, -inf across sequences. Use when the attention kernel doesn't support varlen. |
build_packing_cu_seqlens(sequence_ids, attention_mask=None) |
Compact (cu_seqlens, max_seqlen) for FlashAttention 2 / SDPA varlen / xformers. Skips cross-sequence and padding compute entirely. |
python -m unittest discover -s testsIf you use this software in your research, please cite it as follows:
@software{alsamkary2026seqpack,
author = {Hazem Alsamkary},
title = {seqpack: r1.0.1},
month = jun,
year = 2026,
publisher = {Zenodo},
version = {v1.0.1},
doi = {10.5281/zenodo.20530927},
url = {https://doi.org/10.5281/zenodo.20530927}
}