Skip to content

Conversation

@jingyu-ml
Copy link
Contributor

What does this PR do?

Type of change: new example

Overview: Support Kimi-k2 Calibration

Usage

python hf_ptq.py --pyt_ckpt_path unsloth/Kimi-K2-Thinking-BF16 --qformat nvfp4_mlp_experts_only --export_path <quantized_ckpt_path> --trust_remote_code

Testing

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Signed-off-by: Jingyu Xin <[email protected]>
@jingyu-ml jingyu-ml requested review from a team as code owners December 5, 2025 20:32
Signed-off-by: Jingyu Xin <[email protected]>
Comment on lines 626 to 644
NVFP4_MLP_EXPERTS_ONLY_CFG = {
"quant_cfg": {
"*mlp.experts*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"enable": True,
"pass_through_bwd": True,
},
"*mlp.experts*input_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"enable": True,
"pass_through_bwd": True,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": "max",
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have this config. See NVFP4_MLP_ONLY_CFG

Suggested change
NVFP4_MLP_EXPERTS_ONLY_CFG = {
"quant_cfg": {
"*mlp.experts*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"enable": True,
"pass_through_bwd": True,
},
"*mlp.experts*input_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"enable": True,
"pass_through_bwd": True,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": "max",
}

Copy link
Contributor Author

@jingyu-ml jingyu-ml Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that will quantize mlp.shared_experts

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is great. Let's not creating more cfgs


> *This is a subset of the models supported. For the full list please check the [TensorRT-LLM support matrix](https://nvidia.github.io/TensorRT-LLM/reference/precision.html#support-matrix)*
> We recommend upcasting Kimi-K2-Thinking from INT4 to BF16 before running quantization.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it a recommendation or it's something we have to do? An alterantive is to up cast the in4 to BF16 during calibration like we did with DS.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But there’s no INT4 support in PyTorch, as we discussed. People have to use vLLM if they want INT4. Me and Zhiyu are looking into the vLLM calibration of this model

> We recommend upcasting Kimi-K2-Thinking from INT4 to BF16 before running quantization.
```python
from transformers import AutoModelForCausalLM
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"w4a8_nvfp4_fp8": mtq.W4A8_NVFP4_FP8_CFG,
"w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG,
"nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG,
"nvfp4_mlp_experts_only": mtq.NVFP4_MLP_EXPERTS_ONLY_CFG,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still against adding more configs here. I think this MR we should just stick with MLP_only if we have to. People can tune the recipe themselves if they want to do experts only.

If you really like to add this config, let's name it experts_only for short. experts are always in MLP

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed it to nvfp4_experts_only. Let’s keep this config for now; once the YAML config system is released, we can avoid using these recipe dictionaries.

Signed-off-by: Jingyu Xin <[email protected]>
@codecov
Copy link

codecov bot commented Dec 5, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 74.57%. Comparing base (c6c9905) to head (647755c).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #655      +/-   ##
==========================================
- Coverage   74.58%   74.57%   -0.02%     
==========================================
  Files         183      183              
  Lines       18451    18452       +1     
==========================================
- Hits        13762    13760       -2     
- Misses       4689     4692       +3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: Jingyu Xin <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants