fix(spin): fix tied embedding corruption in silent_run() and unconditional untie/fuse in run()#319
Conversation
…ional untie/fuse in run() - Move _untie_word_embeddings() to unconditional top of silent_run() (was missing entirely) - Move _untie_word_embeddings() and _apply_fused_ln() outside 'if R1' condition in run() - Remove duplicate _untie_word_embeddings() call from _apply_fused_ln() This fixes embedding lookup table corruption when tie_word_embeddings=True (affects Qwen3 and all models with tied embeddings).
|
In Spinquant, only the R1 section should require the fuse layernorm operation, which is why untie_embedding needs to be executed. Verification has shown that fuse layernorm can negatively impact model performance to some extent (due to the use of low-precision weights), so only untie_embedding and fuse layernorm are placed in the R1 branch. |
|
@gavingavin99 That's true — the precision loss doesn't only come from fused LayerNorm, but also from other rotation matrix multiplications. For the Qwen3-Omni model, we experimented with SpinQuant (R1, R2, R4) + GPTQ, and compared to GPTQ alone, the performance actually degraded significantly. Do you have any suggestions? |
|
@sunnyxiaohu Could you show us some relevant test result? If you are using rotation alone, it is recommended to disable R1 rotation |
Summary
Fix SpinQuant rotation transform corrupting the embedding lookup table when
tie_word_embeddings=True, affecting all models with tied embeddings (e.g., Qwen3, LLaMA-3).Problems
silent_run()corrupts tied embeddings —silent_run()calls_apply_fused_ln()without first untying word embeddings. Whenlm_head.weightandembed_tokens.weightshare the same underlying tensor (tied),fuse_ln_linear()modifies both simultaneously, corrupting the embedding lookup table.run()gates untie/fuse insideif "R1"branch —_untie_word_embeddings()and_apply_fused_ln()are incorrectly placed inside theif "R1" in self.spin_config.rotationconditional block. Models using only R2 rotation skip these critical steps, leading to either tied embedding corruption or missing norm fusion.Changes
spin.py_untie_word_embeddings()to the top of bothrun()andsilent_run(), unconditionally before any fuse/rotation operationsspin.py_apply_fused_ln()out of theif "R1"branch to execute unconditionallyspin.py_untie_word_embeddings()call inside_apply_fused_ln()Root Cause
When
tie_word_embeddings=True,lm_head.weightis a reference (not a copy) toembed_tokens.weight. Any in-place modification to one affects the other. The fused layer norm operation scaleslm_head.weightin-place, which simultaneously corruptsembed_tokens.weight, causing garbage token embeddings during inference.