route collectives through torchcomms#5385
Draft
tushar00jain wants to merge 1 commit into
Draft
Conversation
Contributor
|
Thanks for opening this PR. Could you please update the PR description to follow the Megatron-LM pull request template? The current PR body is empty, so reviewers do not have the required summary, issue tracking, testing, documentation, and pre-check information. A few specific items to address:
Once those are updated, we can continue review. If the PR remains without the required template/policy information, we may need to close it until it is resubmitted or updated accordingly. |
Summary: # Route Megatron-LM collectives through PyTorch TorchComms ## Summary This PR makes Megatron-LM's process-group setup compatible with [**torchcomms**](https://meta-pytorch.org/torchcomms/main/index.html), so every `torch.distributed` collective — both the NCCL device path and the Gloo CPU path — can be routed through TorchComms by enabling PyTorch's `torch.distributed.config.use_torchcomms` (env `TORCH_DISTRIBUTED_USE_TORCHCOMMS`). No call site switches to a new API. Existing `new_group` / `init_process_group` calls route through TorchComms' `split_group` path automatically when the flag is on, so the change set is small and the default (NCCL/Gloo `ProcessGroup`) path is untouched when the flag is off. --- ## Motivation ### 1. Migration to torchcomms torchcomms is the modern PyTorch communications library designed to replace the legacy `ProcessGroup` + `Backend` abstraction. We want Megatron-LM to be able to run its entire distributed stack over torchcomms with only an environment-variable flip, as a step toward adopting it as the default collective backend. ### 2. Minimal, reversible change Keeping `new_group` (rather than calling `split_group` directly) means the diff is small, the non-torchcomms path is byte-for-byte unchanged, and the whole behavior is gated behind a single env var. ### 3. No silent config loss Where `split_group` would drop `ProcessGroupNCCL.Options` on the floor, we translate the relevant knobs (`is_high_priority_stream`, `cga_cluster_size`, `max_ctas`, `min_ctas`) into TorchComms' `CommOptions.hints` and build a standalone comm so they're actually honored. --- ## What changed ### `megatron/core/parallel_state.py` — torchcomms-compatible group creation - Torchcomms routes `new_group` through `split_group`, which requires (a) the parent PG to be eagerly **device-bound** (`bound_device_id`) and (b) the backend filter handed to subgroups to be **device-qualified** and to include the parent's default device backend ### `megatron/training/initialize.py` — eager device-bound world PG `_initialize_distributed` now, when torchcomms is enabled and a CUDA `device_id` exists: - Seeds `TORCHCOMM_RANK` / `TORCHCOMM_SIZE` for the TorchComms bootstrap. - Inits the world PG with `backend='cpu:gloo,cuda:nccl'` and `device_id=…` so the parent is eagerly device-bound. - Issues `dist.barrier(device_ids=[device_id.index])` immediately after init as a defensive eager-init flush. `device_id` alone sets `bound_device_id` (which `split_group` checks) but the underlying NCCL comm is still created lazily on first collective; the no-op device barrier forces that creation, sidestepping the intermittent init-time hang documented in [pytorch/pytorch#153960](pytorch/pytorch#153960). One collective at boot — essentially free. ### `megatron/core/process_groups_config.py` — singleton group inheritance The singleton `expt_dp_group` now routes through `parallel_state.create_group(...)` so it picks up the same backend-qualification and torchcomms routing as every other group. ### `tests/unit_tests/test_utilities.py` — Utils.initialize_distributed mirror `Utils.initialize_distributed` is the test-side analogue of `_initialize_distributed`. Under torchcomms it now inits with `backend='cpu:gloo,cuda:nccl'`, passes `device_id`, seeds `TORCHCOMM_RANK`/`SIZE`, and barriers — so unit tests that subsequently ask for a `backend='gloo'` subgroup don't trip `split_group`'s "Requested backend for device 'cpu' is not present in the parent" error. With torchcomms off it keeps the original `backend='nccl'` path. --- ## Tests Validated on a 4 × H100 (Hopper) host against PyTorch + torchcomms nightlies. 1. End-to-end smokes (`smoke_*.py`) 2. CI unit-test subset (`CI_LIGHT`) 3. No regression with the flag off All of the above pass with `TORCH_DISTRIBUTED_USE_TORCHCOMMS=0`, using the standard `ProcessGroupNCCL` / `ProcessGroupGloo` backends. --- ## Rollback / gating The whole change is gated behind `TORCH_DISTRIBUTED_USE_TORCHCOMMS`. It is a no-op unless the `torchcomms` package is installed (torch's `_use_torchcomms_enabled()` also checks availability), and it can be disabled at any time with `TORCH_DISTRIBUTED_USE_TORCHCOMMS=0` without touching code — the default NCCL/Gloo `ProcessGroup` path is unchanged. Signed-off-by: Tushar Jain <tushar00jain@users.noreply.github.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary:
Route Megatron-LM collectives through PyTorch TorchComms
Summary
This PR makes Megatron-LM's process-group setup compatible with torchcomms, so every
torch.distributedcollective — both the NCCL device path and the Gloo CPU path — can be routed through TorchComms by enabling PyTorch'storch.distributed.config.use_torchcomms(envTORCH_DISTRIBUTED_USE_TORCHCOMMS).No call site switches to a new API. Existing
new_group/init_process_groupcalls route through TorchComms'split_grouppath automatically when the flag is on, so the change set is small and the default (NCCL/GlooProcessGroup) path is untouched when the flag is off.Motivation
1. Migration to torchcomms
torchcomms is the modern PyTorch communications library designed to replace the legacy
ProcessGroup+Backendabstraction. We want Megatron-LM to be able to run its entire distributed stack over torchcomms with only an environment-variable flip, as a step toward adopting it as the default collective backend.2. Minimal, reversible change
Keeping
new_group(rather than callingsplit_groupdirectly) means the diff is small, the non-torchcomms path is byte-for-byte unchanged, and the whole behavior is gated behind a single env var.3. No silent config loss
Where
split_groupwould dropProcessGroupNCCL.Optionson the floor, we translate the relevant knobs (is_high_priority_stream,cga_cluster_size,max_ctas,min_ctas) into TorchComms'CommOptions.hintsand build a standalone comm so they're actually honored.What changed
megatron/core/parallel_state.py— torchcomms-compatible group creationnew_groupthroughsplit_group, which requires (a) the parent PG to be eagerly device-bound (bound_device_id) and (b) the backend filter handed to subgroups to be device-qualified and to include the parent's default device backendmegatron/training/initialize.py— eager device-bound world PG_initialize_distributednow, when torchcomms is enabled and a CUDAdevice_idexists:TORCHCOMM_RANK/TORCHCOMM_SIZEfor the TorchComms bootstrap.backend='cpu:gloo,cuda:nccl'anddevice_id=…so the parent is eagerly device-bound.dist.barrier(device_ids=[device_id.index])immediately after init as a defensive eager-init flush.device_idalone setsbound_device_id(whichsplit_groupchecks) but the underlying NCCL comm is still created lazily on first collective; the no-op device barrier forces that creation, sidestepping the intermittent init-time hang documented in pytorch/pytorch#153960. One collective at boot — essentially free.megatron/core/process_groups_config.py— singleton group inheritanceThe singleton
expt_dp_groupnow routes throughparallel_state.create_group(...)so it picks up the same backend-qualification and torchcomms routing as every other group.tests/unit_tests/test_utilities.py— Utils.initialize_distributed mirrorUtils.initialize_distributedis the test-side analogue of_initialize_distributed. Under torchcomms it now inits withbackend='cpu:gloo,cuda:nccl', passesdevice_id, seedsTORCHCOMM_RANK/SIZE, and barriers — so unit tests that subsequently ask for abackend='gloo'subgroup don't tripsplit_group's "Requested backend for device 'cpu' is not present in the parent" error. With torchcomms off it keeps the originalbackend='nccl'path.Tests
Validated on a 4 × H100 (Hopper) host against PyTorch + torchcomms nightlies.
smoke_*.py)CI_LIGHT)All of the above pass with
TORCH_DISTRIBUTED_USE_TORCHCOMMS=0, using the standardProcessGroupNCCL/ProcessGroupGloobackends.Rollback / gating
The whole change is gated behind
TORCH_DISTRIBUTED_USE_TORCHCOMMS. It is a no-op unless thetorchcommspackage is installed (torch's_use_torchcomms_enabled()also checks availability), and it can be disabled at any time withTORCH_DISTRIBUTED_USE_TORCHCOMMS=0without touching code — the default NCCL/GlooProcessGrouppath is unchanged.Signed-off-by: Tushar Jain tushar00jain@users.noreply.github.com