Skip to content

Commit 070ae87

Browse files
committed
Refactor and clean up hf_ptq.py
This script has several separate logic and the code of them are entangled, making it really hard to add new features Refactor them so that we separate these logics: 1. sparsity, all logic go to sparsity_main. TODO: we may actually move this logic out to a separate script 2. quantize, all logic go to quantize_main. 2.1 plain quantization with a single quantization format 2.2 auto quantization In the quantization pipeline, separate the pipeline to: 1. model loading 2. calibrate dataset loading 3. pre-quantize processing 4. actual quantize 5. post-quantize processing 6. quantized model export Signed-off-by: Shengliang Xu <[email protected]>
1 parent 53a2dde commit 070ae87

File tree

5 files changed

+551
-407
lines changed

5 files changed

+551
-407
lines changed

examples/llm_ptq/example_utils.py

Lines changed: 50 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,27 @@
2020
import sys
2121
import warnings
2222
from pathlib import Path
23+
from typing import Any
2324

2425
import torch
2526
import transformers
2627
from accelerate import infer_auto_device_map, init_empty_weights
2728
from accelerate.utils import get_max_memory
28-
from transformers import AutoConfig, AutoModelForCausalLM, AutoProcessor, AutoTokenizer
29+
from transformers import (
30+
AutoConfig,
31+
AutoModelForCausalLM,
32+
AutoProcessor,
33+
AutoTokenizer,
34+
ProcessorMixin,
35+
)
2936

3037
try:
3138
from huggingface_hub import snapshot_download
3239
except ImportError:
3340
snapshot_download = None
3441

3542
import modelopt.torch.quantization as mtq
36-
from modelopt.torch.utils.image_processor import MllamaImageProcessor
43+
from modelopt.torch.utils.image_processor import BaseImageProcessor, MllamaImageProcessor
3744

3845
SPECULATIVE_MODEL_LIST = ["Eagle", "Medusa"]
3946

@@ -127,52 +134,50 @@ def build_quant_cfg(
127134
qformat,
128135
kv_cache_qformat,
129136
awq_block_size,
130-
auto_quantize,
131137
model_type,
132138
quant_cfg_choices,
133139
kv_quant_cfg_choices,
134-
):
140+
) -> dict[str, Any]:
135141
quant_cfg = {}
136-
if not auto_quantize:
137-
assert qformat in quant_cfg_choices, (
138-
f"Unsupported quantization format: {qformat} with {kv_cache_qformat} KV cache"
139-
)
142+
assert qformat in quant_cfg_choices, (
143+
f"Unsupported quantization format: {qformat} with {kv_cache_qformat} KV cache"
144+
)
140145

141-
quant_cfg = quant_cfg_choices[qformat]
142-
143-
if "awq" in qformat:
144-
quant_cfg = copy.deepcopy(quant_cfg_choices[qformat])
145-
weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"]
146-
if isinstance(weight_quantizer, list):
147-
weight_quantizer = weight_quantizer[0]
148-
# If awq_block_size argument is provided, update weight_quantizer
149-
if awq_block_size:
150-
weight_quantizer["block_sizes"][-1] = awq_block_size
151-
152-
# Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models
153-
if qformat == "w4a8_awq" and model_type in ["gemma", "mpt"]:
154-
quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1}
155-
156-
enable_quant_kv_cache = kv_cache_qformat != "none"
157-
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")
158-
159-
# Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer.
160-
if enable_quant_kv_cache:
161-
quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant(
162-
quant_cfg,
163-
getattr(mtq, kv_quant_cfg_choices[kv_cache_qformat])["quant_cfg"],
164-
)
146+
quant_cfg = quant_cfg_choices[qformat]
147+
148+
if "awq" in qformat:
149+
quant_cfg = copy.deepcopy(quant_cfg_choices[qformat])
150+
weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"]
151+
if isinstance(weight_quantizer, list):
152+
weight_quantizer = weight_quantizer[0]
153+
# If awq_block_size argument is provided, update weight_quantizer
154+
if awq_block_size:
155+
weight_quantizer["block_sizes"][-1] = awq_block_size
156+
157+
# Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models
158+
if qformat == "w4a8_awq" and model_type in ["gemma", "mpt"]:
159+
quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1}
160+
161+
enable_quant_kv_cache = kv_cache_qformat != "none"
162+
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")
163+
164+
# Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer.
165+
if enable_quant_kv_cache:
166+
quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant(
167+
quant_cfg,
168+
getattr(mtq, kv_quant_cfg_choices[kv_cache_qformat])["quant_cfg"],
169+
)
165170

166-
# Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead.
167-
if model_type == "gemma" and "int8_sq" in qformat:
168-
quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5}
171+
# Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead.
172+
if model_type == "gemma" and "int8_sq" in qformat:
173+
quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5}
169174

170-
if model_type == "phi4mm":
171-
# Only quantize the language model
172-
quant_cfg["quant_cfg"]["*speech*"] = {"enable": False}
173-
quant_cfg["quant_cfg"]["*audio*"] = {"enable": False}
174-
quant_cfg["quant_cfg"]["*image*"] = {"enable": False}
175-
quant_cfg["quant_cfg"]["*vision*"] = {"enable": False}
175+
if model_type == "phi4mm":
176+
# Only quantize the language model
177+
quant_cfg["quant_cfg"]["*speech*"] = {"enable": False}
178+
quant_cfg["quant_cfg"]["*audio*"] = {"enable": False}
179+
quant_cfg["quant_cfg"]["*image*"] = {"enable": False}
180+
quant_cfg["quant_cfg"]["*vision*"] = {"enable": False}
176181

177182
return quant_cfg
178183

@@ -205,8 +210,8 @@ def get_tokenizer(ckpt_path, trust_remote_code=False, **kwargs):
205210

206211

207212
def get_processor(
208-
ckpt_path, model_type, device=None, trust_remote_code=False, attn_implementation=None
209-
):
213+
ckpt_path, model_type, device: str = "auto", trust_remote_code=False, attn_implementation=None
214+
) -> BaseImageProcessor | ProcessorMixin | None:
210215
"""
211216
Returns a :class:`modelopt.torch.utils.image_processor.MllamaImageProcessor` object.
212217
"""
@@ -241,6 +246,8 @@ def get_processor(
241246

242247
return MllamaImageProcessor(processor, device)
243248

249+
return None
250+
244251

245252
def get_dtype(dtype):
246253
if dtype == "bf16":

0 commit comments

Comments
 (0)