From 36809e591f515d33260793a744a3973beef5a808 Mon Sep 17 00:00:00 2001 From: Daniel Shen Date: Mon, 27 Apr 2026 17:36:30 -0700 Subject: [PATCH 1/2] fix: topkgating major bug Signed-off-by: Daniel Shen --- deepspeed/moe/sharded_moe.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/deepspeed/moe/sharded_moe.py b/deepspeed/moe/sharded_moe.py index ea1022d9d4ae..6786ffdb164d 100644 --- a/deepspeed/moe/sharded_moe.py +++ b/deepspeed/moe/sharded_moe.py @@ -384,14 +384,16 @@ def topkgating( """Implements TopKGating on logits.""" # everything is in fp32 in this function - # get topk gates - top_gate, top_idx = torch.topk(logits, k=k, dim=1) # gating decisions gates = F.softmax(logits, dim=1) num_experts = int(gates.shape[1]) + # get topk gates — use softmax probs so non-assigned slots (zero) never + # outrank legitimately-assigned tokens that happen to have negative logits + top_gate, top_idx = torch.topk(gates, k=k, dim=1) + # get topk mask - topk_masked_gates = torch.zeros_like(logits).scatter(1, top_idx, top_gate) + topk_masked_gates = torch.zeros_like(gates).scatter(1, top_idx, top_gate) mask = torch.zeros_like(gates, dtype=torch.bool).scatter_(1, top_idx, 1) From d334266741d72fb6ecd7c7cc6027bc0515f22c3e Mon Sep 17 00:00:00 2001 From: Daniel Shen Date: Mon, 27 Apr 2026 17:41:02 -0700 Subject: [PATCH 2/2] fix: further cleanup Signed-off-by: Daniel Shen --- deepspeed/moe/sharded_moe.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/deepspeed/moe/sharded_moe.py b/deepspeed/moe/sharded_moe.py index 6786ffdb164d..3d73c62f5a8a 100644 --- a/deepspeed/moe/sharded_moe.py +++ b/deepspeed/moe/sharded_moe.py @@ -388,13 +388,9 @@ def topkgating( gates = F.softmax(logits, dim=1) num_experts = int(gates.shape[1]) - # get topk gates — use softmax probs so non-assigned slots (zero) never - # outrank legitimately-assigned tokens that happen to have negative logits + # get topk gates top_gate, top_idx = torch.topk(gates, k=k, dim=1) - # get topk mask - topk_masked_gates = torch.zeros_like(gates).scatter(1, top_idx, top_gate) - mask = torch.zeros_like(gates, dtype=torch.bool).scatter_(1, top_idx, 1) exp_counts = torch.sum(mask, dim=0).detach().to(logits.device) @@ -410,9 +406,10 @@ def topkgating( # update mask and locations by capacity if drop_policy == 'probs': - capacity_probs, capacity_indices = torch.topk(topk_masked_gates, k=capacity, dim=0, sorted=False) - capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1) - mask = torch.logical_and(mask, capacity_mask) + topk_masked_gates = torch.zeros_like(gates).scatter(1, top_idx, top_gate) + _, capacity_indices = torch.topk(topk_masked_gates, k=capacity, dim=0, sorted=False) + capacity_mask = torch.zeros_like(gates, dtype=torch.bool).scatter_(0, capacity_indices, True) + mask &= capacity_mask locations = torch.cumsum(mask, dim=0) - 1 elif drop_policy == "position":