Skip to content

Add hetero grid args and MoE process groups for MIMO example#5375

Draft
yashaswikarnati wants to merge 5 commits into
NVIDIA:mainfrom
yashaswikarnati:x2-args-topology
Draft

Add hetero grid args and MoE process groups for MIMO example#5375
yashaswikarnati wants to merge 5 commits into
NVIDIA:mainfrom
yashaswikarnati:x2-args-topology

Conversation

@yashaswikarnati

Copy link
Copy Markdown
Contributor

What

Add the heterogeneous-grid CLI args module (add_hetero_grid_args, validate_hetero_grid_args, build_module_grid_specs) for MIMO training, and extend the example topology helper:

  • topology.py: create additional dense process groups (tp+dp, tp+dp+cp, tp+cp+dp+pp) and set pgc.tp_dp / pgc.tp_dp_cp explicitly. MoE layers/router and finalize_model_grads read tp_dp_cp; cuda-graph capture reads tp_dp; the ProcessGroupCollection leaves them init=False by default.

Why

topology.py is the file that landed with MM1 (PR #5331-series); these edits are purely additive. args.py is a new leaf module that imports ModuleGridSpec from topology.py. The grid-args unit test travels with args.py; the existing topology unit test on main is unchanged.

Validation

Validated in the 8-GPU 20L Nemotron VLM e2e (trains + checkpoint save/resume, lm loss 12.18->11.54 across resume).

CODEOWNERS

examples/mimo/... + tests/unit_tests/... -> repo default owners.

🤖 Generated with Claude Code

Add the heterogeneous-grid CLI args module (add_hetero_grid_args,
validate_hetero_grid_args, build_module_grid_specs) used to configure
per-module grids for MIMO training, and extend the example topology helper:

- topology.py: create the additional dense process groups (tp+dp, tp+dp+cp,
  tp+cp+dp+pp) and set pgc.tp_dp / pgc.tp_dp_cp explicitly. MoE layers/router
  and finalize_model_grads read tp_dp_cp; cuda-graph capture reads tp_dp; the
  ProcessGroupCollection leaves them init=False by default. These edits extend
  the topology helper that landed with MM1; they are purely additive.

The grid-args unit test travels with args.py. The existing topology unit
test on main is unchanged.

Validated in the 8-GPU 20L Nemotron VLM e2e (trains + checkpoint
save/resume, lm loss 12.18->11.54 across resume).

Signed-off-by: ykarnati <ykarnati@nvidia.com>
@copy-pr-bot

copy-pr-bot Bot commented Jun 16, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Comment thread examples/mimo/training/args.py Outdated
grid = parser.add_argument_group("hetero module grids")

# Encoder grid placement + factorization.
grid.add_argument("--encoder-offset", type=int, default=0,

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we dont need encoder offset

Comment thread examples/mimo/training/args.py Outdated
def validate_hetero_grid_args(args: argparse.Namespace, world_size: int) -> tuple[int, int]:
"""Validate the disjoint-grid hetero layout. Returns ``(encoder_size, llm_size)``.

Call AFTER stock ``validate_args`` (so ``micro_batch_size`` / ``num_experts``

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doc strings are verbose. can be concsely written, dont explain code in doc strings here. validate is a simple check function

return parser.parse_args(argv)


def _layout_8gpu_20l(**overrides):

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests not adding value, just smoke tests should be avoided. tests can be authored more concise here

Comment thread examples/mimo/training/topology.py Outdated
pgc.dp_cp = grid.get_pg(["dp", "cp"])
pgc.intra_dp_cp = pgc.dp_cp
pgc.tp_cp = grid.get_pg(["tp", "cp"])
# MoE layers/router and finalize_model_grads read tp_dp_cp (tensor+data+context

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove - # MoE layers/router and finalize_model_grads read tp_dp_cp (tensor+data+context
# parallel group); cuda-graph capture reads tp_dp. Set them explicitly since the
# ProcessGroupCollection leaves them init=False.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

verbose comment and not necessary here

Comment thread examples/mimo/training/args.py Outdated
# ``args.vision_encoder_key = "radio_encoder"``; we mirror that default here so
# the encoder ModuleGridSpec carries the same module name the provider/runtime
# look up. Resolved at spec-build time via ``getattr(args, "vision_encoder_key")``.
DEFAULT_ENCODER_MODULE_NAME = "radio_encoder"

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this - DEFAULT_ENCODER_MODULE_NAME ?

Comment thread examples/mimo/training/args.py Outdated
def add_hetero_grid_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Register the hetero parallelism/topology arg group.

Stock-hook compatible: returns the parser so it can be passed straight to

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dont use the wording Stock-hook . stock has no meanin here. also currnt doc string looks very vague and verbose.

Comment thread examples/mimo/training/args.py Outdated
help="Language pipeline-model-parallel size.")
grid.add_argument("--llm-dp", type=int, default=2,
help="Language data-parallel size. Global batch is keyed on this.")
# MoE expert parallelism for the language grid. Relocated here from the E1

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this verbose comment

Comment thread examples/mimo/training/args.py
Comment thread examples/mimo/training/args.py Outdated
return [encoder_spec, language_spec]


def _resolve_expt_tp(expt_tp, tp: int) -> int:

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we need special function for this

Comment thread examples/mimo/training/args.py Outdated


def _num_experts(args: argparse.Namespace) -> int:
"""Resolve MoE expert count from stock (--num-experts) or prototype args."""

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need two args? wdym by prototype args ?

Comment thread examples/mimo/training/args.py Outdated
return 0


def _num_microbatches(args: argparse.Namespace) -> int:

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we need num microbatches? why args needs to pass it ?

Comment thread examples/mimo/training/args.py Outdated
@@ -0,0 +1,253 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.

"""Hetero grid/topology CLI args + validation for the stock-args MIMO examples.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

verbose doc string / comments for the file

Comment thread examples/mimo/training/args.py Outdated

# Sample-based scheduler resolution: derive --train-iters from --train-samples
# using the llm_dp-keyed global batch size.
if getattr(args, "train_samples", None) is not None:

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to check this here? isnt this part of megatron/training training loop?

Comment thread examples/mimo/training/args.py Outdated
"""
encoder_size, llm_size = validate_hetero_grid_args(args, world_size)

language_spec = ModuleGridSpec(

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should call this language_grid_spec to avoid confusion with module spec we use for model init?

yashaswikarnati and others added 2 commits June 18, 2026 17:30
Trim verbose module/function docstrings and comments, remove the unused
--encoder-offset arg (the encoder span always starts at rank 0), inline the
trivial expert-TP default, read MoE expert count from stock --num-experts only,
and rename the local grid specs to language_grid_spec / encoder_grid_spec.
Slim the pure-args tests to the value-asserting cases.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: ykarnati <ykarnati@nvidia.com>
Drop the _num_microbatches helper and key the --train-samples to --train-iters
conversion off the explicit --global-batch-size. num_microbatches is derived by
the stock calculator (gbs / (mbs * llm_dp)), so the redundant --num-microbatches
read is removed here. The conversion stays: stock validate_args does not derive
train_iters from train_samples and the MIMO loop reads args.train_iters.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: ykarnati <ykarnati@nvidia.com>
validate_hetero_grid_args(args, WORLD_SIZE_8)


def test_cp_must_be_one():

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

validate_hetero_grid_args(args, WORLD_SIZE_8)


def test_llm_only_requires_offset_zero():

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

assert specs[0].name == MIMO_LANGUAGE_MODULE_KEY


def test_train_samples_resolves_iters():

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

…aram

The caller (model provider) owns the encoder module name. Drop the duplicate
DEFAULT_ENCODER_MODULE_NAME constant and the dead vision_encoder_key getattr;
RADIO_ENCODER_MODULE_NAME (radio_encoder.py) is now the single source.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: ykarnati <ykarnati@nvidia.com>
Comment thread examples/mimo/training/args.py Outdated
gbs = getattr(args, "global_batch_size", None)
if not gbs or gbs <= 0:
raise ValueError("--train-samples requires a positive --global-batch-size")
args.train_iters = math.ceil(args.train_samples / gbs)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we updating train iters ourself? in args? how does megatron train loop supposed ot handle this?

Comment thread examples/mimo/training/args.py Outdated

Returns ``[encoder_grid_spec, language_grid_spec]`` (or just the language spec
when ``--llm-only``). The caller supplies ``encoder_module_name`` (the model
provider owns it). ``num_ranks`` is the ground truth ModuleGridSpec field;

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

verbose commentary/doc string - not required

The conversion belongs to the training loop, not grid-layout validation. Stock
update_train_iters owns it; the MIMO entry invokes it. validate_hetero_grid_args
is now purely about grid layout. Also trim the build_module_grid_specs docstring.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: ykarnati <ykarnati@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant