|
20 | 20 | import sys |
21 | 21 | import warnings |
22 | 22 | from pathlib import Path |
| 23 | +from typing import Any |
23 | 24 |
|
24 | 25 | import torch |
25 | 26 | import transformers |
26 | 27 | from accelerate import infer_auto_device_map, init_empty_weights |
27 | 28 | from accelerate.utils import get_max_memory |
28 | | -from transformers import AutoConfig, AutoModelForCausalLM, AutoProcessor, AutoTokenizer |
| 29 | +from transformers import ( |
| 30 | + AutoConfig, |
| 31 | + AutoModelForCausalLM, |
| 32 | + AutoProcessor, |
| 33 | + AutoTokenizer, |
| 34 | + ProcessorMixin, |
| 35 | +) |
29 | 36 |
|
30 | 37 | try: |
31 | 38 | from huggingface_hub import snapshot_download |
32 | 39 | except ImportError: |
33 | 40 | snapshot_download = None |
34 | 41 |
|
35 | 42 | import modelopt.torch.quantization as mtq |
36 | | -from modelopt.torch.utils.image_processor import MllamaImageProcessor |
| 43 | +from modelopt.torch.utils.image_processor import BaseImageProcessor, MllamaImageProcessor |
37 | 44 |
|
38 | 45 | SPECULATIVE_MODEL_LIST = ["Eagle", "Medusa"] |
39 | 46 |
|
@@ -127,52 +134,50 @@ def build_quant_cfg( |
127 | 134 | qformat, |
128 | 135 | kv_cache_qformat, |
129 | 136 | awq_block_size, |
130 | | - auto_quantize, |
131 | 137 | model_type, |
132 | 138 | quant_cfg_choices, |
133 | 139 | kv_quant_cfg_choices, |
134 | | -): |
| 140 | +) -> dict[str, Any]: |
135 | 141 | quant_cfg = {} |
136 | | - if not auto_quantize: |
137 | | - assert qformat in quant_cfg_choices, ( |
138 | | - f"Unsupported quantization format: {qformat} with {kv_cache_qformat} KV cache" |
139 | | - ) |
| 142 | + assert qformat in quant_cfg_choices, ( |
| 143 | + f"Unsupported quantization format: {qformat} with {kv_cache_qformat} KV cache" |
| 144 | + ) |
140 | 145 |
|
141 | | - quant_cfg = quant_cfg_choices[qformat] |
142 | | - |
143 | | - if "awq" in qformat: |
144 | | - quant_cfg = copy.deepcopy(quant_cfg_choices[qformat]) |
145 | | - weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"] |
146 | | - if isinstance(weight_quantizer, list): |
147 | | - weight_quantizer = weight_quantizer[0] |
148 | | - # If awq_block_size argument is provided, update weight_quantizer |
149 | | - if awq_block_size: |
150 | | - weight_quantizer["block_sizes"][-1] = awq_block_size |
151 | | - |
152 | | - # Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models |
153 | | - if qformat == "w4a8_awq" and model_type in ["gemma", "mpt"]: |
154 | | - quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1} |
155 | | - |
156 | | - enable_quant_kv_cache = kv_cache_qformat != "none" |
157 | | - print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization") |
158 | | - |
159 | | - # Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer. |
160 | | - if enable_quant_kv_cache: |
161 | | - quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant( |
162 | | - quant_cfg, |
163 | | - getattr(mtq, kv_quant_cfg_choices[kv_cache_qformat])["quant_cfg"], |
164 | | - ) |
| 146 | + quant_cfg = quant_cfg_choices[qformat] |
| 147 | + |
| 148 | + if "awq" in qformat: |
| 149 | + quant_cfg = copy.deepcopy(quant_cfg_choices[qformat]) |
| 150 | + weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"] |
| 151 | + if isinstance(weight_quantizer, list): |
| 152 | + weight_quantizer = weight_quantizer[0] |
| 153 | + # If awq_block_size argument is provided, update weight_quantizer |
| 154 | + if awq_block_size: |
| 155 | + weight_quantizer["block_sizes"][-1] = awq_block_size |
| 156 | + |
| 157 | + # Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models |
| 158 | + if qformat == "w4a8_awq" and model_type in ["gemma", "mpt"]: |
| 159 | + quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1} |
| 160 | + |
| 161 | + enable_quant_kv_cache = kv_cache_qformat != "none" |
| 162 | + print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization") |
| 163 | + |
| 164 | + # Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer. |
| 165 | + if enable_quant_kv_cache: |
| 166 | + quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant( |
| 167 | + quant_cfg, |
| 168 | + getattr(mtq, kv_quant_cfg_choices[kv_cache_qformat])["quant_cfg"], |
| 169 | + ) |
165 | 170 |
|
166 | | - # Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead. |
167 | | - if model_type == "gemma" and "int8_sq" in qformat: |
168 | | - quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5} |
| 171 | + # Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead. |
| 172 | + if model_type == "gemma" and "int8_sq" in qformat: |
| 173 | + quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5} |
169 | 174 |
|
170 | | - if model_type == "phi4mm": |
171 | | - # Only quantize the language model |
172 | | - quant_cfg["quant_cfg"]["*speech*"] = {"enable": False} |
173 | | - quant_cfg["quant_cfg"]["*audio*"] = {"enable": False} |
174 | | - quant_cfg["quant_cfg"]["*image*"] = {"enable": False} |
175 | | - quant_cfg["quant_cfg"]["*vision*"] = {"enable": False} |
| 175 | + if model_type == "phi4mm": |
| 176 | + # Only quantize the language model |
| 177 | + quant_cfg["quant_cfg"]["*speech*"] = {"enable": False} |
| 178 | + quant_cfg["quant_cfg"]["*audio*"] = {"enable": False} |
| 179 | + quant_cfg["quant_cfg"]["*image*"] = {"enable": False} |
| 180 | + quant_cfg["quant_cfg"]["*vision*"] = {"enable": False} |
176 | 181 |
|
177 | 182 | return quant_cfg |
178 | 183 |
|
@@ -205,8 +210,8 @@ def get_tokenizer(ckpt_path, trust_remote_code=False, **kwargs): |
205 | 210 |
|
206 | 211 |
|
207 | 212 | def get_processor( |
208 | | - ckpt_path, model_type, device=None, trust_remote_code=False, attn_implementation=None |
209 | | -): |
| 213 | + ckpt_path, model_type, device: str = "auto", trust_remote_code=False, attn_implementation=None |
| 214 | +) -> BaseImageProcessor | ProcessorMixin | None: |
210 | 215 | """ |
211 | 216 | Returns a :class:`modelopt.torch.utils.image_processor.MllamaImageProcessor` object. |
212 | 217 | """ |
@@ -241,6 +246,8 @@ def get_processor( |
241 | 246 |
|
242 | 247 | return MllamaImageProcessor(processor, device) |
243 | 248 |
|
| 249 | + return None |
| 250 | + |
244 | 251 |
|
245 | 252 | def get_dtype(dtype): |
246 | 253 | if dtype == "bf16": |
|
0 commit comments