Skip to content

[WIP] Add changes needed for FP8 megatron training#1543

Open
pcmoritz wants to merge 2 commits into
NovaSky-AI:mainfrom
pcmoritz:megatron-fp8
Open

[WIP] Add changes needed for FP8 megatron training#1543
pcmoritz wants to merge 2 commits into
NovaSky-AI:mainfrom
pcmoritz:megatron-fp8

Conversation

@pcmoritz
Copy link
Copy Markdown
Collaborator

@pcmoritz pcmoritz commented Apr 21, 2026

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several optimizations and fixes for distributed training, including memory-efficient gradient computation, FP8 alignment adjustments, and enhanced sequence parallelism support. Feedback focuses on correcting a potential typo in the CUDA version for nvidia-nvshmem, increasing the chunk_size for log-probability calculations to mitigate communication overhead, and vectorizing loops used for attention mask reconstruction to improve performance.

Comment thread pyproject.toml Outdated
"hf_transfer",
"cloudpathlib>=0.23.0",
# "deep-ep",
"nvidia-nvshmem-cu13>=3.6.5; sys_platform == 'linux'",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Potential typo: nvidia-nvshmem-cu13 refers to CUDA 13, which is not yet released. You likely intended to use nvidia-nvshmem-cu12 or a version compatible with your current CUDA environment (e.g., cu128 as seen in your index URL).

inference_only=True,
cp_group=None, # we handle cp gathering in `postprocess_packed_seqs`
chunk_size=None,
chunk_size=16,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The hardcoded chunk_size=16 is extremely small. Since the log-probability computation involves multiple all_reduce operations per chunk (3 per chunk in the current implementation), such a small value will lead to significant communication overhead. Consider increasing this to a larger value (e.g., 128 or 256) to improve performance while maintaining memory efficiency.

Comment on lines +179 to +185
gather_attention_mask = torch.zeros(
outputs.shape[0], gathered_seq_len,
dtype=new_attention_mask.dtype, device=new_attention_mask.device
)
seq_lens = new_attention_mask.sum(dim=1)
for i in range(outputs.shape[0]):
gather_attention_mask[i, :seq_lens[i]] = 1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This loop for reconstructing the attention mask after gathering can be vectorized to improve performance, especially as the batch size or sequence length increases.

                    seq_lens = new_attention_mask.sum(dim=1, keepdim=True)
                    gather_attention_mask = (torch.arange(outputs.shape[1], device=new_attention_mask.device) < seq_lens).to(new_attention_mask.dtype)

inference_only=False,
cp_group=None, # we handle cp gathering in `postprocess_packed_seqs`
chunk_size=None,
chunk_size=16,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The hardcoded chunk_size=16 is extremely small. Since the log-probability computation involves multiple all_reduce operations per chunk, such a small value will lead to significant communication overhead. Consider increasing this to a larger value (e.g., 128 or 256) to improve performance while maintaining memory efficiency.

Comment on lines +493 to +499
gather_attention_mask = torch.zeros(
outputs.shape[0], gathered_seq_len,
dtype=new_attention_mask.dtype, device=new_attention_mask.device
)
seq_lens = new_attention_mask.sum(dim=1)
for i in range(outputs.shape[0]):
gather_attention_mask[i, :seq_lens[i]] = 1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This loop for reconstructing the attention mask after gathering can be vectorized to improve performance, especially as the batch size or sequence length increases.

                    seq_lens = new_attention_mask.sum(dim=1, keepdim=True)
                    gather_attention_mask = (torch.arange(outputs.shape[1], device=new_attention_mask.device) < seq_lens).to(new_attention_mask.dtype)

Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

✅ Devin Review: No Issues Found

Devin Review analyzed this PR and found no bugs or issues to report.

Open in Devin Review

Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 1 new potential issue.

View 5 additional findings in Devin Review.

Open in Devin Review

all_grad_input.append(grad_input)

grad_input = torch.cat(all_grad_input, dim=1)
vocab_parallel_logits[:, chunk_start:chunk_end, :] = chunk_grad.to(vocab_parallel_logits.dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 In-place overwrite of saved logits tensor in ChunkedDistributedLogprob.backward() crashes when use_entropy_loss=True

The new ChunkedDistributedLogprob.backward() overwrites vocab_parallel_logits in-place with gradient data and returns it. This tensor is the original model output logits, which is also saved by _VocabParallelEntropy (via ctx.save_for_backward at model_utils.py:561) when use_entropy_loss=True. Both autograd functions share the same underlying tensor storage (since action_logits = logits[:, -num_actions - 1 : -1, :] at megatron_model_wrapper.py:349 is a view of logits). During backward, whichever function runs second will encounter a PyTorch RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation because ctx.saved_tensors detects the version mismatch caused by the first function's in-place modification.

Previously this was not an issue because chunk_size=None was used, which routed through DistributedLogprob — that class saves softmax_output (a derived tensor) rather than the original logits (model_utils.py:101). The PR switches to chunk_size=16 (megatron_model_wrapper.py:292), forcing all training paths through ChunkedDistributedLogprob, which saves the original logits (model_utils.py:197) and now destructively overwrites them in backward.

Prompt for agents
The in-place write `vocab_parallel_logits[:, chunk_start:chunk_end, :] = chunk_grad.to(vocab_parallel_logits.dtype)` at model_utils.py:240 destructively overwrites the saved logits tensor, which conflicts with _VocabParallelEntropy (model_utils.py:561-572) that also saves a view of the same tensor. When both are active (use_entropy_loss=True), PyTorch's saved-tensor version check causes a RuntimeError in backward.

Possible fixes:
1. Revert to the previous approach of collecting chunk gradients into a list and concatenating them (avoids in-place modification of the saved tensor entirely).
2. Allocate a separate output tensor of the same shape as vocab_parallel_logits and write chunk gradients into it, then return that new tensor.
3. Use `vocab_parallel_logits.clone()` at the start of backward to create a separate buffer for writing gradients, leaving the original saved tensor untouched.

Option 2 is closest to the memory-saving intent of this PR while avoiding the version conflict. The key files involved are model_utils.py (ChunkedDistributedLogprob.backward, lines 217-243) and megatron_model_wrapper.py (loss_func which calls both from_parallel_logits_to_logprobs and vocab_parallel_entropy).
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant