Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ def save_layers_to(q_m, max_seq_len, save_layers_to_folder):

def quantize_using_PTQ(q_m, calib_inputs, args):
print("Wrapping layers with PTQWrapper …")

qcfg = build_llm_ptq_config(
model_type="llama",
num_hidden_layers=len(q_m.model.layers),
Expand Down Expand Up @@ -187,7 +186,16 @@ def quantize_using_PTQ(q_m, calib_inputs, args):
device = torch.device(args.device)
with torch.no_grad():
for inp in tqdm.tqdm(calib_inputs):
q_m(inp.to(device))
if args.calibrate_use_cache:
outputs = q_m(inp[..., :-1].to(device), use_cache=True)
# TODO add padding?
q_m(
inp[..., -1:].to(device),
past_key_values=outputs.past_key_values,
use_cache=True,
)
Comment on lines +190 to +196
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to turn on use_cache option to calibrate kv-cache observers (both input and output), that is why prefill-decode model is used. First run produces kv-cache, second run uses it on a single token inference.

else:
q_m(inp.to(device))

# Freeze all Q-params (scale, zero-point)
q_m = convert(q_m)
Expand Down Expand Up @@ -341,6 +349,12 @@ def main():
type=str,
default=None,
)
parser.add_argument(
"--calibrate_use_cache",
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure about the option name, bit it should be used only in PTQ quantizer calibration.
Any other suggestions are welcome.

action="store_true",
default=False,
help="Calibrate using cache (e.g. for PTQ-model evaluation on `truthfulqa` benchmark)",
)
args = parser.parse_args()
print(args)

Expand Down Expand Up @@ -383,7 +397,7 @@ def main():
else:
print("Skipping SpinQuant preprocessing …")

model.config.use_cache = False # TODO use args for it
model.config.use_cache = False
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it occured that GPTQ is not very kv-cache friendly, as other possible techniques. Right now it's needed only in PTQ.

if args.calibrate_seq_len is not None:
model.config.max_position_embeddings = min(
model.config.max_position_embeddings, args.calibrate_seq_len
Expand Down
26 changes: 20 additions & 6 deletions tico/quantization/wrapq/wrappers/llama/quant_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,16 @@ def _concat_kv(
if past is None:
return k_new, v_new
past_k, past_v = past
past_k = self._fq(past_k, self.obs_past_key)
past_v = self._fq(past_v, self.obs_past_value)
k = torch.cat([past_k[:, kv_idx, :, :], k_new], dim=1)
v = torch.cat([past_v[:, kv_idx, :, :], v_new], dim=1)
if past_k is not None:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DynamicCache tuple can be None, but the values of the tuple can be None.

past_k = self._fq(past_k, self.obs_past_key)
k = torch.cat([past_k[:, kv_idx, :, :], k_new], dim=1)
else:
k = k_new
if past_v is not None:
past_v = self._fq(past_v, self.obs_past_value)
v = torch.cat([past_v[:, kv_idx, :, :], v_new], dim=1)
else:
v = v_new
return k, v

def _build_attention_mask(
Expand All @@ -228,7 +234,11 @@ def _build_attention_mask(
- additive mask: use as-is.
"""
q_len = hidden_states.size(1)
past_len = 0 if past_key_value is None else int(past_key_value[0].shape[2])
past_len = (
0
if (past_key_value is None or past_key_value[0] is None)
else int(past_key_value[0].shape[2])
)
k_len = past_len + q_len

if attention_mask is None:
Expand Down Expand Up @@ -267,7 +277,11 @@ def forward(
cos = self._fq(cos, self.obs_cos)
sin = self._fq(sin, self.obs_sin)

past_len = 0 if past_key_value is None else int(past_key_value[0].shape[2])
past_len = (
0
if (past_key_value is None or past_key_value[0] is None)
else int(past_key_value[0].shape[2])
)
key_len = past_len + S

attn_mask = self._build_attention_mask(
Expand Down
8 changes: 6 additions & 2 deletions tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,11 @@ def _normalize_attention_mask(
- additive mask: use as-is.
"""
q_len = hidden_states.size(1)
past_len = 0 if past_key_value is None else int(past_key_value[0].shape[2])
past_len = (
0
if (past_key_value is None or past_key_value[0] is None)
else int(past_key_value[0].shape[2])
)
k_len = past_len + q_len

if attention_mask is None:
Expand Down Expand Up @@ -305,7 +309,7 @@ def forward(
if use_cache:
outputs += (present_key_value,) # type: ignore[assignment]

if self.return_type == "tuple":
if self.return_type == "tuple" or use_cache:
return outputs
if self.return_type == "tensor":
return hidden_states
Expand Down
177 changes: 135 additions & 42 deletions tico/quantization/wrapq/wrappers/llama/quant_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,26 +143,89 @@ def __init__(
self.register_buffer("rope_cos_template", cos_t, persistent=False)
self.register_buffer("rope_sin_template", sin_t, persistent=False)

def _slice_causal(self, seq_len: int, device: torch.device) -> torch.Tensor:
"""Return `[1,1,L,L]` causal mask slice on *device*."""
assert isinstance(self.causal_mask_template, torch.Tensor)
return self.causal_mask_template[..., :seq_len, :seq_len].to(device)

def get_attention_mask_for(self, x):
L = x.size(1)
attention_mask = self._slice_causal(L, x.device)
return attention_mask
def _slice_rope(
self,
*,
start: int,
seq_len: int,
device: torch.device,
dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
assert isinstance(self.rope_cos_template, torch.Tensor)
assert isinstance(self.rope_sin_template, torch.Tensor)
end = start + seq_len
cos = self.rope_cos_template[:, start:end, :].to(device=device, dtype=dtype)
sin = self.rope_sin_template[:, start:end, :].to(device=device, dtype=dtype)
return cos, sin

def _normalize_position_embeddings(
self,
*,
hidden_states: torch.Tensor,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]],
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]],
) -> tuple[torch.Tensor, torch.Tensor]:
if position_embeddings is None:
q_len = hidden_states.size(1)
past_len = 0 if past_key_value is None else int(past_key_value[0].shape[2])
cos, sin = self._slice_rope(
start=past_len,
seq_len=q_len,
device=hidden_states.device,
dtype=hidden_states.dtype,
)
else:
cos, sin = position_embeddings
return self._fq(cos, self.obs_cos), self._fq(sin, self.obs_sin)

def get_position_embeddings_for(self, hidden_states):
return (
self.rope_cos_template.to(
dtype=hidden_states.dtype, device=hidden_states.device
),
self.rope_sin_template.to(
dtype=hidden_states.dtype, device=hidden_states.device
),
def _normalize_attention_mask(
self,
*,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]],
device: torch.device,
) -> torch.Tensor:
"""
Return an additive attention mask usable with per-head logits.

Supported cases:
- None: build a causal mask slice.
- bool mask: convert to additive mask using 0 / -120.
- additive mask: use as-is.
"""
seq_len = hidden_states.size(1)
past_len = (
0
if (past_key_value is None or past_key_value[0] is None)
else int(past_key_value[0].shape[2])
)

if attention_mask is None:
assert isinstance(self.causal_mask_template, torch.Tensor)
mask = self.causal_mask_template[
..., past_len : past_len + seq_len, : past_len + seq_len
].to(device)
return mask.squeeze(0)

if attention_mask.dtype == torch.bool or attention_mask.dtype == torch.int64:
if attention_mask.dtype == torch.int64:
attention_mask = attention_mask == 1 # convert to bool
mask = self.causal_mask_template[
..., past_len : past_len + seq_len, : past_len + seq_len
].to(
device
) # so for q_len == 1 mask will be the last row of causal_mask_template
# only padding which is assumed to change causal_mask
additive = torch.zeros_like(attention_mask, dtype=torch.float32)
additive = additive.masked_fill(~attention_mask, float("-120"))
mask = torch.max(
torch.tensor(float("-120")).to(device), additive + mask
) # so -120-120->-120, -120+0->-120, 0-120->-120, 0+0->0
return mask.squeeze(0)

return attention_mask

def forward(
self,
input_ids: torch.LongTensor = None,
Expand Down Expand Up @@ -202,52 +265,68 @@ def forward(
inputs_embeds = self.embed_tokens(input_ids)

if use_cache and past_key_values is None:
past_key_values = DynamicCache()

if cache_position is None:
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)
past_key_values = []
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DynamicCache seems to be non-exportable to circle, so let's just use list instead.


present_key_values = []
hidden_states = inputs_embeds

# Apply the SpinQuant rotation only when the source model provides it.
if self.rotate_embedding is not None:
hidden_states = self.rotate_embedding(hidden_states)

past_key_value = None # sample kv-cache to infer past_seq_len
if past_key_values is not None:
if isinstance(past_key_values, DynamicCache):
if past_key_values.layers[0].keys is not None:
past_key_value = (
past_key_values.layers[0].keys,
past_key_values.layers[0].values,
)
elif len(past_key_values) > 0:
past_key_value = past_key_values[0]

# create position_embeddings and causal_mask to be shared across all the decoder layers
causal_mask = self.get_attention_mask_for(hidden_states)
causal_mask = causal_mask.squeeze(0)
causal_mask = self._normalize_attention_mask(
hidden_states=hidden_states,
attention_mask=attention_mask,
past_key_value=past_key_value,
device=hidden_states.device,
)
causal_mask = self._fq(causal_mask, self.obs_causal_mask)

position_embeddings = self.get_position_embeddings_for(hidden_states)
cos, sin = position_embeddings
position_embeddings = (
self._fq(cos[:, : hidden_states.size(1), :], self.obs_cos),
self._fq(sin[:, : hidden_states.size(1), :], self.obs_sin),
position_embeddings = self._normalize_position_embeddings(
hidden_states=hidden_states,
position_embeddings=None,
past_key_value=past_key_value,
)

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None

for decoder_layer in self.layers[: self.config.num_hidden_layers]:
for idx, decoder_layer in enumerate(
self.layers[: self.config.num_hidden_layers]
):
if output_hidden_states:
all_hidden_states += (hidden_states,) # type: ignore[operator]

if use_cache is True:
if isinstance(past_key_values, DynamicCache):
past_key_value = (
past_key_values.layers[idx].keys,
past_key_values.layers[idx].values,
)
else:
past_key_value = (
past_key_values[idx] if idx < len(past_key_values) else None
)
else:
past_key_value = None

layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
Expand All @@ -257,6 +336,16 @@ def forward(

if decoder_layer.wrapped.return_type == "tuple":
hidden_states = layer_outputs[0]
elif use_cache is True:
hidden_states = layer_outputs[0]
if isinstance(past_key_values, DynamicCache):
past_key_values.update(
layer_outputs[1][0], layer_outputs[1][1], layer_idx=idx
)
else:
present_key_values.append(
(layer_outputs[1][0], layer_outputs[1][1])
)
else:
hidden_states = layer_outputs

Expand All @@ -271,7 +360,11 @@ def forward(

output = BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
past_key_values=(
past_key_values
if use_cache and isinstance(past_key_values, DynamicCache)
else present_key_values
),
Comment on lines +363 to +367
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in case DynamicCache is used it was sent externally for non-export purposes, so let's return updated kv-cache.
otherwise return new values to be managed by runtime.

hidden_states=all_hidden_states,
attentions=all_self_attns,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch.nn as nn

from transformers.cache_utils import Cache
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast

from tico.quantization.config.ptq import PTQConfig
Expand All @@ -30,7 +31,9 @@
"transformers.models.llama.modeling_llama.LlamaForCausalLM",
"tico.quantization.algorithm.spinquant.spin_llama.SpinLlamaForCausalLM",
)
class QuantLlamaForCausalLM(QuantModuleBase):
class QuantLlamaForCausalLM(QuantModuleBase, GenerationMixin):
_is_stateful = False

def __init__(
self,
model_fp: nn.Module,
Expand Down Expand Up @@ -78,10 +81,15 @@ def __init__(
self.config = model_fp.config
self.loss_function = model_fp.loss_function
self.device = model_fp.device
self.generation_config = model_fp.generation_config
self.main_input_name = model_fp.main_input_name

def tie_weights(self):
pass

def is_remote_code(self):
return False

def forward(
self,
input_ids: torch.LongTensor | None = None,
Expand All @@ -99,6 +107,9 @@ def forward(
output_attentions = self.config.output_attentions
output_hidden_states = self.config.output_hidden_states
return_dict = self.config.use_return_dict
if "return_dict" in kwargs:
# lm_eval set return_dict explicitely in kwargs
return_dict = kwargs.pop("return_dict")
Comment on lines 109 to +112
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it occures that lm_eval uses return_dict set explicitely in kwargs, so we need to erase it from kwargs to avoid duplication.


# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
Expand Down
Loading