diff --git a/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py b/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py index 53ea1759..9f6101e1 100644 --- a/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py +++ b/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py @@ -151,7 +151,6 @@ def save_layers_to(q_m, max_seq_len, save_layers_to_folder): def quantize_using_PTQ(q_m, calib_inputs, args): print("Wrapping layers with PTQWrapper …") - qcfg = build_llm_ptq_config( model_type="llama", num_hidden_layers=len(q_m.model.layers), @@ -187,7 +186,16 @@ def quantize_using_PTQ(q_m, calib_inputs, args): device = torch.device(args.device) with torch.no_grad(): for inp in tqdm.tqdm(calib_inputs): - q_m(inp.to(device)) + if args.calibrate_use_cache: + outputs = q_m(inp[..., :-1].to(device), use_cache=True) + # TODO add padding? + q_m( + inp[..., -1:].to(device), + past_key_values=outputs.past_key_values, + use_cache=True, + ) + else: + q_m(inp.to(device)) # Freeze all Q-params (scale, zero-point) q_m = convert(q_m) @@ -341,6 +349,12 @@ def main(): type=str, default=None, ) + parser.add_argument( + "--calibrate_use_cache", + action="store_true", + default=False, + help="Calibrate using cache (e.g. for PTQ-model evaluation on `truthfulqa` benchmark)", + ) args = parser.parse_args() print(args) @@ -383,7 +397,7 @@ def main(): else: print("Skipping SpinQuant preprocessing …") - model.config.use_cache = False # TODO use args for it + model.config.use_cache = False if args.calibrate_seq_len is not None: model.config.max_position_embeddings = min( model.config.max_position_embeddings, args.calibrate_seq_len diff --git a/tico/quantization/wrapq/wrappers/llama/quant_attention.py b/tico/quantization/wrapq/wrappers/llama/quant_attention.py index 12b4bef4..4096ca1c 100644 --- a/tico/quantization/wrapq/wrappers/llama/quant_attention.py +++ b/tico/quantization/wrapq/wrappers/llama/quant_attention.py @@ -205,10 +205,16 @@ def _concat_kv( if past is None: return k_new, v_new past_k, past_v = past - past_k = self._fq(past_k, self.obs_past_key) - past_v = self._fq(past_v, self.obs_past_value) - k = torch.cat([past_k[:, kv_idx, :, :], k_new], dim=1) - v = torch.cat([past_v[:, kv_idx, :, :], v_new], dim=1) + if past_k is not None: + past_k = self._fq(past_k, self.obs_past_key) + k = torch.cat([past_k[:, kv_idx, :, :], k_new], dim=1) + else: + k = k_new + if past_v is not None: + past_v = self._fq(past_v, self.obs_past_value) + v = torch.cat([past_v[:, kv_idx, :, :], v_new], dim=1) + else: + v = v_new return k, v def _build_attention_mask( @@ -228,7 +234,11 @@ def _build_attention_mask( - additive mask: use as-is. """ q_len = hidden_states.size(1) - past_len = 0 if past_key_value is None else int(past_key_value[0].shape[2]) + past_len = ( + 0 + if (past_key_value is None or past_key_value[0] is None) + else int(past_key_value[0].shape[2]) + ) k_len = past_len + q_len if attention_mask is None: @@ -267,7 +277,11 @@ def forward( cos = self._fq(cos, self.obs_cos) sin = self._fq(sin, self.obs_sin) - past_len = 0 if past_key_value is None else int(past_key_value[0].shape[2]) + past_len = ( + 0 + if (past_key_value is None or past_key_value[0] is None) + else int(past_key_value[0].shape[2]) + ) key_len = past_len + S attn_mask = self._build_attention_mask( diff --git a/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py b/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py index 3cd50744..e5f11da4 100644 --- a/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +++ b/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py @@ -204,7 +204,11 @@ def _normalize_attention_mask( - additive mask: use as-is. """ q_len = hidden_states.size(1) - past_len = 0 if past_key_value is None else int(past_key_value[0].shape[2]) + past_len = ( + 0 + if (past_key_value is None or past_key_value[0] is None) + else int(past_key_value[0].shape[2]) + ) k_len = past_len + q_len if attention_mask is None: @@ -305,7 +309,7 @@ def forward( if use_cache: outputs += (present_key_value,) # type: ignore[assignment] - if self.return_type == "tuple": + if self.return_type == "tuple" or use_cache: return outputs if self.return_type == "tensor": return hidden_states diff --git a/tico/quantization/wrapq/wrappers/llama/quant_model.py b/tico/quantization/wrapq/wrappers/llama/quant_model.py index 7a1696ff..f7a5c43e 100644 --- a/tico/quantization/wrapq/wrappers/llama/quant_model.py +++ b/tico/quantization/wrapq/wrappers/llama/quant_model.py @@ -143,26 +143,89 @@ def __init__( self.register_buffer("rope_cos_template", cos_t, persistent=False) self.register_buffer("rope_sin_template", sin_t, persistent=False) - def _slice_causal(self, seq_len: int, device: torch.device) -> torch.Tensor: - """Return `[1,1,L,L]` causal mask slice on *device*.""" - assert isinstance(self.causal_mask_template, torch.Tensor) - return self.causal_mask_template[..., :seq_len, :seq_len].to(device) - - def get_attention_mask_for(self, x): - L = x.size(1) - attention_mask = self._slice_causal(L, x.device) - return attention_mask + def _slice_rope( + self, + *, + start: int, + seq_len: int, + device: torch.device, + dtype: torch.dtype, + ) -> tuple[torch.Tensor, torch.Tensor]: + assert isinstance(self.rope_cos_template, torch.Tensor) + assert isinstance(self.rope_sin_template, torch.Tensor) + end = start + seq_len + cos = self.rope_cos_template[:, start:end, :].to(device=device, dtype=dtype) + sin = self.rope_sin_template[:, start:end, :].to(device=device, dtype=dtype) + return cos, sin + + def _normalize_position_embeddings( + self, + *, + hidden_states: torch.Tensor, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]], + past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]], + ) -> tuple[torch.Tensor, torch.Tensor]: + if position_embeddings is None: + q_len = hidden_states.size(1) + past_len = 0 if past_key_value is None else int(past_key_value[0].shape[2]) + cos, sin = self._slice_rope( + start=past_len, + seq_len=q_len, + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + else: + cos, sin = position_embeddings + return self._fq(cos, self.obs_cos), self._fq(sin, self.obs_sin) - def get_position_embeddings_for(self, hidden_states): - return ( - self.rope_cos_template.to( - dtype=hidden_states.dtype, device=hidden_states.device - ), - self.rope_sin_template.to( - dtype=hidden_states.dtype, device=hidden_states.device - ), + def _normalize_attention_mask( + self, + *, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]], + device: torch.device, + ) -> torch.Tensor: + """ + Return an additive attention mask usable with per-head logits. + + Supported cases: + - None: build a causal mask slice. + - bool mask: convert to additive mask using 0 / -120. + - additive mask: use as-is. + """ + seq_len = hidden_states.size(1) + past_len = ( + 0 + if (past_key_value is None or past_key_value[0] is None) + else int(past_key_value[0].shape[2]) ) + if attention_mask is None: + assert isinstance(self.causal_mask_template, torch.Tensor) + mask = self.causal_mask_template[ + ..., past_len : past_len + seq_len, : past_len + seq_len + ].to(device) + return mask.squeeze(0) + + if attention_mask.dtype == torch.bool or attention_mask.dtype == torch.int64: + if attention_mask.dtype == torch.int64: + attention_mask = attention_mask == 1 # convert to bool + mask = self.causal_mask_template[ + ..., past_len : past_len + seq_len, : past_len + seq_len + ].to( + device + ) # so for q_len == 1 mask will be the last row of causal_mask_template + # only padding which is assumed to change causal_mask + additive = torch.zeros_like(attention_mask, dtype=torch.float32) + additive = additive.masked_fill(~attention_mask, float("-120")) + mask = torch.max( + torch.tensor(float("-120")).to(device), additive + mask + ) # so -120-120->-120, -120+0->-120, 0-120->-120, 0+0->0 + return mask.squeeze(0) + + return attention_mask + def forward( self, input_ids: torch.LongTensor = None, @@ -202,52 +265,68 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: - past_key_values = DynamicCache() - - if cache_position is None: - past_seen_tokens = ( - past_key_values.get_seq_length() if past_key_values is not None else 0 - ) - cache_position = torch.arange( - past_seen_tokens, - past_seen_tokens + inputs_embeds.shape[1], - device=inputs_embeds.device, - ) - - if position_ids is None: - position_ids = cache_position.unsqueeze(0) + past_key_values = [] + present_key_values = [] hidden_states = inputs_embeds # Apply the SpinQuant rotation only when the source model provides it. if self.rotate_embedding is not None: hidden_states = self.rotate_embedding(hidden_states) + past_key_value = None # sample kv-cache to infer past_seq_len + if past_key_values is not None: + if isinstance(past_key_values, DynamicCache): + if past_key_values.layers[0].keys is not None: + past_key_value = ( + past_key_values.layers[0].keys, + past_key_values.layers[0].values, + ) + elif len(past_key_values) > 0: + past_key_value = past_key_values[0] + # create position_embeddings and causal_mask to be shared across all the decoder layers - causal_mask = self.get_attention_mask_for(hidden_states) - causal_mask = causal_mask.squeeze(0) + causal_mask = self._normalize_attention_mask( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_value=past_key_value, + device=hidden_states.device, + ) causal_mask = self._fq(causal_mask, self.obs_causal_mask) - position_embeddings = self.get_position_embeddings_for(hidden_states) - cos, sin = position_embeddings - position_embeddings = ( - self._fq(cos[:, : hidden_states.size(1), :], self.obs_cos), - self._fq(sin[:, : hidden_states.size(1), :], self.obs_sin), + position_embeddings = self._normalize_position_embeddings( + hidden_states=hidden_states, + position_embeddings=None, + past_key_value=past_key_value, ) - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for idx, decoder_layer in enumerate( + self.layers[: self.config.num_hidden_layers] + ): if output_hidden_states: all_hidden_states += (hidden_states,) # type: ignore[operator] + if use_cache is True: + if isinstance(past_key_values, DynamicCache): + past_key_value = ( + past_key_values.layers[idx].keys, + past_key_values.layers[idx].values, + ) + else: + past_key_value = ( + past_key_values[idx] if idx < len(past_key_values) else None + ) + else: + past_key_value = None + layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -257,6 +336,16 @@ def forward( if decoder_layer.wrapped.return_type == "tuple": hidden_states = layer_outputs[0] + elif use_cache is True: + hidden_states = layer_outputs[0] + if isinstance(past_key_values, DynamicCache): + past_key_values.update( + layer_outputs[1][0], layer_outputs[1][1], layer_idx=idx + ) + else: + present_key_values.append( + (layer_outputs[1][0], layer_outputs[1][1]) + ) else: hidden_states = layer_outputs @@ -271,7 +360,11 @@ def forward( output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=past_key_values if use_cache else None, + past_key_values=( + past_key_values + if use_cache and isinstance(past_key_values, DynamicCache) + else present_key_values + ), hidden_states=all_hidden_states, attentions=all_self_attns, ) diff --git a/tico/quantization/wrapq/wrappers/llama/quant_model_for_causal_lm.py b/tico/quantization/wrapq/wrappers/llama/quant_model_for_causal_lm.py index 1553e42b..37e7ee51 100644 --- a/tico/quantization/wrapq/wrappers/llama/quant_model_for_causal_lm.py +++ b/tico/quantization/wrapq/wrappers/llama/quant_model_for_causal_lm.py @@ -18,6 +18,7 @@ import torch.nn as nn from transformers.cache_utils import Cache +from transformers.generation import GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithPast from tico.quantization.config.ptq import PTQConfig @@ -30,7 +31,9 @@ "transformers.models.llama.modeling_llama.LlamaForCausalLM", "tico.quantization.algorithm.spinquant.spin_llama.SpinLlamaForCausalLM", ) -class QuantLlamaForCausalLM(QuantModuleBase): +class QuantLlamaForCausalLM(QuantModuleBase, GenerationMixin): + _is_stateful = False + def __init__( self, model_fp: nn.Module, @@ -78,10 +81,15 @@ def __init__( self.config = model_fp.config self.loss_function = model_fp.loss_function self.device = model_fp.device + self.generation_config = model_fp.generation_config + self.main_input_name = model_fp.main_input_name def tie_weights(self): pass + def is_remote_code(self): + return False + def forward( self, input_ids: torch.LongTensor | None = None, @@ -99,6 +107,9 @@ def forward( output_attentions = self.config.output_attentions output_hidden_states = self.config.output_hidden_states return_dict = self.config.use_return_dict + if "return_dict" in kwargs: + # lm_eval set return_dict explicitely in kwargs + return_dict = kwargs.pop("return_dict") # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model(