diff --git a/deepspeed/moe/sharded_moe.py b/deepspeed/moe/sharded_moe.py index 14db1c9d9e15..73e291de2516 100644 --- a/deepspeed/moe/sharded_moe.py +++ b/deepspeed/moe/sharded_moe.py @@ -96,16 +96,19 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor: class _AllToAll(torch.autograd.Function): @staticmethod - def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor: # type: ignore + def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, async_op=False) -> Tensor: # type: ignore ctx.group = group input = input.contiguous() output = torch.empty_like(input) - dist.all_to_all_single(output, input, group=group) - return output + work = dist.all_to_all_single(output, input, group=group, async_op=async_op) + if async_op: + return output, work + else: + return output @staticmethod def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]: - return (None, _AllToAll.apply(ctx.group, *grad_output)) + return (None, _AllToAll.apply(ctx.group, *grad_output), None) # einsum rewrites are on par or more performant @@ -550,6 +553,7 @@ class MOELayer(Base): expert (torch.nn.Module): expert network """ + d2d_stream = torch.cuda.Stream() def __init__(self, gate: Module, @@ -572,6 +576,8 @@ def __init__(self, self.wall_clock_breakdown = False self.use_tutel = use_tutel and TUTEL_INSTALLED and gate.k == 1 + self.enable_pipelie = True + self.shard_num = 4 if self.use_tutel: logger.info('Using Tutel optimizations.') @@ -586,8 +592,54 @@ def _set_ep_group(self, ep_group): self.ep_group = ep_group self.gate._set_ep_group(ep_group) - def forward(self, *input: Tensor, **kwargs: Any) -> Tensor: + # During multi machine MOE training, alltoall is the communication between machines, + # allgather is the communication within machines. They use different communication links, + # so they can be executed in parallel + # input shape (E,C,M),Shard input in C dim, first execute alltoall on the shard, + # So the allgather of this shard and the alltoall of the next shard are executed in parallel + # A E I M + # A1 E1 I1 M1 + # A2 E2 I2 M2 + # A3 E3 I3 M3 + # A4 E4 I4 M4 + def pipeline_alltoall_with_allgather(self, input, shard_dim=1) -> Tensor: + if not self.enable_pipelie: + input = _AllToAll.apply(self.ep_group, input) + input = gather_tokens(input, dim=shard_dim) + return input + + assert self.shard_num > 0, f"shard_num must be a positive number,but get is {self.shard_num}" + input_chunks = list(input.chunk(self.shard_num, dim=shard_dim)) + world_size = bwc_tensor_model_parallel_world_size(groups.mpu) + dims = list(input.size()) + dims[shard_dim] = dims[shard_dim] * world_size + output = torch.empty(dims, device=input.device) + input_gather_dim_len = input.shape[shard_dim] + have_gather_len = 0 + works = [] + for i in range(len(input_chunks)): + input_chunks[i], work = _AllToAll.apply(self.ep_group, input_chunks[i], True) + works.append(work) + + current_stream = torch.cuda.current_stream() + for i in range(len(input_chunks)): + works[i].wait() + # we use dim 0 do allgather and chunk, so we can avoid unnecessary cat in gather_tokens + gather_out = gather_tokens(input_chunks[i], dim=0) + gather_list = gather_out.chunk(world_size, dim=0) + dim_len = gather_list[0].shape[shard_dim] + MOELayer.d2d_stream.wait_stream(current_stream) + + for j in range(len(gather_list)): + start = input_gather_dim_len * j + have_gather_len + with torch.cuda.stream(MOELayer.d2d_stream): + torch.narrow(output, shard_dim, start, dim_len).copy_(gather_list[j]) + have_gather_len += dim_len + + current_stream.wait_stream(MOELayer.d2d_stream) + return output + def forward(self, *input: Tensor, **kwargs: Any) -> Tensor: if self.wall_clock_breakdown: self.timers(MOE_TIMER).start() @@ -611,9 +663,6 @@ def forward(self, *input: Tensor, **kwargs: Any) -> Tensor: self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input, input[1]) dispatched_input = einsum("sec,sm->ecm", dispatch_mask.type_as(input[0]), reshaped_input) - if self.wall_clock_breakdown: - self.timers(FIRST_ALLTOALL_TIMER).start() - tensor_model_world_size = bwc_tensor_model_parallel_world_size(groups.mpu) if tensor_model_world_size > 1: # If the non-expert is tensor-parallel, @@ -628,18 +677,17 @@ def forward(self, *input: Tensor, **kwargs: Any) -> Tensor: # an allgather to ensure correctness, dispatched_input = drop_tokens(dispatched_input, dim=1) - dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input) - if self.wall_clock_breakdown: - self.timers(FIRST_ALLTOALL_TIMER).stop() - self.time_falltoall = self.timers(FIRST_ALLTOALL_TIMER).elapsed(reset=False) + self.timers(FIRST_ALLTOALL_TIMER).start() if tensor_model_world_size > 1 and groups._get_expert_model_parallel_world_size() > 1: - # if both expert and non-expert are tensor-parallel - # the dropped duplicate tokens need to be gathered on each - # tensor parallel rank again to ensure correctness - dispatched_input = gather_tokens(dispatched_input, dim=1) + dispatched_input = self.pipeline_alltoall_with_allgather(dispatched_input) + else: + dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input) + if self.wall_clock_breakdown: + self.timers(FIRST_ALLTOALL_TIMER).stop() + self.time_falltoall = self.timers(FIRST_ALLTOALL_TIMER).elapsed(reset=False) # Re-shape after all-to-all: ecm -> gecm dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model) expert_output = self.experts(dispatched_input) @@ -654,18 +702,12 @@ def forward(self, *input: Tensor, **kwargs: Any) -> Tensor: if self.wall_clock_breakdown: self.timers(SECOND_ALLTOALL_TIMER).start() - expert_output = _AllToAll.apply(self.ep_group, expert_output) + expert_output = self.pipeline_alltoall_with_allgather(expert_output) if self.wall_clock_breakdown: self.timers(SECOND_ALLTOALL_TIMER).stop() self.time_salltoall = self.timers(SECOND_ALLTOALL_TIMER).elapsed(reset=False) - if tensor_model_world_size > 1: - # the dropped duplicate tokens need to be gathered on each - # tensor parallel rank again for the tensor-parallel - # non-expert of the next layer. - expert_output = gather_tokens(expert_output, dim=1) - if self.use_tutel: combined_output = self._tutel_dispatcher.decode(expert_output.view(E * C, M)) else: diff --git a/tests/unit/moe/test_pipeline.py b/tests/unit/moe/test_pipeline.py new file mode 100644 index 000000000000..b22ae6313277 --- /dev/null +++ b/tests/unit/moe/test_pipeline.py @@ -0,0 +1,89 @@ +import torch +import deepspeed +import pytest +from unit.common import DistributedTest +from deepspeed import get_accelerator +from deepspeed.moe.sharded_moe import _AllToAll +from deepspeed.moe.mappings import gather_tokens +from deepspeed.moe.layer import MoE + + +class MPU(): + + def __init__(self, tp_world_size): + self.rank = deepspeed.comm.get_rank() + self.world_size = deepspeed.comm.get_world_size() + self.tp_world_size = tp_world_size + + for i in range(0, self.world_size, tp_world_size): + ranks = range(i, i + tp_world_size) + group = deepspeed.comm.new_group(ranks) + if self.rank in ranks: + self.tp_group = group + + for i in range(0, tp_world_size): + ranks = range(i, self.world_size, tp_world_size) + group = deepspeed.comm.new_group(ranks) + if self.rank in ranks: + self.dp_group = group + + def get_model_parallel_rank(self): + return self.rank % self.tp_world_size + + def get_model_parallel_world_size(self): + return self.tp_world_size + + def get_data_parallel_rank(self): + return self.rank // self.tp_world_size + + def get_data_parallel_world_size(self): + return self.world_size // self.tp_world_size + + def get_data_parallel_group(self): + return self.dp_group + + def get_model_parallel_group(self): + return self.tp_group + + +@pytest.mark.parametrize("shard_num", [6, 10]) +@pytest.mark.parametrize("C, M, scale", [(92, 32, 1),(209, 128, 5)]) +class TestPipelineCommunication(DistributedTest): + world_size = 8 + + def test(self, shard_num, C, M, scale): + tp_size = 2 + world_size = deepspeed.comm.get_world_size() + E = world_size + ep_size = 4 + config_dict = {"train_batch_size": 8, "steps_per_print": 1, "fp16": {"enabled": True}} + hidden_dim = M + device = get_accelerator().current_device_name() + tensor_parallel_expert = torch.nn.Sequential(torch.nn.Linear(hidden_dim, 4 * hidden_dim // tp_size), + torch.nn.ReLU(), + torch.nn.Linear(4 * hidden_dim // tp_size, hidden_dim)) + + model = MoE( + hidden_size=hidden_dim, + expert=tensor_parallel_expert, + num_experts=world_size * scale, + ep_size=ep_size, + use_residual=True, + enable_expert_tensor_parallelism=True, + ) + optimizer = torch.optim.AdamW(params=model.parameters()) + model, _, _, _ = deepspeed.initialize(config=config_dict, + model=model, + optimizer=optimizer, + dist_init_required=False, + mpu=MPU(tp_size)) + model.deepspeed_moe.shard_num = shard_num + input = torch.rand(E, C, M, device=device) + + # pipeline alltoall with allgather + pipeline_output = model.deepspeed_moe.pipeline_alltoall_with_allgather(input) + + # first alltoall, then allgather + alltoall_output = _AllToAll.apply(model.deepspeed_moe.ep_group, input) + gather_output = gather_tokens(alltoall_output, dim=1) + assert torch.allclose(pipeline_output, gather_output, atol=1e-07), f"pipeline_output {pipeline_output} is not equal to gather_output {gather_output}" \ No newline at end of file