diff --git a/modelopt/torch/export/model_config.py b/modelopt/torch/export/model_config.py index 306348f2c..93edb1822 100755 --- a/modelopt/torch/export/model_config.py +++ b/modelopt/torch/export/model_config.py @@ -31,6 +31,7 @@ QUANTIZATION_INT8_SQ = "int8_sq" QUANTIZATION_INT8_WO = "int8_wo" QUANTIZATION_INT4_AWQ = "int4_awq" +QUANTIZATION_INT4_WO = "int4_wo" QUANTIZATION_W4A8_AWQ = "w4a8_awq" QUANTIZATION_NVFP4 = "nvfp4" QUANTIZATION_W4A8_NVFP4_FP8 = "w4a8_nvfp4_fp8" diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index eee13dc51..4b9463ea8 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -51,6 +51,7 @@ QUANTIZATION_FP8_PB_WO, QUANTIZATION_FP8_PC_PT, QUANTIZATION_INT4_AWQ, + QUANTIZATION_INT4_WO, QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO, QUANTIZATION_MXFP4, @@ -463,7 +464,10 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames assert len(weight_quantizer.block_sizes) > 0 and weight_quantizer.block_sizes[-1] > 0, ( "Invalid block_sizes for INT4 quantizer" ) - return QUANTIZATION_INT4_AWQ + if hasattr(weight_quantizer, "pre_quant_scale") and weight_quantizer.pre_quant_scale: + return QUANTIZATION_INT4_AWQ + else: + return QUANTIZATION_INT4_WO if weight_quantizer.num_bits == 8: if input_quantizer is not None and input_quantizer.is_enabled: @@ -634,6 +638,13 @@ def process_layer_quant_config(layer_config_dict): "has_zero_point": False, "pre_quant_scale": True, } + elif v == "int4_wo": + layer_config = { + "quant_algo": "W4A16", + "group_size": block_size_value, + "has_zero_point": False, + "pre_quant_scale": True, + } elif v == "w4a8_awq": layer_config = { "quant_algo": "W4A8_AWQ", @@ -810,7 +821,7 @@ def to_quantized_weight( ) return (weight / weights_scaling_factor[:, None]).to(torch.float8_e4m3fn) - if quantization in [QUANTIZATION_INT4_AWQ, QUANTIZATION_W4A8_AWQ]: + if quantization in [QUANTIZATION_INT4_AWQ, QUANTIZATION_W4A8_AWQ, QUANTIZATION_INT4_WO]: return pack_int4_in_uint8(weight, weights_scaling_factor) if quantization in [QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, QUANTIZATION_W4A8_NVFP4_FP8]: