[WIP] Add changes needed for FP8 megatron training#1543
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces several optimizations and fixes for distributed training, including memory-efficient gradient computation, FP8 alignment adjustments, and enhanced sequence parallelism support. Feedback focuses on correcting a potential typo in the CUDA version for nvidia-nvshmem, increasing the chunk_size for log-probability calculations to mitigate communication overhead, and vectorizing loops used for attention mask reconstruction to improve performance.
| "hf_transfer", | ||
| "cloudpathlib>=0.23.0", | ||
| # "deep-ep", | ||
| "nvidia-nvshmem-cu13>=3.6.5; sys_platform == 'linux'", |
| inference_only=True, | ||
| cp_group=None, # we handle cp gathering in `postprocess_packed_seqs` | ||
| chunk_size=None, | ||
| chunk_size=16, |
There was a problem hiding this comment.
The hardcoded chunk_size=16 is extremely small. Since the log-probability computation involves multiple all_reduce operations per chunk (3 per chunk in the current implementation), such a small value will lead to significant communication overhead. Consider increasing this to a larger value (e.g., 128 or 256) to improve performance while maintaining memory efficiency.
| gather_attention_mask = torch.zeros( | ||
| outputs.shape[0], gathered_seq_len, | ||
| dtype=new_attention_mask.dtype, device=new_attention_mask.device | ||
| ) | ||
| seq_lens = new_attention_mask.sum(dim=1) | ||
| for i in range(outputs.shape[0]): | ||
| gather_attention_mask[i, :seq_lens[i]] = 1 |
There was a problem hiding this comment.
This loop for reconstructing the attention mask after gathering can be vectorized to improve performance, especially as the batch size or sequence length increases.
seq_lens = new_attention_mask.sum(dim=1, keepdim=True)
gather_attention_mask = (torch.arange(outputs.shape[1], device=new_attention_mask.device) < seq_lens).to(new_attention_mask.dtype)| inference_only=False, | ||
| cp_group=None, # we handle cp gathering in `postprocess_packed_seqs` | ||
| chunk_size=None, | ||
| chunk_size=16, |
There was a problem hiding this comment.
The hardcoded chunk_size=16 is extremely small. Since the log-probability computation involves multiple all_reduce operations per chunk, such a small value will lead to significant communication overhead. Consider increasing this to a larger value (e.g., 128 or 256) to improve performance while maintaining memory efficiency.
| gather_attention_mask = torch.zeros( | ||
| outputs.shape[0], gathered_seq_len, | ||
| dtype=new_attention_mask.dtype, device=new_attention_mask.device | ||
| ) | ||
| seq_lens = new_attention_mask.sum(dim=1) | ||
| for i in range(outputs.shape[0]): | ||
| gather_attention_mask[i, :seq_lens[i]] = 1 |
There was a problem hiding this comment.
This loop for reconstructing the attention mask after gathering can be vectorized to improve performance, especially as the batch size or sequence length increases.
seq_lens = new_attention_mask.sum(dim=1, keepdim=True)
gather_attention_mask = (torch.arange(outputs.shape[1], device=new_attention_mask.device) < seq_lens).to(new_attention_mask.dtype)| all_grad_input.append(grad_input) | ||
|
|
||
| grad_input = torch.cat(all_grad_input, dim=1) | ||
| vocab_parallel_logits[:, chunk_start:chunk_end, :] = chunk_grad.to(vocab_parallel_logits.dtype) |
There was a problem hiding this comment.
🔴 In-place overwrite of saved logits tensor in ChunkedDistributedLogprob.backward() crashes when use_entropy_loss=True
The new ChunkedDistributedLogprob.backward() overwrites vocab_parallel_logits in-place with gradient data and returns it. This tensor is the original model output logits, which is also saved by _VocabParallelEntropy (via ctx.save_for_backward at model_utils.py:561) when use_entropy_loss=True. Both autograd functions share the same underlying tensor storage (since action_logits = logits[:, -num_actions - 1 : -1, :] at megatron_model_wrapper.py:349 is a view of logits). During backward, whichever function runs second will encounter a PyTorch RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation because ctx.saved_tensors detects the version mismatch caused by the first function's in-place modification.
Previously this was not an issue because chunk_size=None was used, which routed through DistributedLogprob — that class saves softmax_output (a derived tensor) rather than the original logits (model_utils.py:101). The PR switches to chunk_size=16 (megatron_model_wrapper.py:292), forcing all training paths through ChunkedDistributedLogprob, which saves the original logits (model_utils.py:197) and now destructively overwrites them in backward.
Prompt for agents
The in-place write `vocab_parallel_logits[:, chunk_start:chunk_end, :] = chunk_grad.to(vocab_parallel_logits.dtype)` at model_utils.py:240 destructively overwrites the saved logits tensor, which conflicts with _VocabParallelEntropy (model_utils.py:561-572) that also saves a view of the same tensor. When both are active (use_entropy_loss=True), PyTorch's saved-tensor version check causes a RuntimeError in backward.
Possible fixes:
1. Revert to the previous approach of collecting chunk gradients into a list and concatenating them (avoids in-place modification of the saved tensor entirely).
2. Allocate a separate output tensor of the same shape as vocab_parallel_logits and write chunk gradients into it, then return that new tensor.
3. Use `vocab_parallel_logits.clone()` at the start of backward to create a separate buffer for writing gradients, leaving the original saved tensor untouched.
Option 2 is closest to the memory-saving intent of this PR while avoiding the version conflict. The key files involved are model_utils.py (ChunkedDistributedLogprob.backward, lines 217-243) and megatron_model_wrapper.py (loss_func which calls both from_parallel_logits_to_logprobs and vocab_parallel_entropy).
Was this helpful? React with 👍 or 👎 to provide feedback.
Uh oh!
There was an error while loading. Please reload this page.