diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index a3fa6ef1a..1d1d02929 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -487,6 +487,78 @@ def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor: return self.w2_linear[expert_idx](x1) +class _QuantQwen3VLMoeTextExperts(QuantModule): + def _setup(self): + """Modify the Qwen3VLMoeTextExperts by using nn.Linear layers.""" + from accelerate import init_empty_weights + + dtype, device = self.gate_up_proj.dtype, self.gate_up_proj.device + + def _copy_weight(module, weight): + module.to_empty(device=device) + with torch.no_grad(): + module.weight.data = weight.detach().data.to(dtype=dtype, device=device) + + with init_empty_weights(): + gate_proj = nn.ModuleList( + [ + nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + for _ in range(self.num_experts) + ] + ) + up_proj = nn.ModuleList( + [ + nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + for _ in range(self.num_experts) + ] + ) + down_proj = nn.ModuleList( + [ + nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + for _ in range(self.num_experts) + ] + ) + + for idx in range(self.num_experts): + _copy_weight(gate_proj[idx], self.gate_up_proj[idx, :, : self.expert_dim].T) + _copy_weight(up_proj[idx], self.gate_up_proj[idx, :, self.expert_dim :].T) + _copy_weight(down_proj[idx], self.down_proj[idx, :].T) + + delattr(self, "gate_up_proj") + delattr(self, "down_proj") + self.gate_proj = gate_proj + self.up_proj = up_proj + self.down_proj = down_proj + + def forward( + self, + hidden_states: torch.Tensor, + routing_weights: torch.Tensor, + router_indices: torch.Tensor, + ) -> torch.Tensor: + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) + next_states = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hit[:]: + assert expert_idx.numel() == 1, expert_idx + with torch.no_grad(): + _, token_idx = torch.where(expert_mask[expert_idx[0]]) + current_state = hidden_states[token_idx] + gate = self.gate_proj[expert_idx](current_state) + up = self.up_proj[expert_idx](current_state) + gated_output = up * self.act_fn(gate) + out = self.down_proj[expert_idx](gated_output) + weighted_output = out * routing_weights[token_idx, expert_idx, None] + next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) + next_states = next_states.view(batch_size, -1, self.hidden_size) + + return next_states + + class _QuantDbrxFFN(_QuantSparseMoe): @property def num_experts(self): @@ -576,6 +648,27 @@ def top_k(self, value): except ImportError: pass +try: + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextSparseMoeBlock + + if Qwen3VLMoeTextSparseMoeBlock not in QuantModuleRegistry: + QuantModuleRegistry.register( + {Qwen3VLMoeTextSparseMoeBlock: "hf.Qwen3VLMoeTextSparseMoeBlock"} + )(_QuantSparseMoe) +except ImportError: + pass + + +try: + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextExperts + + if Qwen3VLMoeTextExperts not in QuantModuleRegistry: + QuantModuleRegistry.register({Qwen3VLMoeTextExperts: "hf.Qwen3VLMoeTextExperts"})( + _QuantQwen3VLMoeTextExperts + ) +except ImportError: + pass + class _QuantGptOssExperts(_QuantFunctionalMixin): """Quantized wrapper for `transformers.GptOssExperts`.