|
| 1 | +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import unittest |
| 16 | + |
| 17 | +import torch |
| 18 | + |
| 19 | +from tico.quantization import convert, prepare |
| 20 | +from tico.quantization.algorithm.spinquant.quantizer import SpinQuantQuantizer |
| 21 | +from tico.quantization.algorithm.spinquant.spin_llama import SpinLlamaForCausalLM |
| 22 | +from tico.quantization.config.spinquant import SpinQuantConfig |
| 23 | +from transformers.models.llama.configuration_llama import LlamaConfig |
| 24 | +from transformers.models.llama.modeling_llama import LlamaForCausalLM |
| 25 | + |
| 26 | + |
| 27 | +class SpinQuantTest(unittest.TestCase): |
| 28 | + def _build_llama_model( |
| 29 | + self, |
| 30 | + *, |
| 31 | + hidden_size: int = 32, |
| 32 | + intermediate_size: int = 64, |
| 33 | + num_hidden_layers: int = 2, |
| 34 | + num_attention_heads: int = 4, |
| 35 | + vocab_size: int = 64, |
| 36 | + tie_word_embeddings: bool = True, |
| 37 | + ) -> LlamaForCausalLM: |
| 38 | + """ |
| 39 | + Build a small LLaMA model for unit tests. |
| 40 | +
|
| 41 | + Parameters: |
| 42 | + hidden_size: Hidden dimension. |
| 43 | + intermediate_size: MLP intermediate dimension. |
| 44 | + num_hidden_layers: Number of decoder layers. |
| 45 | + num_attention_heads: Number of attention heads. |
| 46 | + vocab_size: Vocabulary size. |
| 47 | + tie_word_embeddings: Whether to tie embedding and lm_head weights. |
| 48 | +
|
| 49 | + Returns: |
| 50 | + A small LlamaForCausalLM instance. |
| 51 | + """ |
| 52 | + config = LlamaConfig( |
| 53 | + vocab_size=vocab_size, |
| 54 | + hidden_size=hidden_size, |
| 55 | + intermediate_size=intermediate_size, |
| 56 | + num_hidden_layers=num_hidden_layers, |
| 57 | + num_attention_heads=num_attention_heads, |
| 58 | + num_key_value_heads=num_attention_heads, |
| 59 | + max_position_embeddings=64, |
| 60 | + tie_word_embeddings=tie_word_embeddings, |
| 61 | + pad_token_id=0, |
| 62 | + bos_token_id=1, |
| 63 | + eos_token_id=2, |
| 64 | + ) |
| 65 | + model = LlamaForCausalLM(config) |
| 66 | + model.eval() |
| 67 | + return model |
| 68 | + |
| 69 | + def _clone_state_dict(self, model: torch.nn.Module) -> dict[str, torch.Tensor]: |
| 70 | + """ |
| 71 | + Clone a model state_dict into detached tensors. |
| 72 | +
|
| 73 | + Parameters: |
| 74 | + model: Source model. |
| 75 | +
|
| 76 | + Returns: |
| 77 | + A copied state_dict. |
| 78 | + """ |
| 79 | + return {k: v.detach().clone() for k, v in model.state_dict().items()} |
| 80 | + |
| 81 | + def _assert_identity_linear(self, layer: torch.nn.Linear) -> None: |
| 82 | + """ |
| 83 | + Assert that a linear layer is initialized as identity. |
| 84 | +
|
| 85 | + Parameters: |
| 86 | + layer: Target linear layer. |
| 87 | + """ |
| 88 | + self.assertEqual(layer.in_features, layer.out_features) |
| 89 | + expected = torch.eye( |
| 90 | + layer.in_features, |
| 91 | + device=layer.weight.device, |
| 92 | + dtype=layer.weight.dtype, |
| 93 | + ) |
| 94 | + self.assertTrue(torch.allclose(layer.weight, expected)) |
| 95 | + if layer.bias is not None: |
| 96 | + self.assertTrue(torch.allclose(layer.bias, torch.zeros_like(layer.bias))) |
| 97 | + |
| 98 | + def test_spinquant_config_validate_accepts_random(self): |
| 99 | + cfg = SpinQuantConfig(init_method="random") |
| 100 | + self.assertEqual(cfg.init_method, "random") |
| 101 | + self.assertEqual(cfg.name, "spinquant") |
| 102 | + |
| 103 | + def test_spinquant_config_validate_accepts_hadamard(self): |
| 104 | + cfg = SpinQuantConfig(init_method="hadamard") |
| 105 | + self.assertEqual(cfg.init_method, "hadamard") |
| 106 | + |
| 107 | + def test_spinquant_config_validate_requires_r1_for_external(self): |
| 108 | + with self.assertRaises(ValueError): |
| 109 | + SpinQuantConfig(init_method="external") |
| 110 | + |
| 111 | + def test_spinquant_config_validate_rejects_non_tensor_r1(self): |
| 112 | + with self.assertRaises(ValueError): |
| 113 | + SpinQuantConfig(init_method="random", r1="invalid") # type: ignore[arg-type] |
| 114 | + |
| 115 | + def test_spinquant_config_validate_rejects_non_tensor_r2(self): |
| 116 | + with self.assertRaises(ValueError): |
| 117 | + SpinQuantConfig( |
| 118 | + init_method="random", |
| 119 | + r2_map={"model.layers.0.self_attn.R2": "invalid"}, # type: ignore[dict-item] |
| 120 | + ) |
| 121 | + |
| 122 | + @torch.inference_mode() |
| 123 | + def test_prepare_converts_llama_to_spin_llama(self): |
| 124 | + model = self._build_llama_model() |
| 125 | + q_m = prepare(model, SpinQuantConfig()) |
| 126 | + |
| 127 | + self.assertIsInstance(q_m, SpinLlamaForCausalLM) |
| 128 | + self.assertTrue(hasattr(q_m.model, "rotate_embedding")) |
| 129 | + self.assertTrue(hasattr(q_m, "rotate_lm_head")) |
| 130 | + self._assert_identity_linear(q_m.model.rotate_embedding) |
| 131 | + self._assert_identity_linear(q_m.rotate_lm_head) |
| 132 | + |
| 133 | + @torch.inference_mode() |
| 134 | + def test_prepare_preserves_generation_related_attributes(self): |
| 135 | + model = self._build_llama_model() |
| 136 | + model.name_or_path = "dummy-llama" |
| 137 | + model._keep_in_fp32_modules = {"lm_head"} |
| 138 | + |
| 139 | + q_m = prepare(model, SpinQuantConfig()) |
| 140 | + |
| 141 | + self.assertEqual(q_m.name_or_path, "dummy-llama") |
| 142 | + self.assertEqual(q_m._keep_in_fp32_modules, {"lm_head"}) |
| 143 | + self.assertIs(q_m.config, model.config) |
| 144 | + |
| 145 | + @torch.inference_mode() |
| 146 | + def test_prepare_preserves_original_weights_before_convert(self): |
| 147 | + model = self._build_llama_model() |
| 148 | + original_state = self._clone_state_dict(model) |
| 149 | + |
| 150 | + q_m = prepare(model, SpinQuantConfig()) |
| 151 | + |
| 152 | + # Check that original model weights are copied into the converted model. |
| 153 | + self.assertTrue( |
| 154 | + torch.allclose( |
| 155 | + q_m.model.embed_tokens.weight, |
| 156 | + original_state["model.embed_tokens.weight"], |
| 157 | + ) |
| 158 | + ) |
| 159 | + self.assertTrue( |
| 160 | + torch.allclose( |
| 161 | + q_m.lm_head.weight, |
| 162 | + original_state["lm_head.weight"], |
| 163 | + ) |
| 164 | + ) |
| 165 | + self.assertTrue( |
| 166 | + torch.allclose( |
| 167 | + q_m.model.layers[0].self_attn.q_proj.weight, |
| 168 | + original_state["model.layers.0.self_attn.q_proj.weight"], |
| 169 | + ) |
| 170 | + ) |
| 171 | + |
| 172 | + @torch.inference_mode() |
| 173 | + def test_prepare_preserves_tied_embedding_sharing(self): |
| 174 | + model = self._build_llama_model(tie_word_embeddings=True) |
| 175 | + self.assertIs(model.model.embed_tokens.weight, model.lm_head.weight) |
| 176 | + |
| 177 | + q_m = prepare(model, SpinQuantConfig()) |
| 178 | + |
| 179 | + # Check that the converted model still uses tied weights. |
| 180 | + self.assertIs(q_m.model.embed_tokens.weight, q_m.lm_head.weight) |
| 181 | + |
| 182 | + @torch.inference_mode() |
| 183 | + def test_convert_changes_decoder_weights(self): |
| 184 | + model = self._build_llama_model() |
| 185 | + q_m = prepare(model, SpinQuantConfig(init_method="random")) |
| 186 | + |
| 187 | + before_q = q_m.model.layers[0].self_attn.q_proj.weight.detach().clone() |
| 188 | + before_o = q_m.model.layers[0].self_attn.o_proj.weight.detach().clone() |
| 189 | + before_gate = q_m.model.layers[0].mlp.gate_proj.weight.detach().clone() |
| 190 | + before_down = q_m.model.layers[0].mlp.down_proj.weight.detach().clone() |
| 191 | + |
| 192 | + q_m = convert(q_m) |
| 193 | + |
| 194 | + after_q = q_m.model.layers[0].self_attn.q_proj.weight |
| 195 | + after_o = q_m.model.layers[0].self_attn.o_proj.weight |
| 196 | + after_gate = q_m.model.layers[0].mlp.gate_proj.weight |
| 197 | + after_down = q_m.model.layers[0].mlp.down_proj.weight |
| 198 | + |
| 199 | + self.assertFalse(torch.allclose(before_q, after_q)) |
| 200 | + self.assertFalse(torch.allclose(before_o, after_o)) |
| 201 | + self.assertFalse(torch.allclose(before_gate, after_gate)) |
| 202 | + self.assertFalse(torch.allclose(before_down, after_down)) |
| 203 | + |
| 204 | + @torch.inference_mode() |
| 205 | + def test_convert_updates_rotation_side_layers(self): |
| 206 | + model = self._build_llama_model() |
| 207 | + q_m = prepare(model, SpinQuantConfig(init_method="random")) |
| 208 | + |
| 209 | + before_embed_rot = q_m.model.rotate_embedding.weight.detach().clone() |
| 210 | + before_lm_head_rot = q_m.rotate_lm_head.weight.detach().clone() |
| 211 | + |
| 212 | + q_m = convert(q_m) |
| 213 | + |
| 214 | + after_embed_rot = q_m.model.rotate_embedding.weight |
| 215 | + after_lm_head_rot = q_m.rotate_lm_head.weight |
| 216 | + |
| 217 | + self.assertFalse(torch.allclose(before_embed_rot, after_embed_rot)) |
| 218 | + self.assertFalse(torch.allclose(before_lm_head_rot, after_lm_head_rot)) |
| 219 | + |
| 220 | + @torch.inference_mode() |
| 221 | + def test_convert_resets_folded_layer_norms_to_identity(self): |
| 222 | + model = self._build_llama_model() |
| 223 | + q_m = prepare(model, SpinQuantConfig(init_method="random")) |
| 224 | + q_m = convert(q_m) |
| 225 | + |
| 226 | + for layer in q_m.model.layers: |
| 227 | + self.assertTrue( |
| 228 | + torch.allclose( |
| 229 | + layer.input_layernorm.weight, |
| 230 | + torch.ones_like(layer.input_layernorm.weight), |
| 231 | + ) |
| 232 | + ) |
| 233 | + self.assertTrue( |
| 234 | + torch.allclose( |
| 235 | + layer.post_attention_layernorm.weight, |
| 236 | + torch.ones_like(layer.post_attention_layernorm.weight), |
| 237 | + ) |
| 238 | + ) |
| 239 | + |
| 240 | + self.assertTrue( |
| 241 | + torch.allclose( |
| 242 | + q_m.model.norm.weight, |
| 243 | + torch.ones_like(q_m.model.norm.weight), |
| 244 | + ) |
| 245 | + ) |
| 246 | + |
| 247 | + @torch.inference_mode() |
| 248 | + def test_quantizer_convert_with_external_identity_r1(self): |
| 249 | + model = self._build_llama_model(hidden_size=32, num_attention_heads=4) |
| 250 | + |
| 251 | + hidden_size = model.config.hidden_size |
| 252 | + head_dim = hidden_size // model.config.num_attention_heads |
| 253 | + |
| 254 | + r1 = torch.eye(hidden_size, dtype=torch.float64) |
| 255 | + r2_map = { |
| 256 | + f"model.layers.{idx}.self_attn.R2": torch.eye(head_dim, dtype=torch.float64) |
| 257 | + for idx in range(model.config.num_hidden_layers) |
| 258 | + } |
| 259 | + |
| 260 | + quantizer = SpinQuantQuantizer( |
| 261 | + SpinQuantConfig( |
| 262 | + init_method="external", |
| 263 | + r1=r1, |
| 264 | + r2_map=r2_map, |
| 265 | + ) |
| 266 | + ) |
| 267 | + |
| 268 | + q_m = quantizer.prepare(model) |
| 269 | + q_m = quantizer.convert(q_m) |
| 270 | + |
| 271 | + expected_embed_rot = torch.eye( |
| 272 | + hidden_size, |
| 273 | + device=q_m.model.rotate_embedding.weight.device, |
| 274 | + dtype=q_m.model.rotate_embedding.weight.dtype, |
| 275 | + ) |
| 276 | + self.assertTrue( |
| 277 | + torch.allclose(q_m.model.rotate_embedding.weight, expected_embed_rot) |
| 278 | + ) |
| 279 | + |
| 280 | + # The final norm scale is folded into rotate_lm_head, so this layer should |
| 281 | + # become a diagonal matrix equal to the original final norm weights. |
| 282 | + expected_lm_rot = torch.diag( |
| 283 | + model.model.norm.weight.detach().to( |
| 284 | + device=q_m.rotate_lm_head.weight.device, |
| 285 | + dtype=q_m.rotate_lm_head.weight.dtype, |
| 286 | + ) |
| 287 | + ) |
| 288 | + self.assertTrue(torch.allclose(q_m.rotate_lm_head.weight, expected_lm_rot)) |
| 289 | + |
| 290 | + def test_quantizer_prepare_rejects_non_module(self): |
| 291 | + quantizer = SpinQuantQuantizer(SpinQuantConfig()) |
| 292 | + with self.assertRaises(TypeError): |
| 293 | + quantizer.prepare("not a module") # type: ignore[arg-type] |
| 294 | + |
| 295 | + def test_quantizer_prepare_rejects_non_llama_model_type(self): |
| 296 | + class DummyConfig: |
| 297 | + model_type = "not_llama" |
| 298 | + |
| 299 | + class DummyModel(torch.nn.Module): |
| 300 | + def __init__(self): |
| 301 | + super().__init__() |
| 302 | + self.config = DummyConfig() |
| 303 | + self.model = torch.nn.Module() |
| 304 | + self.lm_head = torch.nn.Linear(4, 4, bias=False) |
| 305 | + |
| 306 | + quantizer = SpinQuantQuantizer(SpinQuantConfig()) |
| 307 | + with self.assertRaises(ValueError): |
| 308 | + quantizer.prepare(DummyModel()) |
| 309 | + |
| 310 | + @torch.inference_mode() |
| 311 | + def test_forward_runs_after_spinquant_prepare_and_convert(self): |
| 312 | + model = self._build_llama_model() |
| 313 | + q_m = prepare(model, SpinQuantConfig(init_method="random")) |
| 314 | + q_m = convert(q_m) |
| 315 | + |
| 316 | + input_ids = torch.tensor([[1, 2, 3, 4]], dtype=torch.long) |
| 317 | + outputs = q_m(input_ids=input_ids) |
| 318 | + |
| 319 | + self.assertEqual(outputs.logits.shape[0], 1) |
| 320 | + self.assertEqual(outputs.logits.shape[1], 4) |
| 321 | + self.assertEqual(outputs.logits.shape[2], q_m.config.vocab_size) |
0 commit comments