2020import sys
2121import warnings
2222from pathlib import Path
23+ from typing import Any
2324
2425import torch
2526import transformers
2627from accelerate import infer_auto_device_map , init_empty_weights
2728from accelerate .utils import get_max_memory
28- from transformers import AutoConfig , AutoModelForCausalLM , AutoProcessor , AutoTokenizer
29+ from transformers import (
30+ AutoConfig ,
31+ AutoModelForCausalLM ,
32+ AutoProcessor ,
33+ AutoTokenizer ,
34+ PreTrainedTokenizerBase ,
35+ ProcessorMixin ,
36+ )
2937
3038try :
3139 from huggingface_hub import snapshot_download
3240except ImportError :
3341 snapshot_download = None
3442
3543import modelopt .torch .quantization as mtq
36- from modelopt .torch .utils .image_processor import MllamaImageProcessor
44+ from modelopt .torch .utils .image_processor import BaseImageProcessor , MllamaImageProcessor
3745
3846SPECULATIVE_MODEL_LIST = ["Eagle" , "Medusa" ]
3947
@@ -127,52 +135,50 @@ def build_quant_cfg(
127135 qformat ,
128136 kv_cache_qformat ,
129137 awq_block_size ,
130- auto_quantize ,
131138 model_type ,
132139 quant_cfg_choices ,
133140 kv_quant_cfg_choices ,
134- ):
141+ ) -> dict [ str , Any ] :
135142 quant_cfg = {}
136- if not auto_quantize :
137- assert qformat in quant_cfg_choices , (
138- f"Unsupported quantization format: { qformat } with { kv_cache_qformat } KV cache"
139- )
143+ assert qformat in quant_cfg_choices , (
144+ f"Unsupported quantization format: { qformat } with { kv_cache_qformat } KV cache"
145+ )
140146
141- quant_cfg = quant_cfg_choices [qformat ]
142-
143- if "awq" in qformat :
144- quant_cfg = copy .deepcopy (quant_cfg_choices [qformat ])
145- weight_quantizer = quant_cfg ["quant_cfg" ]["*weight_quantizer" ]
146- if isinstance (weight_quantizer , list ):
147- weight_quantizer = weight_quantizer [0 ]
148- # If awq_block_size argument is provided, update weight_quantizer
149- if awq_block_size :
150- weight_quantizer ["block_sizes" ][- 1 ] = awq_block_size
151-
152- # Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models
153- if qformat == "w4a8_awq" and model_type in ["gemma" , "mpt" ]:
154- quant_cfg ["algorithm" ] = {"method" : "awq_lite" , "alpha_step" : 1 }
155-
156- enable_quant_kv_cache = kv_cache_qformat != "none"
157- print (f"{ 'Enable' if enable_quant_kv_cache else 'Disable' } KV cache quantization" )
158-
159- # Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer.
160- if enable_quant_kv_cache :
161- quant_cfg = mtq .update_quant_cfg_with_kv_cache_quant (
162- quant_cfg ,
163- getattr (mtq , kv_quant_cfg_choices [kv_cache_qformat ])["quant_cfg" ],
164- )
147+ quant_cfg = quant_cfg_choices [qformat ]
148+
149+ if "awq" in qformat :
150+ quant_cfg = copy .deepcopy (quant_cfg_choices [qformat ])
151+ weight_quantizer = quant_cfg ["quant_cfg" ]["*weight_quantizer" ]
152+ if isinstance (weight_quantizer , list ):
153+ weight_quantizer = weight_quantizer [0 ]
154+ # If awq_block_size argument is provided, update weight_quantizer
155+ if awq_block_size :
156+ weight_quantizer ["block_sizes" ][- 1 ] = awq_block_size
157+
158+ # Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models
159+ if qformat == "w4a8_awq" and model_type in ["gemma" , "mpt" ]:
160+ quant_cfg ["algorithm" ] = {"method" : "awq_lite" , "alpha_step" : 1 }
161+
162+ enable_quant_kv_cache = kv_cache_qformat != "none"
163+ print (f"{ 'Enable' if enable_quant_kv_cache else 'Disable' } KV cache quantization" )
164+
165+ # Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer.
166+ if enable_quant_kv_cache :
167+ quant_cfg = mtq .update_quant_cfg_with_kv_cache_quant (
168+ quant_cfg ,
169+ getattr (mtq , kv_quant_cfg_choices [kv_cache_qformat ])["quant_cfg" ],
170+ )
165171
166- # Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead.
167- if model_type == "gemma" and "int8_sq" in qformat :
168- quant_cfg ["algorithm" ] = {"method" : "smoothquant" , "alpha" : 0.5 }
172+ # Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead.
173+ if model_type == "gemma" and "int8_sq" in qformat :
174+ quant_cfg ["algorithm" ] = {"method" : "smoothquant" , "alpha" : 0.5 }
169175
170- if model_type == "phi4mm" :
171- # Only quantize the language model
172- quant_cfg ["quant_cfg" ]["*speech*" ] = {"enable" : False }
173- quant_cfg ["quant_cfg" ]["*audio*" ] = {"enable" : False }
174- quant_cfg ["quant_cfg" ]["*image*" ] = {"enable" : False }
175- quant_cfg ["quant_cfg" ]["*vision*" ] = {"enable" : False }
176+ if model_type == "phi4mm" :
177+ # Only quantize the language model
178+ quant_cfg ["quant_cfg" ]["*speech*" ] = {"enable" : False }
179+ quant_cfg ["quant_cfg" ]["*audio*" ] = {"enable" : False }
180+ quant_cfg ["quant_cfg" ]["*image*" ] = {"enable" : False }
181+ quant_cfg ["quant_cfg" ]["*vision*" ] = {"enable" : False }
176182
177183 return quant_cfg
178184
@@ -184,7 +190,7 @@ def is_speculative(hf_config):
184190 )
185191
186192
187- def get_tokenizer (ckpt_path , trust_remote_code = False , ** kwargs ):
193+ def get_tokenizer (ckpt_path , trust_remote_code = False , ** kwargs ) -> PreTrainedTokenizerBase :
188194 print (f"Initializing tokenizer from { ckpt_path } " )
189195
190196 if "vila" in ckpt_path .lower ():
@@ -205,8 +211,12 @@ def get_tokenizer(ckpt_path, trust_remote_code=False, **kwargs):
205211
206212
207213def get_processor (
208- ckpt_path , model_type , device = None , trust_remote_code = False , attn_implementation = None
209- ):
214+ ckpt_path ,
215+ model_type ,
216+ device : torch .device = "auto" ,
217+ trust_remote_code = False ,
218+ attn_implementation = None ,
219+ ) -> BaseImageProcessor | ProcessorMixin | None :
210220 """
211221 Returns a :class:`modelopt.torch.utils.image_processor.MllamaImageProcessor` object.
212222 """
@@ -241,6 +251,8 @@ def get_processor(
241251
242252 return MllamaImageProcessor (processor , device )
243253
254+ return None
255+
244256
245257def get_dtype (dtype ):
246258 if dtype == "bf16" :
0 commit comments