Skip to content

perf(gated_delta_net): fold q/k L2-norm into the gated_delta_rule kernel#5396

Draft
yuchenwang3 wants to merge 1 commit into
NVIDIA:mainfrom
yuchenwang3:fix/gdn-qk-l2norm-in-kernel
Draft

perf(gated_delta_net): fold q/k L2-norm into the gated_delta_rule kernel#5396
yuchenwang3 wants to merge 1 commit into
NVIDIA:mainfrom
yuchenwang3:fix/gdn-qk-l2norm-in-kernel

Conversation

@yuchenwang3

Copy link
Copy Markdown

Background. Found while running ms-swift + Megatron-Core SFT of Qwen3.5-35B-A3B (GatedDeltaNet hybrid) at 128K context on 16× B200 across 2 nodes (2×8); this change ran in that real training.

What

GatedDeltaNet L2-normalizes q/k with an explicit l2norm(query_key) before the FLA gated_delta_rule kernel, while passing use_qk_l2norm_in_kernel=False. This materializes a normalized query_key [B, T, 2*Hk, 128] activation that must be kept for backward.

Folding the L2-norm into the kernel (use_qk_l2norm_in_kernel=self.use_qk_l2norm) lets FLA keep only the small rstd [B, T, H] vectors and recompute normalized q/k in backward via l2norm_bwd, removing the materialized normalized-q/k activation. GatedDeltaNet is ~3/4 of the layers in Qwen3.5-class hybrids, so this is a meaningful backward-activation saving at long context.

Numerically lossless

  • eps matches on all paths: FLA l2norm default eps=1e-6; FLA in-kernel l2norm_fwd default eps=1e-6; the torch deterministic path uses explicit eps=1e-6 — same as the previous explicit l2norm() default.
  • GQA: l2norm is per-head over dim=-1; repeat_interleave only duplicates heads, so l2norm(repeat_interleave(x)) == repeat_interleave(l2norm(x)). The in-kernel path normalizes post-repeat (Hv) heads; the previous explicit path normalized pre-repeat (Hk) heads — numerically identical per head.
  • When use_qk_l2norm=False, behavior is unchanged (kernel receives False, no norm).

Testing

Ran in real Qwen3.5-35B-A3B 128K SFT on 16× B200 (2 nodes). I could not run Megatron's GPU/FLA test suite locally (no GPU / no triton+FLA on my machine), so relying on CI.

Question for maintainers

The flag was hardcoded False with an explicit pre-kernel l2norm. If that was intentional for a specific path (e.g. cu_seqlens/packed-sequence or CP correctness of the in-kernel l2norm), please advise — happy to gate the fold behind a condition instead.

GatedDeltaNet L2-normalizes q/k with an explicit l2norm(query_key) before the
FLA kernel while passing use_qk_l2norm_in_kernel=False, materializing a
normalized query_key [B,T,2*Hk,128] activation kept for backward. Folding it
into the kernel (use_qk_l2norm_in_kernel=self.use_qk_l2norm) lets FLA keep only
rstd [B,T,H] and recompute normalized q/k in backward, removing that activation.
GatedDeltaNet is ~3/4 of layers in Qwen3.5-class hybrids -> meaningful backward
activation saving at long context.

Lossless: eps=1e-6 on all paths (FLA l2norm / in-kernel l2norm_fwd / torch
deterministic), matching the previous explicit l2norm default. GQA: per-head
l2norm commutes with repeat_interleave, so in-kernel (post-repeat Hv) ==
previous explicit (pre-repeat Hk) per head. use_qk_l2norm=False unchanged.

Signed-off-by: yuchenwang3 <eang333cms@gmail.com>
@copy-pr-bot

copy-pr-bot Bot commented Jun 17, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants