Skip to content

Commit a89625b

Browse files
committed
Refactor and clean up hf_ptq.py
This script has several separate logic and the code of them are entangled, making it really hard to add new features Refactor them so that we separate these logics: 1. sparsity, all logic go to sparsity_main. TODO: we may actually move this logic out to a separate script 2. quantize, all logic go to quantize_main. 2.1 plain quantization with a single quantization format 2.2 auto quantization In the quantization pipeline, separate the pipeline to: 1. model loading 2. calibrate dataset loading 3. pre-quantize processing 4. actual quantize 5. post-quantize processing 6. quantized model export Signed-off-by: Shengliang Xu <[email protected]>
1 parent 53a2dde commit a89625b

File tree

6 files changed

+557
-412
lines changed

6 files changed

+557
-412
lines changed

examples/llm_ptq/example_utils.py

Lines changed: 56 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,28 @@
2020
import sys
2121
import warnings
2222
from pathlib import Path
23+
from typing import Any
2324

2425
import torch
2526
import transformers
2627
from accelerate import infer_auto_device_map, init_empty_weights
2728
from accelerate.utils import get_max_memory
28-
from transformers import AutoConfig, AutoModelForCausalLM, AutoProcessor, AutoTokenizer
29+
from transformers import (
30+
AutoConfig,
31+
AutoModelForCausalLM,
32+
AutoProcessor,
33+
AutoTokenizer,
34+
PreTrainedTokenizerBase,
35+
ProcessorMixin,
36+
)
2937

3038
try:
3139
from huggingface_hub import snapshot_download
3240
except ImportError:
3341
snapshot_download = None
3442

3543
import modelopt.torch.quantization as mtq
36-
from modelopt.torch.utils.image_processor import MllamaImageProcessor
44+
from modelopt.torch.utils.image_processor import BaseImageProcessor, MllamaImageProcessor
3745

3846
SPECULATIVE_MODEL_LIST = ["Eagle", "Medusa"]
3947

@@ -127,52 +135,50 @@ def build_quant_cfg(
127135
qformat,
128136
kv_cache_qformat,
129137
awq_block_size,
130-
auto_quantize,
131138
model_type,
132139
quant_cfg_choices,
133140
kv_quant_cfg_choices,
134-
):
141+
) -> dict[str, Any]:
135142
quant_cfg = {}
136-
if not auto_quantize:
137-
assert qformat in quant_cfg_choices, (
138-
f"Unsupported quantization format: {qformat} with {kv_cache_qformat} KV cache"
139-
)
143+
assert qformat in quant_cfg_choices, (
144+
f"Unsupported quantization format: {qformat} with {kv_cache_qformat} KV cache"
145+
)
140146

141-
quant_cfg = quant_cfg_choices[qformat]
142-
143-
if "awq" in qformat:
144-
quant_cfg = copy.deepcopy(quant_cfg_choices[qformat])
145-
weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"]
146-
if isinstance(weight_quantizer, list):
147-
weight_quantizer = weight_quantizer[0]
148-
# If awq_block_size argument is provided, update weight_quantizer
149-
if awq_block_size:
150-
weight_quantizer["block_sizes"][-1] = awq_block_size
151-
152-
# Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models
153-
if qformat == "w4a8_awq" and model_type in ["gemma", "mpt"]:
154-
quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1}
155-
156-
enable_quant_kv_cache = kv_cache_qformat != "none"
157-
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")
158-
159-
# Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer.
160-
if enable_quant_kv_cache:
161-
quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant(
162-
quant_cfg,
163-
getattr(mtq, kv_quant_cfg_choices[kv_cache_qformat])["quant_cfg"],
164-
)
147+
quant_cfg = quant_cfg_choices[qformat]
148+
149+
if "awq" in qformat:
150+
quant_cfg = copy.deepcopy(quant_cfg_choices[qformat])
151+
weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"]
152+
if isinstance(weight_quantizer, list):
153+
weight_quantizer = weight_quantizer[0]
154+
# If awq_block_size argument is provided, update weight_quantizer
155+
if awq_block_size:
156+
weight_quantizer["block_sizes"][-1] = awq_block_size
157+
158+
# Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models
159+
if qformat == "w4a8_awq" and model_type in ["gemma", "mpt"]:
160+
quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1}
161+
162+
enable_quant_kv_cache = kv_cache_qformat != "none"
163+
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")
164+
165+
# Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer.
166+
if enable_quant_kv_cache:
167+
quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant(
168+
quant_cfg,
169+
getattr(mtq, kv_quant_cfg_choices[kv_cache_qformat])["quant_cfg"],
170+
)
165171

166-
# Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead.
167-
if model_type == "gemma" and "int8_sq" in qformat:
168-
quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5}
172+
# Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead.
173+
if model_type == "gemma" and "int8_sq" in qformat:
174+
quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5}
169175

170-
if model_type == "phi4mm":
171-
# Only quantize the language model
172-
quant_cfg["quant_cfg"]["*speech*"] = {"enable": False}
173-
quant_cfg["quant_cfg"]["*audio*"] = {"enable": False}
174-
quant_cfg["quant_cfg"]["*image*"] = {"enable": False}
175-
quant_cfg["quant_cfg"]["*vision*"] = {"enable": False}
176+
if model_type == "phi4mm":
177+
# Only quantize the language model
178+
quant_cfg["quant_cfg"]["*speech*"] = {"enable": False}
179+
quant_cfg["quant_cfg"]["*audio*"] = {"enable": False}
180+
quant_cfg["quant_cfg"]["*image*"] = {"enable": False}
181+
quant_cfg["quant_cfg"]["*vision*"] = {"enable": False}
176182

177183
return quant_cfg
178184

@@ -184,7 +190,7 @@ def is_speculative(hf_config):
184190
)
185191

186192

187-
def get_tokenizer(ckpt_path, trust_remote_code=False, **kwargs):
193+
def get_tokenizer(ckpt_path, trust_remote_code=False, **kwargs) -> PreTrainedTokenizerBase:
188194
print(f"Initializing tokenizer from {ckpt_path}")
189195

190196
if "vila" in ckpt_path.lower():
@@ -205,8 +211,12 @@ def get_tokenizer(ckpt_path, trust_remote_code=False, **kwargs):
205211

206212

207213
def get_processor(
208-
ckpt_path, model_type, device=None, trust_remote_code=False, attn_implementation=None
209-
):
214+
ckpt_path,
215+
model_type,
216+
device: torch.device = "auto",
217+
trust_remote_code=False,
218+
attn_implementation=None,
219+
) -> BaseImageProcessor | ProcessorMixin | None:
210220
"""
211221
Returns a :class:`modelopt.torch.utils.image_processor.MllamaImageProcessor` object.
212222
"""
@@ -241,6 +251,8 @@ def get_processor(
241251

242252
return MllamaImageProcessor(processor, device)
243253

254+
return None
255+
244256

245257
def get_dtype(dtype):
246258
if dtype == "bf16":

0 commit comments

Comments
 (0)