Skip to content

[ci] H100 CI#1679

Open
hao-aaron wants to merge 2 commits into
NovaSky-AI:mainfrom
hao-aaron:h100-ci
Open

[ci] H100 CI#1679
hao-aaron wants to merge 2 commits into
NovaSky-AI:mainfrom
hao-aaron:h100-ci

Conversation

@hao-aaron
Copy link
Copy Markdown
Collaborator

H100 CI tests for large MoE models

Introduces an opt-in H100 CI lane that exercises two ~30B-class MoE models end-to-end:

  • Qwen/Qwen3.5-35B-A3B — FSDP e2e (test_policy_local_engines_e2e), both colocated and non-colocated with vLLM. Megatron logprob roundtrip (test_megatron_models).
  • nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 — Megatron logprob roundtrip (test_megatron_models), re-enabled after the original 8-GPU param was skipped.

Tests are gated by pytest.mark.h100 (auto-skipped unless -m h100 is passed). The new GitHub workflow submits them as an Anyscale staging job on the llm-team-h100-4x:1 compute config.

Test infrastructure

  • Register the h100 marker in tests/backends/skyrl_train/gpu/conftest.py and auto-skip those tests unless -m h100 is explicitly passed.
  • Add Qwen3.5-35B-A3B parameterizations to test_policy_local_engines_e2e (colocated + non-colocated) and test_megatron_models (TP=4 EP=4).
  • Re-enable the Nemotron-3-Nano param on 4 GPUs (TP=1 EP=4) with the h100 marker.

Core fixes uncovered along the way

  • FSDP2 meta-init reordering (fsdp_strategy.py): swap the module to meta before apply_fsdp2, so sharded DTensors allocate directly at their final shard size instead of materializing the full model on rank 0 first. Snapshot/restore non-persistent buffers (e.g. RotaryEmbedding.inv_freq) around the meta swap so they survive the broadcast.
  • Replace manual broadcast loop with set_model_state_dict (fsdp_utils.py): the ~80-line per-parameter broadcast + distribute_tensor loop is now PyTorch's set_model_state_dict(..., StateDictOptions(full_state_dict=True, broadcast_from_rank0=True)).
  • Remove the FSDP1-era offload+load dance (fsdp_utils.py): the trick (offload_fsdp2_model_to_cpuempty_cacheload_fsdp2_model_to_gpu) was meant to clear reserved-but-unallocated PyTorch memory. For FSDP2 it's a no-op (model.to("cpu") doesn't move FSDPParam-managed storage during init), then the reload allocates a second copy — doubling memory. Removed; set_model_state_dict already leaves us with exactly the shard on GPU.
  • NCCL weight-sync receive must be wrapped in set_current_vllm_config (new_inference_worker_wrap.py, remote_inference_client.py, broadcast_strategy.py): MoE models (FlashInfer CUTLASS kernel) read get_current_vllm_config() during load_weights, which is only set around init_device/load_model. Added a new update_weights_nccl worker method that wraps weight_transfer_engine.receive_weights in set_current_vllm_config, and route the broadcast sender through start_weight_update + update_weights_nccl + finish_weight_update instead of vLLM's native /update_weights endpoint. Same pattern CUDA IPC already used. Tracked against upstream vllm-project/vllm#42577.
  • Rename update_weights_chunkupdate_weights_ipc across new_inference_worker_wrap.py, remote_inference_client.py, and cuda_ipc_strategy.py so the IPC and NCCL paths have parallel names.
  • Drop stray bf16=False in policy FSDP init (fsdp_worker.py:162): this was forcing fp32 master weights even though the wrapper default is bf16=True (and the critic path at line 408 already respects the config). Halves per-rank shard size — was the difference between fitting and OOMing on non-colocated 2-GPU FSDP.
  • Large-MoE Megatron config (test_megatron_models.py): for Qwen3.5-35B and Nemotron-3-Nano, set use_precision_aware_optimizer=True + optimizer_cpu_offload=True + optimizer_offload_fraction=1.0. Megatron eagerly materializes the fp32 AdamW state on GPU at init (unlike PyTorch's lazy AdamW), so the optimizer state alone OOMs without offload.
  • Lower vLLM gpu_memory_utilization to 0.5 for these large MoE models in the Megatron test (_engine_overrides_for_model) so the policy shard + vLLM pool both fit on each H100.
  • sleep_level=2 in test_megatron_models: the test explicitly syncs weights, so full sleep (matching the InferenceEngineState.create default and the FSDP e2e test) is the right level. Previously was hardcoded to 1.

hao-aaron added 2 commits May 16, 2026 03:39
x
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
x
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds H100 GPU support, new CI configurations, and optimizations for large MoE models, specifically refactoring FSDP2 initialization to use meta-device swapping and PyTorch's set_model_state_dict. The weight update protocol was also updated to support NCCL transfers alongside existing IPC methods. Review feedback identified opportunities to further optimize memory usage by restricting state_dict calls to rank 0 and removing unused arguments in buffer synchronization.

@@ -219,6 +219,33 @@ def _fsdp_init_model(self, model, is_train=True, is_wrapped=False):
}
module = model.model if is_wrapped else model
full_state = module.state_dict()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Calling state_dict() on all ranks materializes the full model state in memory on every process. Since fsdp2_load_full_state_dict uses set_model_state_dict with broadcast_from_rank0=True, only rank 0 needs the actual data. On non-zero ranks, this leads to significant and unnecessary CPU memory usage, which can cause OOMs on nodes with many GPUs when training large models.

Suggested change
full_state = module.state_dict()
full_state = module.state_dict() if dist.get_rank() == 0 else {}

# Broadcast non-persistent buffers (e.g. inv_freq from RotaryEmbedding)
# that are excluded from state_dict. On non-rank-0 meta-init these are
# still on the meta device with no data; rank 0 has the correct values.
_sync_non_persistent_buffers(model, model.state_dict())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The second argument to _sync_non_persistent_buffers is not utilized by the function implementation. Calling model.state_dict() on an FSDP2 model, while relatively cheap as it returns sharded DTensors, still incurs unnecessary overhead and creates a large dictionary of tensors for no reason.

Suggested change
_sync_non_persistent_buffers(model, model.state_dict())
_sync_non_persistent_buffers(model, {})

Comment on lines 159 to 164
wrapped_model = HFModelWrapper(
model_path,
use_flash_attention_2=self.cfg.flash_attn,
bf16=False,
lora_rank=self.cfg.policy.model.lora.rank,
lora_alpha=self.cfg.policy.model.lora.alpha,
lora_dropout=self.cfg.policy.model.lora.dropout,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change - usingbf16=True - is going to initialize the model weights in BF16, and will lead to pure BF16 training.

You want to initialize the model weights in FP32 and let FSDP handle casting to BF16 during forward pass.

https://docs.pytorch.org/docs/2.12/fsdp.html#torch.distributed.fsdp.MixedPrecision

"Outside forward and backward, the sharded parameters are kept in full precision (e.g. for the optimizer step), and for model checkpointing, the parameters are always saved in full precision."

Comment on lines +75 to +79
# Large MoE models: AdamW's fp32 optimizer state (~6x model) is the
# dominant per-rank GPU cost during Megatron init.
# use_precision_aware_optimizer keeps it in bf16 (halves it), and
# optimizer_cpu_offload moves it off GPU entirely. Without these the
# optimizer state alone OOMs the H100 on 4 GPUs.
Copy link
Copy Markdown
Member

@SumanthRH SumanthRH May 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to just initialize policy without any optimizer state

We might need a fix in SkyRL FSDP code for this, which is the better way to do this than using optimizer offload.

We can explicitly pass optimizer_config as None to the FSDPStrategy :

optim_config = self.optimizer_config
if optim_config is not None:
new_optimizer = optim.AdamW(
fsdp_module.parameters(),

Currently this assertion willl prevent us from having optimizer as None for PolicyWorker,

assert (
self.optimizer is not None and self.scheduler is not None
), "FSDP preparation should create optimizer and scheduler"

but we can just modify it to say that it should be non-null only if optimizer_config is non-null

Copy link
Copy Markdown
Member

@SumanthRH SumanthRH left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you break up this PR into

  1. A minimal PR to initialize H100 CI with Qwen 3.5 35B
  2. A PR to add support for Nemotron 30B

Also, it is hard to see which of the changes in FSDP worker initialization were essential and also which model they apply to. When you break up the PR, please also add only the relevant changes needed.

This PR also switches to initialize the model weights directly in BF16. While this saves memory, it introduces a correctness bug because training will no longer be in mixed precision but instead in pure BF16.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants