-
Notifications
You must be signed in to change notification settings - Fork 26
[quantization] Fix truthfulqa
#620
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -151,7 +151,6 @@ def save_layers_to(q_m, max_seq_len, save_layers_to_folder): | |
|
|
||
| def quantize_using_PTQ(q_m, calib_inputs, args): | ||
| print("Wrapping layers with PTQWrapper …") | ||
|
|
||
| qcfg = build_llm_ptq_config( | ||
| model_type="llama", | ||
| num_hidden_layers=len(q_m.model.layers), | ||
|
|
@@ -187,7 +186,16 @@ def quantize_using_PTQ(q_m, calib_inputs, args): | |
| device = torch.device(args.device) | ||
| with torch.no_grad(): | ||
| for inp in tqdm.tqdm(calib_inputs): | ||
| q_m(inp.to(device)) | ||
| if args.calibrate_use_cache: | ||
| outputs = q_m(inp[..., :-1].to(device), use_cache=True) | ||
| # TODO add padding? | ||
| q_m( | ||
| inp[..., -1:].to(device), | ||
| past_key_values=outputs.past_key_values, | ||
| use_cache=True, | ||
| ) | ||
| else: | ||
| q_m(inp.to(device)) | ||
|
|
||
| # Freeze all Q-params (scale, zero-point) | ||
| q_m = convert(q_m) | ||
|
|
@@ -341,6 +349,12 @@ def main(): | |
| type=str, | ||
| default=None, | ||
| ) | ||
| parser.add_argument( | ||
| "--calibrate_use_cache", | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure about the option name, bit it should be used only in PTQ quantizer calibration. |
||
| action="store_true", | ||
| default=False, | ||
| help="Calibrate using cache (e.g. for PTQ-model evaluation on `truthfulqa` benchmark)", | ||
| ) | ||
| args = parser.parse_args() | ||
| print(args) | ||
|
|
||
|
|
@@ -383,7 +397,7 @@ def main(): | |
| else: | ||
| print("Skipping SpinQuant preprocessing …") | ||
|
|
||
| model.config.use_cache = False # TODO use args for it | ||
| model.config.use_cache = False | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it occured that GPTQ is not very kv-cache friendly, as other possible techniques. Right now it's needed only in PTQ. |
||
| if args.calibrate_seq_len is not None: | ||
| model.config.max_position_embeddings = min( | ||
| model.config.max_position_embeddings, args.calibrate_seq_len | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -205,10 +205,16 @@ def _concat_kv( | |
| if past is None: | ||
| return k_new, v_new | ||
| past_k, past_v = past | ||
| past_k = self._fq(past_k, self.obs_past_key) | ||
| past_v = self._fq(past_v, self.obs_past_value) | ||
| k = torch.cat([past_k[:, kv_idx, :, :], k_new], dim=1) | ||
| v = torch.cat([past_v[:, kv_idx, :, :], v_new], dim=1) | ||
| if past_k is not None: | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| past_k = self._fq(past_k, self.obs_past_key) | ||
| k = torch.cat([past_k[:, kv_idx, :, :], k_new], dim=1) | ||
| else: | ||
| k = k_new | ||
| if past_v is not None: | ||
| past_v = self._fq(past_v, self.obs_past_value) | ||
| v = torch.cat([past_v[:, kv_idx, :, :], v_new], dim=1) | ||
| else: | ||
| v = v_new | ||
| return k, v | ||
|
|
||
| def _build_attention_mask( | ||
|
|
@@ -228,7 +234,11 @@ def _build_attention_mask( | |
| - additive mask: use as-is. | ||
| """ | ||
| q_len = hidden_states.size(1) | ||
| past_len = 0 if past_key_value is None else int(past_key_value[0].shape[2]) | ||
| past_len = ( | ||
| 0 | ||
| if (past_key_value is None or past_key_value[0] is None) | ||
| else int(past_key_value[0].shape[2]) | ||
| ) | ||
| k_len = past_len + q_len | ||
|
|
||
| if attention_mask is None: | ||
|
|
@@ -267,7 +277,11 @@ def forward( | |
| cos = self._fq(cos, self.obs_cos) | ||
| sin = self._fq(sin, self.obs_sin) | ||
|
|
||
| past_len = 0 if past_key_value is None else int(past_key_value[0].shape[2]) | ||
| past_len = ( | ||
| 0 | ||
| if (past_key_value is None or past_key_value[0] is None) | ||
| else int(past_key_value[0].shape[2]) | ||
| ) | ||
| key_len = past_len + S | ||
|
|
||
| attn_mask = self._build_attention_mask( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -143,26 +143,89 @@ def __init__( | |
| self.register_buffer("rope_cos_template", cos_t, persistent=False) | ||
| self.register_buffer("rope_sin_template", sin_t, persistent=False) | ||
|
|
||
| def _slice_causal(self, seq_len: int, device: torch.device) -> torch.Tensor: | ||
| """Return `[1,1,L,L]` causal mask slice on *device*.""" | ||
| assert isinstance(self.causal_mask_template, torch.Tensor) | ||
| return self.causal_mask_template[..., :seq_len, :seq_len].to(device) | ||
|
|
||
| def get_attention_mask_for(self, x): | ||
| L = x.size(1) | ||
| attention_mask = self._slice_causal(L, x.device) | ||
| return attention_mask | ||
| def _slice_rope( | ||
| self, | ||
| *, | ||
| start: int, | ||
| seq_len: int, | ||
| device: torch.device, | ||
| dtype: torch.dtype, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| assert isinstance(self.rope_cos_template, torch.Tensor) | ||
| assert isinstance(self.rope_sin_template, torch.Tensor) | ||
| end = start + seq_len | ||
| cos = self.rope_cos_template[:, start:end, :].to(device=device, dtype=dtype) | ||
| sin = self.rope_sin_template[:, start:end, :].to(device=device, dtype=dtype) | ||
| return cos, sin | ||
|
|
||
| def _normalize_position_embeddings( | ||
| self, | ||
| *, | ||
| hidden_states: torch.Tensor, | ||
| position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]], | ||
| past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]], | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| if position_embeddings is None: | ||
| q_len = hidden_states.size(1) | ||
| past_len = 0 if past_key_value is None else int(past_key_value[0].shape[2]) | ||
| cos, sin = self._slice_rope( | ||
| start=past_len, | ||
| seq_len=q_len, | ||
| device=hidden_states.device, | ||
| dtype=hidden_states.dtype, | ||
| ) | ||
| else: | ||
| cos, sin = position_embeddings | ||
| return self._fq(cos, self.obs_cos), self._fq(sin, self.obs_sin) | ||
|
|
||
| def get_position_embeddings_for(self, hidden_states): | ||
| return ( | ||
| self.rope_cos_template.to( | ||
| dtype=hidden_states.dtype, device=hidden_states.device | ||
| ), | ||
| self.rope_sin_template.to( | ||
| dtype=hidden_states.dtype, device=hidden_states.device | ||
| ), | ||
| def _normalize_attention_mask( | ||
| self, | ||
| *, | ||
| hidden_states: torch.Tensor, | ||
| attention_mask: Optional[torch.Tensor], | ||
| past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]], | ||
| device: torch.device, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Return an additive attention mask usable with per-head logits. | ||
|
|
||
| Supported cases: | ||
| - None: build a causal mask slice. | ||
| - bool mask: convert to additive mask using 0 / -120. | ||
| - additive mask: use as-is. | ||
| """ | ||
| seq_len = hidden_states.size(1) | ||
| past_len = ( | ||
| 0 | ||
| if (past_key_value is None or past_key_value[0] is None) | ||
| else int(past_key_value[0].shape[2]) | ||
| ) | ||
|
|
||
| if attention_mask is None: | ||
| assert isinstance(self.causal_mask_template, torch.Tensor) | ||
| mask = self.causal_mask_template[ | ||
| ..., past_len : past_len + seq_len, : past_len + seq_len | ||
| ].to(device) | ||
| return mask.squeeze(0) | ||
|
|
||
| if attention_mask.dtype == torch.bool or attention_mask.dtype == torch.int64: | ||
| if attention_mask.dtype == torch.int64: | ||
| attention_mask = attention_mask == 1 # convert to bool | ||
| mask = self.causal_mask_template[ | ||
| ..., past_len : past_len + seq_len, : past_len + seq_len | ||
| ].to( | ||
| device | ||
| ) # so for q_len == 1 mask will be the last row of causal_mask_template | ||
| # only padding which is assumed to change causal_mask | ||
| additive = torch.zeros_like(attention_mask, dtype=torch.float32) | ||
| additive = additive.masked_fill(~attention_mask, float("-120")) | ||
| mask = torch.max( | ||
| torch.tensor(float("-120")).to(device), additive + mask | ||
| ) # so -120-120->-120, -120+0->-120, 0-120->-120, 0+0->0 | ||
| return mask.squeeze(0) | ||
|
|
||
| return attention_mask | ||
|
|
||
| def forward( | ||
| self, | ||
| input_ids: torch.LongTensor = None, | ||
|
|
@@ -202,52 +265,68 @@ def forward( | |
| inputs_embeds = self.embed_tokens(input_ids) | ||
|
|
||
| if use_cache and past_key_values is None: | ||
| past_key_values = DynamicCache() | ||
|
|
||
| if cache_position is None: | ||
| past_seen_tokens = ( | ||
| past_key_values.get_seq_length() if past_key_values is not None else 0 | ||
| ) | ||
| cache_position = torch.arange( | ||
| past_seen_tokens, | ||
| past_seen_tokens + inputs_embeds.shape[1], | ||
| device=inputs_embeds.device, | ||
| ) | ||
|
|
||
| if position_ids is None: | ||
| position_ids = cache_position.unsqueeze(0) | ||
| past_key_values = [] | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| present_key_values = [] | ||
| hidden_states = inputs_embeds | ||
|
|
||
| # Apply the SpinQuant rotation only when the source model provides it. | ||
| if self.rotate_embedding is not None: | ||
| hidden_states = self.rotate_embedding(hidden_states) | ||
|
|
||
| past_key_value = None # sample kv-cache to infer past_seq_len | ||
| if past_key_values is not None: | ||
| if isinstance(past_key_values, DynamicCache): | ||
| if past_key_values.layers[0].keys is not None: | ||
| past_key_value = ( | ||
| past_key_values.layers[0].keys, | ||
| past_key_values.layers[0].values, | ||
| ) | ||
| elif len(past_key_values) > 0: | ||
| past_key_value = past_key_values[0] | ||
|
|
||
| # create position_embeddings and causal_mask to be shared across all the decoder layers | ||
| causal_mask = self.get_attention_mask_for(hidden_states) | ||
| causal_mask = causal_mask.squeeze(0) | ||
| causal_mask = self._normalize_attention_mask( | ||
| hidden_states=hidden_states, | ||
| attention_mask=attention_mask, | ||
| past_key_value=past_key_value, | ||
| device=hidden_states.device, | ||
| ) | ||
| causal_mask = self._fq(causal_mask, self.obs_causal_mask) | ||
|
|
||
| position_embeddings = self.get_position_embeddings_for(hidden_states) | ||
| cos, sin = position_embeddings | ||
| position_embeddings = ( | ||
| self._fq(cos[:, : hidden_states.size(1), :], self.obs_cos), | ||
| self._fq(sin[:, : hidden_states.size(1), :], self.obs_sin), | ||
| position_embeddings = self._normalize_position_embeddings( | ||
| hidden_states=hidden_states, | ||
| position_embeddings=None, | ||
| past_key_value=past_key_value, | ||
| ) | ||
|
|
||
| # decoder layers | ||
| all_hidden_states = () if output_hidden_states else None | ||
| all_self_attns = () if output_attentions else None | ||
|
|
||
| for decoder_layer in self.layers[: self.config.num_hidden_layers]: | ||
| for idx, decoder_layer in enumerate( | ||
| self.layers[: self.config.num_hidden_layers] | ||
| ): | ||
| if output_hidden_states: | ||
| all_hidden_states += (hidden_states,) # type: ignore[operator] | ||
|
|
||
| if use_cache is True: | ||
| if isinstance(past_key_values, DynamicCache): | ||
| past_key_value = ( | ||
| past_key_values.layers[idx].keys, | ||
| past_key_values.layers[idx].values, | ||
| ) | ||
| else: | ||
| past_key_value = ( | ||
| past_key_values[idx] if idx < len(past_key_values) else None | ||
| ) | ||
| else: | ||
| past_key_value = None | ||
|
|
||
| layer_outputs = decoder_layer( | ||
| hidden_states, | ||
| attention_mask=causal_mask, | ||
| position_ids=position_ids, | ||
| past_key_value=past_key_values, | ||
| past_key_value=past_key_value, | ||
| output_attentions=output_attentions, | ||
| use_cache=use_cache, | ||
| cache_position=cache_position, | ||
|
|
@@ -257,6 +336,16 @@ def forward( | |
|
|
||
| if decoder_layer.wrapped.return_type == "tuple": | ||
| hidden_states = layer_outputs[0] | ||
| elif use_cache is True: | ||
| hidden_states = layer_outputs[0] | ||
| if isinstance(past_key_values, DynamicCache): | ||
| past_key_values.update( | ||
| layer_outputs[1][0], layer_outputs[1][1], layer_idx=idx | ||
| ) | ||
| else: | ||
| present_key_values.append( | ||
| (layer_outputs[1][0], layer_outputs[1][1]) | ||
| ) | ||
| else: | ||
| hidden_states = layer_outputs | ||
|
|
||
|
|
@@ -271,7 +360,11 @@ def forward( | |
|
|
||
| output = BaseModelOutputWithPast( | ||
| last_hidden_state=hidden_states, | ||
| past_key_values=past_key_values if use_cache else None, | ||
| past_key_values=( | ||
| past_key_values | ||
| if use_cache and isinstance(past_key_values, DynamicCache) | ||
| else present_key_values | ||
| ), | ||
|
Comment on lines
+363
to
+367
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in case DynamicCache is used it was sent externally for non-export purposes, so let's return updated kv-cache. |
||
| hidden_states=all_hidden_states, | ||
| attentions=all_self_attns, | ||
| ) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,7 @@ | |
| import torch.nn as nn | ||
|
|
||
| from transformers.cache_utils import Cache | ||
| from transformers.generation import GenerationMixin | ||
| from transformers.modeling_outputs import CausalLMOutputWithPast | ||
|
|
||
| from tico.quantization.config.ptq import PTQConfig | ||
|
|
@@ -30,7 +31,9 @@ | |
| "transformers.models.llama.modeling_llama.LlamaForCausalLM", | ||
| "tico.quantization.algorithm.spinquant.spin_llama.SpinLlamaForCausalLM", | ||
| ) | ||
| class QuantLlamaForCausalLM(QuantModuleBase): | ||
| class QuantLlamaForCausalLM(QuantModuleBase, GenerationMixin): | ||
| _is_stateful = False | ||
|
|
||
| def __init__( | ||
| self, | ||
| model_fp: nn.Module, | ||
|
|
@@ -78,10 +81,15 @@ def __init__( | |
| self.config = model_fp.config | ||
| self.loss_function = model_fp.loss_function | ||
| self.device = model_fp.device | ||
| self.generation_config = model_fp.generation_config | ||
| self.main_input_name = model_fp.main_input_name | ||
|
|
||
| def tie_weights(self): | ||
| pass | ||
|
|
||
| def is_remote_code(self): | ||
| return False | ||
|
|
||
| def forward( | ||
| self, | ||
| input_ids: torch.LongTensor | None = None, | ||
|
|
@@ -99,6 +107,9 @@ def forward( | |
| output_attentions = self.config.output_attentions | ||
| output_hidden_states = self.config.output_hidden_states | ||
| return_dict = self.config.use_return_dict | ||
| if "return_dict" in kwargs: | ||
| # lm_eval set return_dict explicitely in kwargs | ||
| return_dict = kwargs.pop("return_dict") | ||
|
Comment on lines
109
to
+112
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it occures that |
||
|
|
||
| # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | ||
| outputs = self.model( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to turn on
use_cacheoption to calibrate kv-cache observers (both input and output), that is why prefill-decode model is used. First run produces kv-cache, second run uses it on a single token inference.