Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 45 additions & 16 deletions xtuner/v1/loss/moe_loss.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Literal
from typing import cast

import torch
import torch.nn as nn
from torch import distributed as dist
from torch.distributed import ProcessGroup
from torch.distributed._functional_collectives import all_reduce


Expand All @@ -29,33 +30,34 @@ def __init__(
self,
balancing_loss_alpha: float,
balancing_loss_global_average: bool,
router_scoring_func: Literal["sigmoid", "softmax"],
) -> None:
super().__init__()
self.loss_weight = balancing_loss_alpha
self.global_average = balancing_loss_global_average

def forward(self, router_weights, n_routed_experts, num_experts_per_tok):
def forward(
self,
router_weights: torch.Tensor,
n_routed_experts: int,
num_experts_per_tok: int,
router_n_groups: int,
):
if self.loss_weight == 0:
return torch.tensor(0.0, device=router_weights.device, dtype=torch.float32)

num_layers = router_weights.shape[0]
router_weights = router_weights.float() # (nlayers, seq, ne)
_, selected_experts = torch.topk(router_weights, num_experts_per_tok, dim=-1)
selected_experts_flat = selected_experts.view(num_layers, -1)
offset = torch.arange(num_layers, device=router_weights.device).unsqueeze(1) * n_routed_experts
selected_experts_offset = selected_experts_flat + offset
tokens_per_expert_flat = torch.histc(
selected_experts_offset.view(-1),
bins=num_layers * n_routed_experts,
min=0,
max=num_layers * n_routed_experts,
)
tokens_per_expert = tokens_per_expert_flat.view(num_layers, n_routed_experts) # (nlayers, ne)
tokens_per_expert = self._get_tokens_per_experts(
router_weights,
n_routed_experts,
num_experts_per_tok,
router_n_groups,
) # (nlayers, ne)

tokens_per_expert_global = tokens_per_expert.to(router_weights.dtype) # (nlayers, ne)
if self.global_average and dist.is_initialized():
tokens_per_expert_global = all_reduce(tokens_per_expert_global, "sum", dist.group.WORLD) # (nlayers, ne)
tokens_per_expert_global = all_reduce( # (nlayers, ne)
tokens_per_expert_global, "sum", cast(ProcessGroup, dist.group.WORLD)
)
tokens_global = tokens_per_expert_global.sum(-1) # (nlayers, )
seqlen_global = tokens_global // num_experts_per_tok
routing_weights_sum_global = all_reduce_autograd(
Expand All @@ -74,6 +76,33 @@ def forward(self, router_weights, n_routed_experts, num_experts_per_tok):
# ProberList.record_tensor(scale_global, "[balancing_loss][after]scale_global")
return loss * self.loss_weight

def _get_tokens_per_experts(
self,
router_weights: torch.Tensor, # (nlayers, seq, ne)
n_routed_experts: int,
num_experts_per_tok: int,
n_groups: int,
):
num_layers, seq, n_routed_experts = router_weights.shape
group_size = max(1, n_routed_experts // n_groups)

scores_for_choice = router_weights.view(num_layers, seq, n_groups, group_size)
_, group_local_max_idx = torch.topk(
scores_for_choice, k=num_experts_per_tok // n_groups, dim=3
) # nlayers, seq, n_groups, top_k_per_group
group_offsets = torch.arange(num_layers * n_groups, device=router_weights.device) * group_size
group_offsets = group_offsets.view(num_layers, 1, n_groups, 1)

topk_ids = (group_local_max_idx + group_offsets).to(torch.long) # [seq, n_groups, top_k_per_group]
tokens_per_expert_flat = torch.histc(
topk_ids.view(-1),
bins=num_layers * n_routed_experts,
min=0,
max=num_layers * n_routed_experts,
)
tokens_per_expert = tokens_per_expert_flat.view(num_layers, n_routed_experts)
return tokens_per_expert


def z_loss(router_logits: torch.Tensor, global_average: bool = False):
router_logits = router_logits.float() # (nlayers, seq, ne)
Expand Down
2 changes: 1 addition & 1 deletion xtuner/v1/model/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def build(self, router_scoring_func) -> BalancingLoss:
return BalancingLoss(
self.balancing_loss_alpha,
self.balancing_loss_global_average,
router_scoring_func=router_scoring_func,
)


Expand Down Expand Up @@ -549,6 +548,7 @@ def _forward(
router_weights=router_weights,
n_routed_experts=self.config.n_routed_experts,
num_experts_per_tok=self.config.num_experts_per_tok,
router_n_groups=self.config.router.router_n_groups or 1,
)
output["balancing_loss"] = balancing_loss

Expand Down