diff --git a/deepspeed/moe/sharded_moe.py b/deepspeed/moe/sharded_moe.py index ea1022d9d4ae..3d73c62f5a8a 100644 --- a/deepspeed/moe/sharded_moe.py +++ b/deepspeed/moe/sharded_moe.py @@ -384,14 +384,12 @@ def topkgating( """Implements TopKGating on logits.""" # everything is in fp32 in this function - # get topk gates - top_gate, top_idx = torch.topk(logits, k=k, dim=1) # gating decisions gates = F.softmax(logits, dim=1) num_experts = int(gates.shape[1]) - # get topk mask - topk_masked_gates = torch.zeros_like(logits).scatter(1, top_idx, top_gate) + # get topk gates + top_gate, top_idx = torch.topk(gates, k=k, dim=1) mask = torch.zeros_like(gates, dtype=torch.bool).scatter_(1, top_idx, 1) @@ -408,9 +406,10 @@ def topkgating( # update mask and locations by capacity if drop_policy == 'probs': - capacity_probs, capacity_indices = torch.topk(topk_masked_gates, k=capacity, dim=0, sorted=False) - capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1) - mask = torch.logical_and(mask, capacity_mask) + topk_masked_gates = torch.zeros_like(gates).scatter(1, top_idx, top_gate) + _, capacity_indices = torch.topk(topk_masked_gates, k=capacity, dim=0, sorted=False) + capacity_mask = torch.zeros_like(gates, dtype=torch.bool).scatter_(0, capacity_indices, True) + mask &= capacity_mask locations = torch.cumsum(mask, dim=0) - 1 elif drop_policy == "position":