Add hetero grid args and MoE process groups for MIMO example#5375
Add hetero grid args and MoE process groups for MIMO example#5375yashaswikarnati wants to merge 5 commits into
Conversation
Add the heterogeneous-grid CLI args module (add_hetero_grid_args, validate_hetero_grid_args, build_module_grid_specs) used to configure per-module grids for MIMO training, and extend the example topology helper: - topology.py: create the additional dense process groups (tp+dp, tp+dp+cp, tp+cp+dp+pp) and set pgc.tp_dp / pgc.tp_dp_cp explicitly. MoE layers/router and finalize_model_grads read tp_dp_cp; cuda-graph capture reads tp_dp; the ProcessGroupCollection leaves them init=False by default. These edits extend the topology helper that landed with MM1; they are purely additive. The grid-args unit test travels with args.py. The existing topology unit test on main is unchanged. Validated in the 8-GPU 20L Nemotron VLM e2e (trains + checkpoint save/resume, lm loss 12.18->11.54 across resume). Signed-off-by: ykarnati <ykarnati@nvidia.com>
| grid = parser.add_argument_group("hetero module grids") | ||
|
|
||
| # Encoder grid placement + factorization. | ||
| grid.add_argument("--encoder-offset", type=int, default=0, |
There was a problem hiding this comment.
we dont need encoder offset
| def validate_hetero_grid_args(args: argparse.Namespace, world_size: int) -> tuple[int, int]: | ||
| """Validate the disjoint-grid hetero layout. Returns ``(encoder_size, llm_size)``. | ||
|
|
||
| Call AFTER stock ``validate_args`` (so ``micro_batch_size`` / ``num_experts`` |
There was a problem hiding this comment.
doc strings are verbose. can be concsely written, dont explain code in doc strings here. validate is a simple check function
| return parser.parse_args(argv) | ||
|
|
||
|
|
||
| def _layout_8gpu_20l(**overrides): |
There was a problem hiding this comment.
tests not adding value, just smoke tests should be avoided. tests can be authored more concise here
| pgc.dp_cp = grid.get_pg(["dp", "cp"]) | ||
| pgc.intra_dp_cp = pgc.dp_cp | ||
| pgc.tp_cp = grid.get_pg(["tp", "cp"]) | ||
| # MoE layers/router and finalize_model_grads read tp_dp_cp (tensor+data+context |
There was a problem hiding this comment.
remove - # MoE layers/router and finalize_model_grads read tp_dp_cp (tensor+data+context
# parallel group); cuda-graph capture reads tp_dp. Set them explicitly since the
# ProcessGroupCollection leaves them init=False.
There was a problem hiding this comment.
verbose comment and not necessary here
| # ``args.vision_encoder_key = "radio_encoder"``; we mirror that default here so | ||
| # the encoder ModuleGridSpec carries the same module name the provider/runtime | ||
| # look up. Resolved at spec-build time via ``getattr(args, "vision_encoder_key")``. | ||
| DEFAULT_ENCODER_MODULE_NAME = "radio_encoder" |
There was a problem hiding this comment.
why do we need this - DEFAULT_ENCODER_MODULE_NAME ?
| def add_hetero_grid_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: | ||
| """Register the hetero parallelism/topology arg group. | ||
|
|
||
| Stock-hook compatible: returns the parser so it can be passed straight to |
There was a problem hiding this comment.
dont use the wording Stock-hook . stock has no meanin here. also currnt doc string looks very vague and verbose.
| help="Language pipeline-model-parallel size.") | ||
| grid.add_argument("--llm-dp", type=int, default=2, | ||
| help="Language data-parallel size. Global batch is keyed on this.") | ||
| # MoE expert parallelism for the language grid. Relocated here from the E1 |
There was a problem hiding this comment.
remove this verbose comment
| return [encoder_spec, language_spec] | ||
|
|
||
|
|
||
| def _resolve_expt_tp(expt_tp, tp: int) -> int: |
There was a problem hiding this comment.
why we need special function for this
|
|
||
|
|
||
| def _num_experts(args: argparse.Namespace) -> int: | ||
| """Resolve MoE expert count from stock (--num-experts) or prototype args.""" |
There was a problem hiding this comment.
why do we need two args? wdym by prototype args ?
| return 0 | ||
|
|
||
|
|
||
| def _num_microbatches(args: argparse.Namespace) -> int: |
There was a problem hiding this comment.
why we need num microbatches? why args needs to pass it ?
| @@ -0,0 +1,253 @@ | |||
| # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. | |||
|
|
|||
| """Hetero grid/topology CLI args + validation for the stock-args MIMO examples. | |||
There was a problem hiding this comment.
verbose doc string / comments for the file
|
|
||
| # Sample-based scheduler resolution: derive --train-iters from --train-samples | ||
| # using the llm_dp-keyed global batch size. | ||
| if getattr(args, "train_samples", None) is not None: |
There was a problem hiding this comment.
why do we need to check this here? isnt this part of megatron/training training loop?
| """ | ||
| encoder_size, llm_size = validate_hetero_grid_args(args, world_size) | ||
|
|
||
| language_spec = ModuleGridSpec( |
There was a problem hiding this comment.
should call this language_grid_spec to avoid confusion with module spec we use for model init?
Trim verbose module/function docstrings and comments, remove the unused --encoder-offset arg (the encoder span always starts at rank 0), inline the trivial expert-TP default, read MoE expert count from stock --num-experts only, and rename the local grid specs to language_grid_spec / encoder_grid_spec. Slim the pure-args tests to the value-asserting cases. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: ykarnati <ykarnati@nvidia.com>
Drop the _num_microbatches helper and key the --train-samples to --train-iters conversion off the explicit --global-batch-size. num_microbatches is derived by the stock calculator (gbs / (mbs * llm_dp)), so the redundant --num-microbatches read is removed here. The conversion stays: stock validate_args does not derive train_iters from train_samples and the MIMO loop reads args.train_iters. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: ykarnati <ykarnati@nvidia.com>
| validate_hetero_grid_args(args, WORLD_SIZE_8) | ||
|
|
||
|
|
||
| def test_cp_must_be_one(): |
| validate_hetero_grid_args(args, WORLD_SIZE_8) | ||
|
|
||
|
|
||
| def test_llm_only_requires_offset_zero(): |
| assert specs[0].name == MIMO_LANGUAGE_MODULE_KEY | ||
|
|
||
|
|
||
| def test_train_samples_resolves_iters(): |
…aram The caller (model provider) owns the encoder module name. Drop the duplicate DEFAULT_ENCODER_MODULE_NAME constant and the dead vision_encoder_key getattr; RADIO_ENCODER_MODULE_NAME (radio_encoder.py) is now the single source. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: ykarnati <ykarnati@nvidia.com>
| gbs = getattr(args, "global_batch_size", None) | ||
| if not gbs or gbs <= 0: | ||
| raise ValueError("--train-samples requires a positive --global-batch-size") | ||
| args.train_iters = math.ceil(args.train_samples / gbs) |
There was a problem hiding this comment.
why are we updating train iters ourself? in args? how does megatron train loop supposed ot handle this?
|
|
||
| Returns ``[encoder_grid_spec, language_grid_spec]`` (or just the language spec | ||
| when ``--llm-only``). The caller supplies ``encoder_module_name`` (the model | ||
| provider owns it). ``num_ranks`` is the ground truth ModuleGridSpec field; |
There was a problem hiding this comment.
verbose commentary/doc string - not required
The conversion belongs to the training loop, not grid-layout validation. Stock update_train_iters owns it; the MIMO entry invokes it. validate_hetero_grid_args is now purely about grid layout. Also trim the build_module_grid_specs docstring. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: ykarnati <ykarnati@nvidia.com>
What
Add the heterogeneous-grid CLI args module (
add_hetero_grid_args,validate_hetero_grid_args,build_module_grid_specs) for MIMO training, and extend the example topology helper:topology.py: create additional dense process groups (tp+dp,tp+dp+cp,tp+cp+dp+pp) and setpgc.tp_dp/pgc.tp_dp_cpexplicitly. MoE layers/router andfinalize_model_gradsreadtp_dp_cp; cuda-graph capture readstp_dp; theProcessGroupCollectionleaves theminit=Falseby default.Why
topology.pyis the file that landed with MM1 (PR #5331-series); these edits are purely additive.args.pyis a new leaf module that importsModuleGridSpecfromtopology.py. The grid-args unit test travels withargs.py; the existing topology unit test on main is unchanged.Validation
Validated in the 8-GPU 20L Nemotron VLM e2e (trains + checkpoint save/resume, lm loss 12.18->11.54 across resume).
CODEOWNERS
examples/mimo/...+tests/unit_tests/...-> repo default owners.🤖 Generated with Claude Code