cortex is a PyTorch library for building recurrent backbones, memory modules, and routed mixtures of stateful experts.
It separates stateful computation from architectural wrapping so you can compose new memory stacks quickly and safely.
At the public API level, Cortex is organized around three concepts:
core: the primitive stateful computation unitscaffold: the residual/projection/gating wrapper around a corecell: a programmer-facing pairing ofscaffold + core
The package is published as cortexcore and imported as cortex.
pip install cortexcore- Install
- Architecture
- Quick Start
- Supported Components
- Column and MoE Routing
- Advanced Setup
- AxonLayer
- Evaluate Quickly
- Backend Configuration
- Extending Cortex
- Package Layout
Cortex has five important abstractions:
-
Core
- The runtime stateful unit.
- Responsible for recurrence, attention-style memory, or temporal state updates.
- Examples:
LSTMCore,mLSTMCore,sLSTMCore,XLCore,AGaLiTeCore,AxonCore.
-
Scaffold
- A wrapper around a core that handles projection, gating, residual structure, normalization, and routing.
- Examples:
PassThroughScaffold,PreUpScaffold,PostUpScaffold,AdapterScaffold,ColumnScaffold.
-
Cell
- The public config-level unit most users work with.
- A
CellConfigbundles a scaffold and a core so you can specify expert sets directly. - Examples:
AxonCellConfig(),XLCellConfig(),mLSTMCellConfig().
-
Column
- A routed mixture of expert scaffolds executed in parallel and mixed by a router.
- Used by the auto builders to form MoE-style memory layers.
-
Stack
- A layered composition where each layer is a Column, and each Column is a mixture of cells.
Separating cores from scaffolds gives you:
- Modularity: swap temporal mechanisms without rewriting projection or residual logic
- Composability: reuse the same cores in different scaffold patterns and routed mixtures
- Clarity: keep memory computation separate from architectural wrapping
- Auto-configuration: infer working dimensions from
d_hiddeninstead of manually wiring every hidden size - Flexibility: define explicit cells for programmer-facing recipes while keeping low-level control available
- Gradient stability: use skip paths, gating, normalization, and routed mixtures without baking those concerns into the core
- Performance: select Triton, CUDA, or PyTorch backends per core as available
A core design principle is that cores, scaffolds, columns, and stacks all share the same runtime interface. A cell is a config object, but the modules it expands into follow the same stateful contract:
def forward(
x: Tensor,
state: TensorDict | None,
*,
resets: ResetMask | None = None,
) -> tuple[Tensor, TensorDict | None]:
...- Input:
[B, T, H]for sequence mode or[B, H]for step mode - State: always a
TensorDict, with nesting determined by the current abstraction - Output: same outer shape as input
- Resets: optional episode-boundary mask, passed automatically through stacks and scaffolds down into cores
Typical state structure:
- Core state: flat, for example
{"h": ..., "c": ...} - Scaffold state: wraps core state under the core module class name, for example
{"LSTMCore": {"h": ..., "c": ...}} - Column state: one entry per expert scaffold
- Stack state: one entry per scaffold, indexed by scaffold type and position
This uniformity lets you treat a deep memory stack like any other recurrent module.
The highest-level builder is build_cortex_auto_stack. It creates a stack of Column layers from explicit cell lists.
import torch
from cortex.cells import AxonCellConfig, XLCellConfig, mLSTMCellConfig, sLSTMCellConfig
from cortex.stacks import build_cortex_auto_stack
stack = build_cortex_auto_stack(
d_hidden=256,
num_layers=2,
layers=[
[AxonCellConfig(), XLCellConfig(), mLSTMCellConfig(), sLSTMCellConfig()],
[AxonCellConfig(), XLCellConfig(), mLSTMCellConfig(), sLSTMCellConfig()],
],
)
x = torch.randn(4, 32, 256)
y, state = stack(x, state=None)Step mode uses [B, H] inputs:
x_t = torch.randn(4, 256)
state = stack.init_state(batch=4, device=x_t.device, dtype=x_t.dtype)
y_t, state = stack.step(x_t, state)If layers is omitted, the builder repeats default_cells() for num_layers layers.
Cells are the easiest way to specify the expert set you want. Each built-in cell chooses a default scaffold and core:
| Cell Config | Default Scaffold | Default Core | Notes |
|---|---|---|---|
AxonCellConfig() |
PostUpGatedScaffoldConfig |
AxonCoreConfig |
Streaming RTU-style memory with stateful linear dynamics |
XLCellConfig() |
PostUpGatedScaffoldConfig |
XLCoreConfig |
Transformer-XL style memory with rolling cache |
mLSTMCellConfig() |
PreUpGatedScaffoldConfig |
mLSTMCoreConfig |
Matrix-LSTM with chunked sequence updates |
sLSTMCellConfig() |
PostUpGatedScaffoldConfig |
sLSTMCoreConfig |
Structured LSTM with stabilized accumulators |
LSTMCellConfig() |
PassThroughScaffoldConfig |
LSTMCoreConfig |
Plain recurrent baseline |
CausalConv1dCellConfig() |
PassThroughScaffoldConfig |
CausalConv1dCoreConfig |
Causal convolutional memory core |
AGaLiTeCellConfig() |
PostUpGatedScaffoldConfig |
AGaLiTeCoreConfig |
Attention-style recurrent discounted state |
You can override the nested core directly:
from cortex.cells import XLCellConfig, mLSTMCellConfig
from cortex.config import XLCoreConfig, mLSTMCoreConfig
layer = [
XLCellConfig(core=XLCoreConfig(mem_len=256, use_axon_qkv=True)),
mLSTMCellConfig(core=mLSTMCoreConfig(num_heads=8, chunk_size=32)),
]And you can always construct an explicit cell yourself:
from cortex import CellConfig
from cortex.config import PostUpScaffoldConfig, XLCoreConfig
cell = CellConfig(
scaffold=PostUpScaffoldConfig(proj_factor=2.0),
core=XLCoreConfig(mem_len=128),
)override_global_configs lets you modify every matching config type inside an auto-built stack without rebuilding each
layer manually.
from cortex.cells import XLCellConfig
from cortex.config import RouterConfig, XLCoreConfig
from cortex.stacks import build_cortex_auto_stack
stack = build_cortex_auto_stack(
d_hidden=256,
num_layers=3,
layers=[[XLCellConfig()]] * 3,
router=RouterConfig(d_key=128, temperature=0.7),
override_global_configs=[XLCoreConfig(mem_len=512)],
)Notes:
- Overrides apply by model type across the generated config tree.
- Only explicitly set fields are merged.
- Core
hidden_sizeis still inferred by the enclosing scaffold/stack where applicable.
Cortex can inject route-ID-aware low-rank adapters into linear-like modules across the stack:
import torch
from cortex import RoutedAdapterConfig
from cortex.cells import AxonCellConfig, XLCellConfig
from cortex.stacks import build_cortex_auto_stack
stack = build_cortex_auto_stack(
d_hidden=256,
layers=[[AxonCellConfig(), XLCellConfig()]],
routed_adapter=RoutedAdapterConfig(num_slots=64, rank=8, dropout=0.0),
)
B, T = 4, 16
x = torch.randn(B, T, 256)
route_ids = torch.tensor([0, 5, 12, 7], dtype=torch.long)
out, state = stack(x, route_ids=route_ids)Notes:
route_idshas shape[B]- sequence inputs
[B, T, H]reuse the same route per batch row across time compile_scaffoldsis disabled automatically when routed adapters are enabledset_trunk_lr_mult_(stack, value)can scale gradients on non-adapter parameters at runtime
Current caveats:
LSTMCoreusesnn.LSTMweights directly, so routed adapters do not currently affect itCausalConv1dCorehas no linear projections to adapt- some Axon-backed sequence projections flatten batch and time, so not every linear path can be route-adapted
Public cell configs live in cortex.cells. They are the main user-facing selection surface for auto-built stacks and
columns.
The default expert recipe returned by default_cells() is:
from cortex.cells import default_cells
cells = default_cells()This expands to:
AxonCellConfig()XLCellConfig()mLSTMCellConfig()sLSTMCellConfig()
Low-level core configs live in cortex.config; runtime core modules live in cortex.cores.
| Core | Config | Description | Triton Accelerated | CUDA Accelerated |
|---|---|---|---|---|
LSTMCore |
LSTMCoreConfig |
LSTM wrapper with TensorDict state (h, c), step/sequence parity, and optional resets |
Yes | No |
mLSTMCore |
mLSTMCoreConfig |
Matrix-LSTM with per-head state, chunked closed-form updates, and optional causal Conv1D preprocessing | Yes | No |
sLSTMCore |
sLSTMCoreConfig |
Structured LSTM with stabilized accumulators (c, n, m) and optional causal Conv1D preprocessing |
Yes | No |
CausalConv1dCore |
CausalConv1dCoreConfig |
Depthwise causal Conv1D memory with ring-buffer state and optional channel mixing | Yes (channel-mixing mode) | No |
AxonCore |
AxonCoreConfig |
Streaming RTU with diagonal recurrent transitions and stateful linear dynamics | Yes | Yes |
XLCore |
XLCoreConfig |
Transformer-XL style multi-head attention with rolling memory and optional Axon-backed Q/K/V projections | No | No |
AGaLiTeCore |
AGaLiTeCoreConfig |
Attention-style recurrent discounted state with fused discounted-sum support | No | Yes |
Notes:
- Triton kernels are selected automatically on CUDA when constraints are satisfied; otherwise PyTorch fallback is used
- resets are optional and broadcast-safe:
[B, T]in sequence mode and[B]in step mode HFLlamaLayerConfigis also available for wrapping Hugging Face LLaMA decoder layers inside aCortexStack
Scaffolds wrap cores with projection, gating, routing, and residual structure.
| Scaffold | Config | Description |
|---|---|---|
PassThroughScaffold |
PassThroughScaffoldConfig |
Runs the core directly at d_hidden with residual connection; no learned up/down projection |
PreUpScaffold |
PreUpScaffoldConfig |
Projects to d_inner = int(proj_factor * d_hidden), runs the core at d_inner, then projects back |
PreUpGatedScaffold |
PreUpGatedScaffoldConfig |
Pre-up scaffold with GRU-style gating around the transformed path |
PostUpScaffold |
PostUpScaffoldConfig |
Runs the core at d_hidden, then applies an FFN-style up/down projection before the residual merge |
PostUpGatedScaffold |
PostUpGatedScaffoldConfig |
Post-up scaffold with GTrXL-style gating, useful for deeper stacks |
AdapterScaffold |
AdapterScaffoldConfig |
Wraps another scaffold with an identity-initialized trainable residual adapter |
ColumnScaffold |
ColumnScaffoldConfig |
Runs multiple expert scaffolds in parallel and mixes them through a router |
Hidden Size Inference in Scaffolds
Some scaffolds control the working dimension of their nested core and will override the core's hidden_size when
building through CortexStackConfig or the auto builders.
Scaffolds that set the core hidden size:
PassThroughScaffold:core.hidden_size = d_hiddenPostUpScaffold:core.hidden_size = d_hiddenPostUpGatedScaffold:core.hidden_size = d_hiddenPreUpScaffold:core.hidden_size = int(proj_factor * d_hidden)PreUpGatedScaffold:core.hidden_size = int(proj_factor * d_hidden)
Best practice is to make this explicit by setting hidden_size=None in the core config:
from cortex.config import LSTMCoreConfig, PreUpScaffoldConfig
cfg = PreUpScaffoldConfig(
core=LSTMCoreConfig(hidden_size=None, num_layers=2),
proj_factor=2.0,
)If you instantiate scaffolds and cores manually without the stack builder, you must satisfy these dimension relationships yourself.
ColumnScaffold is the mixture-of-experts building block used by the auto stack builders. A Column runs multiple expert
scaffolds in parallel and mixes their deltas with:
- a global prior router
- optional per-token refinement
- an expert-axis mixer
- outer ReZero-style residual control
Build a Column from explicit cells:
from cortex import CellConfig, RouterConfig, build_column_auto_config
from cortex.cells import AxonCellConfig, XLCellConfig
from cortex.config import PostUpScaffoldConfig, XLCoreConfig, sLSTMCoreConfig
col_cfg = build_column_auto_config(
d_hidden=256,
cells=[
AxonCellConfig(),
XLCellConfig(core=XLCoreConfig(mem_len=128, use_axon_qkv=True)),
CellConfig(
scaffold=PostUpScaffoldConfig(proj_factor=1.5),
core=sLSTMCoreConfig(num_heads=4),
),
],
router=RouterConfig(top_k=2, temperature=0.7, whisper_lambda=0.1),
)You can also pass raw scaffold configs directly to cells=[...] when you want exact expert definitions.
For a token t, the Column behaves conceptually like:
In practice, this gives you routed expert mixing while preserving the clean stack interface.
For full control, construct a CortexStackConfig directly:
import torch
from cortex import CortexStackConfig, build_cortex
from cortex.config import LSTMCoreConfig, PassThroughScaffoldConfig, PreUpScaffoldConfig
cfg = CortexStackConfig(
d_hidden=256,
scaffolds=[
PreUpScaffoldConfig(
core=LSTMCoreConfig(hidden_size=None, num_layers=2),
proj_factor=2.0,
),
PassThroughScaffoldConfig(
core=LSTMCoreConfig(hidden_size=256, num_layers=1),
),
],
post_norm=True,
)
stack = build_cortex(cfg)
state = stack.init_state(batch=4, device="cpu", dtype=torch.float32)
x = torch.randn(4, 16, 256)
out, state = stack(x, state)If you want exact control over a routed layer, build the ColumnScaffoldConfig first and place it inside a stack:
from cortex import CortexStackConfig, RouterConfig, build_column_auto_config, build_cortex
from cortex.cells import AxonCellConfig, mLSTMCellConfig
column = build_column_auto_config(
d_hidden=256,
cells=[AxonCellConfig(), mLSTMCellConfig()],
router=RouterConfig(top_k=1),
)
stack = build_cortex(
CortexStackConfig(
d_hidden=256,
scaffolds=[column],
post_norm=True,
)
)cortex.stacks can wrap LLaMA decoder layers inside a CortexStack:
from cortex.stacks import build_hf_stack
stack = build_hf_stack(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
num_layers=8,
mem_len=128,
compile_scaffolds=False,
)This integration is currently LLaMA-focused.
AxonLayer is a stateful linear-like module built on top of AxonCore. It behaves like:
y = Linear(x) + AxonCore(x, state)
but stores its recurrent substate inside a caller-provided parent TensorDict. This makes it useful as a drop-in
stateful replacement for selected nn.Linear paths inside custom cores.
import torch
from tensordict import TensorDict
from cortex import AxonCoreConfig, AxonLayer, MemoryCore
from cortex.cores.core.axon_layer import update_parent_state
class MyCore(MemoryCore):
def __init__(self, hidden_size: int) -> None:
super().__init__(hidden_size=hidden_size)
self.proj = AxonLayer(
hidden_size,
hidden_size,
cfg=AxonCoreConfig(hidden_size=hidden_size, out_dim=hidden_size),
name="proj",
group="mycore",
)
def init_state(self, batch: int, *, device: torch.device | str, dtype: torch.dtype) -> TensorDict:
return TensorDict({}, batch_size=[batch], device=torch.device(device))
def forward(self, x, state, *, resets=None):
if state is None:
state = self.init_state(x.shape[0], device=x.device, dtype=x.dtype)
y = self.proj(x, state=state, resets=resets)
next_state = TensorDict({}, batch_size=[x.shape[0]], device=x.device)
update_parent_state(next_state, state)
return y, next_state
def reset_state(self, state, mask):
return self.proj.reset_state(mask, state=state)Internally, AxonLayer combines:
- a plain linear branch
- an
AxonCorebranch with local recurrent dynamics - parent-state integration through
TensorDict
Important properties:
- State-augmented linear projection: adds temporal structure to a linear-like path
- Optional SRHT mixing: improves conditioning for diagonal recurrent updates
- Streaming traces: supports long-horizon gradient flow with compact state
Several cores can opt into Axon-backed internal projections:
| Core | Components Replaced | Flags |
|---|---|---|
sLSTMCore |
fused gate projections | use_axon_layer |
mLSTMCore |
input/forget gates and optional QKV path | use_axon_layer, use_axon_qkv |
XLCore |
Q/K/V projections | use_axon_qkv |
The built-in cells that use these options can still be specified directly through their cell configs.
Cortex includes lightweight synthetic evaluations for sanity checking stacks and recipes.
Available tasks:
delayed_recallmajoritydyck
Run a single stack:
uv run python packages/cortex/evaluations/run.py --task delayed_recall --stack slstm_postupRun all registered evaluation stacks:
uv run python packages/cortex/evaluations/run.py --task majority --stack allCommon flags:
--epochs--batch-size--lr--seed--log-level {DEBUG, INFO, WARNING, ERROR}
More detail is in evaluations/README.md.
Cortex selects between Triton, CUDA, and PyTorch implementations based on availability and per-core constraints.
You can force the PyTorch path with:
CORTEX_DISABLE_TRITON=1 python your_script.pyor:
CORTEX_FORCE_PYTORCH=1 python your_script.pyUse this when:
- debugging numerical differences between backends
- testing on a machine without Triton
- forcing a simpler execution path during development
Cortex is registry-driven. New cores and scaffolds become available to the config system once registered.
import torch
from tensordict import TensorDict
from cortex import CoreConfig, MemoryCore, register_core
class GRUCoreConfig(CoreConfig):
num_layers: int = 1
@register_core(GRUCoreConfig)
class GRUCore(MemoryCore):
def __init__(self, cfg: GRUCoreConfig) -> None:
super().__init__(hidden_size=cfg.hidden_size)
self.net = torch.nn.GRU(
input_size=cfg.hidden_size,
hidden_size=cfg.hidden_size,
num_layers=cfg.num_layers,
batch_first=True,
)
self.num_layers = cfg.num_layers
def init_state(self, batch: int, *, device: torch.device | str, dtype: torch.dtype) -> TensorDict:
h = torch.zeros(batch, self.num_layers, self.hidden_size, device=device, dtype=dtype)
return TensorDict({"h": h}, batch_size=[batch])
def forward(self, x, state, *, resets=None):
...
def reset_state(self, state, mask):
...from cortex import BaseScaffold, ScaffoldConfig, register_scaffold
class GatedResidualScaffoldConfig(ScaffoldConfig):
gate_bias: float = 0.0
@register_scaffold(GatedResidualScaffoldConfig)
class GatedResidualScaffold(BaseScaffold):
def forward(self, x, state, *, resets=None):
...Both custom cores and scaffolds are then usable inside stack configs and builder flows.
If you want a reusable programmer-facing cell preset, define it as a composed CellConfig:
from cortex import CellConfig
from cortex.config import PostUpScaffoldConfig, XLCoreConfig
def LongMemXLCell() -> CellConfig:
return CellConfig(
scaffold=PostUpScaffoldConfig(proj_factor=2.0),
core=XLCoreConfig(mem_len=512, use_axon_qkv=True),
)Then use it directly in a stack recipe:
from cortex.stacks import build_cortex_auto_stack
stack = build_cortex_auto_stack(
d_hidden=256,
layers=[
[LongMemXLCell()],
[LongMemXLCell()],
],
)Complete worked examples:
cortex.cells: public composed cell configscortex.config: low-level config modelscortex.cores: core implementations and registrationcortex.scaffolds: scaffold implementations, columns, and builderscortex.stacks: stack classes and auto-builders
For most use cases:
- choose cells in
cortex.cells - build with
cortex.stacks.build_cortex_auto_stack - drop to
cortex.configwhen you want exact low-level control