Skip to content

fix(spin): fix tied embedding corruption in silent_run() and unconditional untie/fuse in run()#319

Open
sunnyxiaohu wants to merge 1 commit into
Tencent:mainfrom
sunnyxiaohu:fix/spinquant-tied-embedding
Open

fix(spin): fix tied embedding corruption in silent_run() and unconditional untie/fuse in run()#319
sunnyxiaohu wants to merge 1 commit into
Tencent:mainfrom
sunnyxiaohu:fix/spinquant-tied-embedding

Conversation

@sunnyxiaohu

Copy link
Copy Markdown
Contributor

Summary

Fix SpinQuant rotation transform corrupting the embedding lookup table when tie_word_embeddings=True, affecting all models with tied embeddings (e.g., Qwen3, LLaMA-3).

Problems

  1. silent_run() corrupts tied embeddingssilent_run() calls _apply_fused_ln() without first untying word embeddings. When lm_head.weight and embed_tokens.weight share the same underlying tensor (tied), fuse_ln_linear() modifies both simultaneously, corrupting the embedding lookup table.

  2. run() gates untie/fuse inside if "R1" branch_untie_word_embeddings() and _apply_fused_ln() are incorrectly placed inside the if "R1" in self.spin_config.rotation conditional block. Models using only R2 rotation skip these critical steps, leading to either tied embedding corruption or missing norm fusion.

Changes

File Fix
spin.py Move _untie_word_embeddings() to the top of both run() and silent_run(), unconditionally before any fuse/rotation operations
spin.py Move _apply_fused_ln() out of the if "R1" branch to execute unconditionally
spin.py Remove redundant _untie_word_embeddings() call inside _apply_fused_ln()

Root Cause

When tie_word_embeddings=True, lm_head.weight is a reference (not a copy) to embed_tokens.weight. Any in-place modification to one affects the other. The fused layer norm operation scales lm_head.weight in-place, which simultaneously corrupts embed_tokens.weight, causing garbage token embeddings during inference.

…ional untie/fuse in run()

- Move _untie_word_embeddings() to unconditional top of silent_run() (was missing entirely)
- Move _untie_word_embeddings() and _apply_fused_ln() outside 'if R1' condition in run()
- Remove duplicate _untie_word_embeddings() call from _apply_fused_ln()

This fixes embedding lookup table corruption when tie_word_embeddings=True
(affects Qwen3 and all models with tied embeddings).
@gavingavin99

Copy link
Copy Markdown
Collaborator

In Spinquant, only the R1 section should require the fuse layernorm operation, which is why untie_embedding needs to be executed. Verification has shown that fuse layernorm can negatively impact model performance to some extent (due to the use of low-precision weights), so only untie_embedding and fuse layernorm are placed in the R1 branch.

@sunnyxiaohu

Copy link
Copy Markdown
Contributor Author

@gavingavin99 That's true — the precision loss doesn't only come from fused LayerNorm, but also from other rotation matrix multiplications. For the Qwen3-Omni model, we experimented with SpinQuant (R1, R2, R4) + GPTQ, and compared to GPTQ alone, the performance actually degraded significantly. Do you have any suggestions?

@gavingavin99

gavingavin99 commented Jun 10, 2026

Copy link
Copy Markdown
Collaborator

@sunnyxiaohu Could you show us some relevant test result? If you are using rotation alone, it is recommended to disable R1 rotation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants