Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions modelopt/torch/quantization/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,78 @@ def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor:
return self.w2_linear[expert_idx](x1)


class _QuantQwen3VLMoeTextExperts(QuantModule):
def _setup(self):
"""Modify the Qwen3VLMoeTextExperts by using nn.Linear layers."""
from accelerate import init_empty_weights

dtype, device = self.gate_up_proj.dtype, self.gate_up_proj.device

def _copy_weight(module, weight):
module.to_empty(device=device)
with torch.no_grad():
module.weight.data = weight.detach().data.to(dtype=dtype, device=device)

with init_empty_weights():
gate_proj = nn.ModuleList(
[
nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
for _ in range(self.num_experts)
]
)
up_proj = nn.ModuleList(
[
nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
for _ in range(self.num_experts)
]
)
down_proj = nn.ModuleList(
[
nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
for _ in range(self.num_experts)
]
)

for idx in range(self.num_experts):
_copy_weight(gate_proj[idx], self.gate_up_proj[idx, :, : self.expert_dim].T)
_copy_weight(up_proj[idx], self.gate_up_proj[idx, :, self.expert_dim :].T)
_copy_weight(down_proj[idx], self.down_proj[idx, :].T)

delattr(self, "gate_up_proj")
delattr(self, "down_proj")
self.gate_proj = gate_proj
self.up_proj = up_proj
self.down_proj = down_proj

def forward(
self,
hidden_states: torch.Tensor,
routing_weights: torch.Tensor,
router_indices: torch.Tensor,
) -> torch.Tensor:
batch_size = hidden_states.shape[0]
hidden_states = hidden_states.reshape(-1, self.hidden_size)
next_states = torch.zeros_like(hidden_states)
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit[:]:
assert expert_idx.numel() == 1, expert_idx
with torch.no_grad():
_, token_idx = torch.where(expert_mask[expert_idx[0]])
current_state = hidden_states[token_idx]
gate = self.gate_proj[expert_idx](current_state)
up = self.up_proj[expert_idx](current_state)
gated_output = up * self.act_fn(gate)
out = self.down_proj[expert_idx](gated_output)
weighted_output = out * routing_weights[token_idx, expert_idx, None]
next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
next_states = next_states.view(batch_size, -1, self.hidden_size)

return next_states


class _QuantDbrxFFN(_QuantSparseMoe):
@property
def num_experts(self):
Expand Down Expand Up @@ -576,6 +648,27 @@ def top_k(self, value):
except ImportError:
pass

try:
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextSparseMoeBlock

if Qwen3VLMoeTextSparseMoeBlock not in QuantModuleRegistry:
QuantModuleRegistry.register(
{Qwen3VLMoeTextSparseMoeBlock: "hf.Qwen3VLMoeTextSparseMoeBlock"}
)(_QuantSparseMoe)
except ImportError:
pass


try:
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextExperts

if Qwen3VLMoeTextExperts not in QuantModuleRegistry:
QuantModuleRegistry.register({Qwen3VLMoeTextExperts: "hf.Qwen3VLMoeTextExperts"})(
_QuantQwen3VLMoeTextExperts
)
except ImportError:
pass


class _QuantGptOssExperts(_QuantFunctionalMixin):
"""Quantized wrapper for `transformers.GptOssExperts`.
Expand Down