Fix prompt batch alignment validation for synchronous PPO#1615
Fix prompt batch alignment validation for synchronous PPO#1615taivu1998 wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements strict validation for prompt batch alignment in the RayPPOTrainer to avoid silent data truncation. It introduces a BatchAlignmentInfo utility to centralize data-parallel size calculations and updates the trainer to raise ValueError when batch sizes are not properly aligned with the model's stride. Feedback includes suggestions to refine error message phrasing for zero-sized batches and to improve the clarity of parallel dimension assertions.
| f"Prompt batch size {prompt_batch_size} is not divisible by the prompt DP alignment stride " | ||
| f"{prompt_alignment_stride}. With trainer.train_batch_size={self.cfg.trainer.train_batch_size}, " |
There was a problem hiding this comment.
The error message states that the prompt batch size is 'not divisible' by the stride. However, this method is also called when prompt_batch_size is 0 (line 405), and 0 is technically divisible by any non-zero integer. While the suggestion for the next valid batch size is correct, the phrasing 'is not divisible' is slightly inaccurate for the zero case. Consider rephrasing to something like 'is invalid or not divisible' or handling the zero case with a specific message.
| tp = megatron_config.tensor_model_parallel_size | ||
| parallel_size = pp * cp * tp | ||
| assert world_size % parallel_size == 0, ( | ||
| f"{model_name}_world_size {world_size} should be divisible by (pp * cp * tp) {parallel_size}. " |
There was a problem hiding this comment.
The f-string here is slightly redundant as it repeats the value of parallel_size. It might be clearer to show the equality, e.g., (pp * cp * tp = {parallel_size}).
| f"{model_name}_world_size {world_size} should be divisible by (pp * cp * tp) {parallel_size}. " | |
| f"{model_name}_world_size {world_size} should be divisible by the product of parallel dimensions (pp * cp * tp = {parallel_size}). " |
Summary
Fixes #1609.
This PR makes synchronous prompt-batch alignment fail fast instead of silently truncating prompt batches before generation. It:
lcm_dp_sizeandgenerator.n_samples_per_promptRayPPOTrainerprompt batch sizes during trainer construction_remove_tail_datainto a runtime validation guard with an actionable error message_remove_tail_datapath opted out, including fully async and full-context dummy trainingRoot Cause
RayPPOTrainer._remove_tail_datapreviously selected the largest prompt prefix whose expanded sample count could shard evenly across the enabled model DP layout. For the issue-1609 configuration,trainer.train_batch_size=2,generator.n_samples_per_prompt=4, andlcm_dp_size=6imply a prompt alignment stride of3, so the largest valid prefix for a two-prompt batch is zero prompts. That silently turned a non-empty training batch into an empty generator input and later surfaced as a much less helpfulValueError.The existing batch-size validation caught only the effective expanded sample lower bound, so
2 * 4 >= 6passed even though the original prompt count could not be aligned without truncation.Changes
BatchAlignmentInfoandget_batch_alignment_info()inskyrl/train/utils/utils.py.validate_batch_sizes()and included critic DP size in the enabled-model LCM.RayPPOTrainer._validate_prompt_batch_alignment()and a detailed formatter that explains the stride, enabled DP sizes, retained prompt count, and next valid batch size._remove_tail_data()with explicit validation.requires_prompt_batch_alignmentso trainer variants that do not consume standard synchronous prompt batches can opt out.FullyAsyncRayPPOTrainerandFullCtxTraineras opt-outs.tests/train/test_trainer.py.Validation
git diff --checkuv run --isolated --python 3.12 --extra dev --with ruff ruff check skyrl/train/utils/utils.py skyrl/train/trainer.py skyrl/train/fully_async_trainer.py examples/train_scripts/full_context/trainer_full_ctx.py tests/train/test_trainer.pyuv run --isolated --python 3.12 --extra dev --extra skyrl-train --with transformers pytest tests/train/test_trainer.pyThe final pytest run passed with
14 passed, 1 warning; the warning was Ray's accelerator environmentFutureWarningand was unrelated to these changes.