From bb058eec74ef516c383ff7dff2011e8cc88103e5 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Mon, 17 Nov 2025 05:43:07 +0000 Subject: [PATCH] [Feature] Support group router based balance loss in `BalancingLoss` --- xtuner/v1/loss/moe_loss.py | 61 ++++++++++++++++++++++++++++---------- xtuner/v1/model/moe/moe.py | 2 +- 2 files changed, 46 insertions(+), 17 deletions(-) diff --git a/xtuner/v1/loss/moe_loss.py b/xtuner/v1/loss/moe_loss.py index d5aa64263..659a0d97c 100644 --- a/xtuner/v1/loss/moe_loss.py +++ b/xtuner/v1/loss/moe_loss.py @@ -1,8 +1,9 @@ -from typing import Literal +from typing import cast import torch import torch.nn as nn from torch import distributed as dist +from torch.distributed import ProcessGroup from torch.distributed._functional_collectives import all_reduce @@ -29,33 +30,34 @@ def __init__( self, balancing_loss_alpha: float, balancing_loss_global_average: bool, - router_scoring_func: Literal["sigmoid", "softmax"], ) -> None: super().__init__() self.loss_weight = balancing_loss_alpha self.global_average = balancing_loss_global_average - def forward(self, router_weights, n_routed_experts, num_experts_per_tok): + def forward( + self, + router_weights: torch.Tensor, + n_routed_experts: int, + num_experts_per_tok: int, + router_n_groups: int, + ): if self.loss_weight == 0: return torch.tensor(0.0, device=router_weights.device, dtype=torch.float32) - num_layers = router_weights.shape[0] router_weights = router_weights.float() # (nlayers, seq, ne) - _, selected_experts = torch.topk(router_weights, num_experts_per_tok, dim=-1) - selected_experts_flat = selected_experts.view(num_layers, -1) - offset = torch.arange(num_layers, device=router_weights.device).unsqueeze(1) * n_routed_experts - selected_experts_offset = selected_experts_flat + offset - tokens_per_expert_flat = torch.histc( - selected_experts_offset.view(-1), - bins=num_layers * n_routed_experts, - min=0, - max=num_layers * n_routed_experts, - ) - tokens_per_expert = tokens_per_expert_flat.view(num_layers, n_routed_experts) # (nlayers, ne) + tokens_per_expert = self._get_tokens_per_experts( + router_weights, + n_routed_experts, + num_experts_per_tok, + router_n_groups, + ) # (nlayers, ne) tokens_per_expert_global = tokens_per_expert.to(router_weights.dtype) # (nlayers, ne) if self.global_average and dist.is_initialized(): - tokens_per_expert_global = all_reduce(tokens_per_expert_global, "sum", dist.group.WORLD) # (nlayers, ne) + tokens_per_expert_global = all_reduce( # (nlayers, ne) + tokens_per_expert_global, "sum", cast(ProcessGroup, dist.group.WORLD) + ) tokens_global = tokens_per_expert_global.sum(-1) # (nlayers, ) seqlen_global = tokens_global // num_experts_per_tok routing_weights_sum_global = all_reduce_autograd( @@ -74,6 +76,33 @@ def forward(self, router_weights, n_routed_experts, num_experts_per_tok): # ProberList.record_tensor(scale_global, "[balancing_loss][after]scale_global") return loss * self.loss_weight + def _get_tokens_per_experts( + self, + router_weights: torch.Tensor, # (nlayers, seq, ne) + n_routed_experts: int, + num_experts_per_tok: int, + n_groups: int, + ): + num_layers, seq, n_routed_experts = router_weights.shape + group_size = max(1, n_routed_experts // n_groups) + + scores_for_choice = router_weights.view(num_layers, seq, n_groups, group_size) + _, group_local_max_idx = torch.topk( + scores_for_choice, k=num_experts_per_tok // n_groups, dim=3 + ) # nlayers, seq, n_groups, top_k_per_group + group_offsets = torch.arange(num_layers * n_groups, device=router_weights.device) * group_size + group_offsets = group_offsets.view(num_layers, 1, n_groups, 1) + + topk_ids = (group_local_max_idx + group_offsets).to(torch.long) # [seq, n_groups, top_k_per_group] + tokens_per_expert_flat = torch.histc( + topk_ids.view(-1), + bins=num_layers * n_routed_experts, + min=0, + max=num_layers * n_routed_experts, + ) + tokens_per_expert = tokens_per_expert_flat.view(num_layers, n_routed_experts) + return tokens_per_expert + def z_loss(router_logits: torch.Tensor, global_average: bool = False): router_logits = router_logits.float() # (nlayers, seq, ne) diff --git a/xtuner/v1/model/moe/moe.py b/xtuner/v1/model/moe/moe.py index 41b35bc3d..331d7096a 100644 --- a/xtuner/v1/model/moe/moe.py +++ b/xtuner/v1/model/moe/moe.py @@ -69,7 +69,6 @@ def build(self, router_scoring_func) -> BalancingLoss: return BalancingLoss( self.balancing_loss_alpha, self.balancing_loss_global_average, - router_scoring_func=router_scoring_func, ) @@ -549,6 +548,7 @@ def _forward( router_weights=router_weights, n_routed_experts=self.config.n_routed_experts, num_experts_per_tok=self.config.num_experts_per_tok, + router_n_groups=self.config.router.router_n_groups or 1, ) output["balancing_loss"] = balancing_loss