Skip to content

Commit 321d751

Browse files
committed
[quantization] Support SpinQuant
This commit supports SpinQuant. TICO-DCO-1.0-Signed-off-by: seongwoo <mhs4670go@naver.com>
1 parent 6c9e38f commit 321d751

12 files changed

Lines changed: 2709 additions & 11 deletions

File tree

Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import torch
18+
19+
from tico.quantization import convert, prepare
20+
from tico.quantization.algorithm.spinquant.quantizer import SpinQuantQuantizer
21+
from tico.quantization.algorithm.spinquant.spin_llama import SpinLlamaForCausalLM
22+
from tico.quantization.config.spinquant import SpinQuantConfig
23+
from transformers.models.llama.configuration_llama import LlamaConfig
24+
from transformers.models.llama.modeling_llama import LlamaForCausalLM
25+
26+
27+
class SpinQuantTest(unittest.TestCase):
28+
def _build_llama_model(
29+
self,
30+
*,
31+
hidden_size: int = 32,
32+
intermediate_size: int = 64,
33+
num_hidden_layers: int = 2,
34+
num_attention_heads: int = 4,
35+
vocab_size: int = 64,
36+
tie_word_embeddings: bool = True,
37+
) -> LlamaForCausalLM:
38+
"""
39+
Build a small LLaMA model for unit tests.
40+
41+
Parameters:
42+
hidden_size: Hidden dimension.
43+
intermediate_size: MLP intermediate dimension.
44+
num_hidden_layers: Number of decoder layers.
45+
num_attention_heads: Number of attention heads.
46+
vocab_size: Vocabulary size.
47+
tie_word_embeddings: Whether to tie embedding and lm_head weights.
48+
49+
Returns:
50+
A small LlamaForCausalLM instance.
51+
"""
52+
config = LlamaConfig(
53+
vocab_size=vocab_size,
54+
hidden_size=hidden_size,
55+
intermediate_size=intermediate_size,
56+
num_hidden_layers=num_hidden_layers,
57+
num_attention_heads=num_attention_heads,
58+
num_key_value_heads=num_attention_heads,
59+
max_position_embeddings=64,
60+
tie_word_embeddings=tie_word_embeddings,
61+
pad_token_id=0,
62+
bos_token_id=1,
63+
eos_token_id=2,
64+
)
65+
model = LlamaForCausalLM(config)
66+
model.eval()
67+
return model
68+
69+
def _clone_state_dict(self, model: torch.nn.Module) -> dict[str, torch.Tensor]:
70+
"""
71+
Clone a model state_dict into detached tensors.
72+
73+
Parameters:
74+
model: Source model.
75+
76+
Returns:
77+
A copied state_dict.
78+
"""
79+
return {k: v.detach().clone() for k, v in model.state_dict().items()}
80+
81+
def _assert_identity_linear(self, layer: torch.nn.Linear) -> None:
82+
"""
83+
Assert that a linear layer is initialized as identity.
84+
85+
Parameters:
86+
layer: Target linear layer.
87+
"""
88+
self.assertEqual(layer.in_features, layer.out_features)
89+
expected = torch.eye(
90+
layer.in_features,
91+
device=layer.weight.device,
92+
dtype=layer.weight.dtype,
93+
)
94+
self.assertTrue(torch.allclose(layer.weight, expected))
95+
if layer.bias is not None:
96+
self.assertTrue(torch.allclose(layer.bias, torch.zeros_like(layer.bias)))
97+
98+
def test_spinquant_config_validate_accepts_random(self):
99+
cfg = SpinQuantConfig(init_method="random")
100+
self.assertEqual(cfg.init_method, "random")
101+
self.assertEqual(cfg.name, "spinquant")
102+
103+
def test_spinquant_config_validate_accepts_hadamard(self):
104+
cfg = SpinQuantConfig(init_method="hadamard")
105+
self.assertEqual(cfg.init_method, "hadamard")
106+
107+
def test_spinquant_config_validate_requires_r1_for_external(self):
108+
with self.assertRaises(ValueError):
109+
SpinQuantConfig(init_method="external")
110+
111+
def test_spinquant_config_validate_rejects_non_tensor_r1(self):
112+
with self.assertRaises(ValueError):
113+
SpinQuantConfig(init_method="random", r1="invalid") # type: ignore[arg-type]
114+
115+
def test_spinquant_config_validate_rejects_non_tensor_r2(self):
116+
with self.assertRaises(ValueError):
117+
SpinQuantConfig(
118+
init_method="random",
119+
r2_map={"model.layers.0.self_attn.R2": "invalid"}, # type: ignore[dict-item]
120+
)
121+
122+
@torch.inference_mode()
123+
def test_prepare_converts_llama_to_spin_llama(self):
124+
model = self._build_llama_model()
125+
q_m = prepare(model, SpinQuantConfig())
126+
127+
self.assertIsInstance(q_m, SpinLlamaForCausalLM)
128+
self.assertTrue(hasattr(q_m.model, "rotate_embedding"))
129+
self.assertTrue(hasattr(q_m, "rotate_lm_head"))
130+
self._assert_identity_linear(q_m.model.rotate_embedding)
131+
self._assert_identity_linear(q_m.rotate_lm_head)
132+
133+
@torch.inference_mode()
134+
def test_prepare_preserves_generation_related_attributes(self):
135+
model = self._build_llama_model()
136+
model.name_or_path = "dummy-llama"
137+
model._keep_in_fp32_modules = {"lm_head"}
138+
139+
q_m = prepare(model, SpinQuantConfig())
140+
141+
self.assertEqual(q_m.name_or_path, "dummy-llama")
142+
self.assertEqual(q_m._keep_in_fp32_modules, {"lm_head"})
143+
self.assertIs(q_m.config, model.config)
144+
145+
@torch.inference_mode()
146+
def test_prepare_preserves_original_weights_before_convert(self):
147+
model = self._build_llama_model()
148+
original_state = self._clone_state_dict(model)
149+
150+
q_m = prepare(model, SpinQuantConfig())
151+
152+
# Check that original model weights are copied into the converted model.
153+
self.assertTrue(
154+
torch.allclose(
155+
q_m.model.embed_tokens.weight,
156+
original_state["model.embed_tokens.weight"],
157+
)
158+
)
159+
self.assertTrue(
160+
torch.allclose(
161+
q_m.lm_head.weight,
162+
original_state["lm_head.weight"],
163+
)
164+
)
165+
self.assertTrue(
166+
torch.allclose(
167+
q_m.model.layers[0].self_attn.q_proj.weight,
168+
original_state["model.layers.0.self_attn.q_proj.weight"],
169+
)
170+
)
171+
172+
@torch.inference_mode()
173+
def test_prepare_preserves_tied_embedding_sharing(self):
174+
model = self._build_llama_model(tie_word_embeddings=True)
175+
self.assertIs(model.model.embed_tokens.weight, model.lm_head.weight)
176+
177+
q_m = prepare(model, SpinQuantConfig())
178+
179+
# Check that the converted model still uses tied weights.
180+
self.assertIs(q_m.model.embed_tokens.weight, q_m.lm_head.weight)
181+
182+
@torch.inference_mode()
183+
def test_convert_changes_decoder_weights(self):
184+
model = self._build_llama_model()
185+
q_m = prepare(model, SpinQuantConfig(init_method="random"))
186+
187+
before_q = q_m.model.layers[0].self_attn.q_proj.weight.detach().clone()
188+
before_o = q_m.model.layers[0].self_attn.o_proj.weight.detach().clone()
189+
before_gate = q_m.model.layers[0].mlp.gate_proj.weight.detach().clone()
190+
before_down = q_m.model.layers[0].mlp.down_proj.weight.detach().clone()
191+
192+
q_m = convert(q_m)
193+
194+
after_q = q_m.model.layers[0].self_attn.q_proj.weight
195+
after_o = q_m.model.layers[0].self_attn.o_proj.weight
196+
after_gate = q_m.model.layers[0].mlp.gate_proj.weight
197+
after_down = q_m.model.layers[0].mlp.down_proj.weight
198+
199+
self.assertFalse(torch.allclose(before_q, after_q))
200+
self.assertFalse(torch.allclose(before_o, after_o))
201+
self.assertFalse(torch.allclose(before_gate, after_gate))
202+
self.assertFalse(torch.allclose(before_down, after_down))
203+
204+
@torch.inference_mode()
205+
def test_convert_updates_rotation_side_layers(self):
206+
model = self._build_llama_model()
207+
q_m = prepare(model, SpinQuantConfig(init_method="random"))
208+
209+
before_embed_rot = q_m.model.rotate_embedding.weight.detach().clone()
210+
before_lm_head_rot = q_m.rotate_lm_head.weight.detach().clone()
211+
212+
q_m = convert(q_m)
213+
214+
after_embed_rot = q_m.model.rotate_embedding.weight
215+
after_lm_head_rot = q_m.rotate_lm_head.weight
216+
217+
self.assertFalse(torch.allclose(before_embed_rot, after_embed_rot))
218+
self.assertFalse(torch.allclose(before_lm_head_rot, after_lm_head_rot))
219+
220+
@torch.inference_mode()
221+
def test_convert_resets_folded_layer_norms_to_identity(self):
222+
model = self._build_llama_model()
223+
q_m = prepare(model, SpinQuantConfig(init_method="random"))
224+
q_m = convert(q_m)
225+
226+
for layer in q_m.model.layers:
227+
self.assertTrue(
228+
torch.allclose(
229+
layer.input_layernorm.weight,
230+
torch.ones_like(layer.input_layernorm.weight),
231+
)
232+
)
233+
self.assertTrue(
234+
torch.allclose(
235+
layer.post_attention_layernorm.weight,
236+
torch.ones_like(layer.post_attention_layernorm.weight),
237+
)
238+
)
239+
240+
self.assertTrue(
241+
torch.allclose(
242+
q_m.model.norm.weight,
243+
torch.ones_like(q_m.model.norm.weight),
244+
)
245+
)
246+
247+
@torch.inference_mode()
248+
def test_quantizer_convert_with_external_identity_r1(self):
249+
model = self._build_llama_model(hidden_size=32, num_attention_heads=4)
250+
251+
hidden_size = model.config.hidden_size
252+
head_dim = hidden_size // model.config.num_attention_heads
253+
254+
r1 = torch.eye(hidden_size, dtype=torch.float64)
255+
r2_map = {
256+
f"model.layers.{idx}.self_attn.R2": torch.eye(head_dim, dtype=torch.float64)
257+
for idx in range(model.config.num_hidden_layers)
258+
}
259+
260+
quantizer = SpinQuantQuantizer(
261+
SpinQuantConfig(
262+
init_method="external",
263+
r1=r1,
264+
r2_map=r2_map,
265+
)
266+
)
267+
268+
q_m = quantizer.prepare(model)
269+
q_m = quantizer.convert(q_m)
270+
271+
expected_embed_rot = torch.eye(
272+
hidden_size,
273+
device=q_m.model.rotate_embedding.weight.device,
274+
dtype=q_m.model.rotate_embedding.weight.dtype,
275+
)
276+
self.assertTrue(
277+
torch.allclose(q_m.model.rotate_embedding.weight, expected_embed_rot)
278+
)
279+
280+
# The final norm scale is folded into rotate_lm_head, so this layer should
281+
# become a diagonal matrix equal to the original final norm weights.
282+
expected_lm_rot = torch.diag(
283+
model.model.norm.weight.detach().to(
284+
device=q_m.rotate_lm_head.weight.device,
285+
dtype=q_m.rotate_lm_head.weight.dtype,
286+
)
287+
)
288+
self.assertTrue(torch.allclose(q_m.rotate_lm_head.weight, expected_lm_rot))
289+
290+
def test_quantizer_prepare_rejects_non_module(self):
291+
quantizer = SpinQuantQuantizer(SpinQuantConfig())
292+
with self.assertRaises(TypeError):
293+
quantizer.prepare("not a module") # type: ignore[arg-type]
294+
295+
def test_quantizer_prepare_rejects_non_llama_model_type(self):
296+
class DummyConfig:
297+
model_type = "not_llama"
298+
299+
class DummyModel(torch.nn.Module):
300+
def __init__(self):
301+
super().__init__()
302+
self.config = DummyConfig()
303+
self.model = torch.nn.Module()
304+
self.lm_head = torch.nn.Linear(4, 4, bias=False)
305+
306+
quantizer = SpinQuantQuantizer(SpinQuantConfig())
307+
with self.assertRaises(ValueError):
308+
quantizer.prepare(DummyModel())
309+
310+
@torch.inference_mode()
311+
def test_forward_runs_after_spinquant_prepare_and_convert(self):
312+
model = self._build_llama_model()
313+
q_m = prepare(model, SpinQuantConfig(init_method="random"))
314+
q_m = convert(q_m)
315+
316+
input_ids = torch.tensor([[1, 2, 3, 4]], dtype=torch.long)
317+
outputs = q_m(input_ids=input_ids)
318+
319+
self.assertEqual(outputs.logits.shape[0], 1)
320+
self.assertEqual(outputs.logits.shape[1], 4)
321+
self.assertEqual(outputs.logits.shape[2], q_m.config.vocab_size)

0 commit comments

Comments
 (0)