Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions deepspeed/moe/sharded_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,14 +384,12 @@ def topkgating(
"""Implements TopKGating on logits."""

# everything is in fp32 in this function
# get topk gates
top_gate, top_idx = torch.topk(logits, k=k, dim=1)
# gating decisions
gates = F.softmax(logits, dim=1)
num_experts = int(gates.shape[1])

# get topk mask
topk_masked_gates = torch.zeros_like(logits).scatter(1, top_idx, top_gate)
# get topk gates
top_gate, top_idx = torch.topk(gates, k=k, dim=1)
Comment thread
delock marked this conversation as resolved.

mask = torch.zeros_like(gates, dtype=torch.bool).scatter_(1, top_idx, 1)

Expand All @@ -408,9 +406,10 @@ def topkgating(
# update mask and locations by capacity

if drop_policy == 'probs':
capacity_probs, capacity_indices = torch.topk(topk_masked_gates, k=capacity, dim=0, sorted=False)
capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1)
mask = torch.logical_and(mask, capacity_mask)
topk_masked_gates = torch.zeros_like(gates).scatter(1, top_idx, top_gate)
_, capacity_indices = torch.topk(topk_masked_gates, k=capacity, dim=0, sorted=False)
capacity_mask = torch.zeros_like(gates, dtype=torch.bool).scatter_(0, capacity_indices, True)
mask &= capacity_mask
locations = torch.cumsum(mask, dim=0) - 1

elif drop_policy == "position":
Expand Down
Loading