Skip to content

Fix prompt batch alignment validation for synchronous PPO#1615

Open
taivu1998 wants to merge 1 commit into
NovaSky-AI:mainfrom
taivu1998:tdv/issue-1609-batch-alignment
Open

Fix prompt batch alignment validation for synchronous PPO#1615
taivu1998 wants to merge 1 commit into
NovaSky-AI:mainfrom
taivu1998:tdv/issue-1609-batch-alignment

Conversation

@taivu1998
Copy link
Copy Markdown

Summary

Fixes #1609.

This PR makes synchronous prompt-batch alignment fail fast instead of silently truncating prompt batches before generation. It:

  • adds shared batch-alignment metadata for enabled policy, critic, and reference data-parallel sizes
  • computes the required prompt alignment stride from lcm_dp_size and generator.n_samples_per_prompt
  • validates synchronous RayPPOTrainer prompt batch sizes during trainer construction
  • changes _remove_tail_data into a runtime validation guard with an actionable error message
  • keeps trainers that do not use the synchronous _remove_tail_data path opted out, including fully async and full-context dummy training
  • adds regression coverage for the issue-1609 repro, valid workaround, critic DP inclusion, runtime truncation rejection, and trainer opt-out behavior

Root Cause

RayPPOTrainer._remove_tail_data previously selected the largest prompt prefix whose expanded sample count could shard evenly across the enabled model DP layout. For the issue-1609 configuration, trainer.train_batch_size=2, generator.n_samples_per_prompt=4, and lcm_dp_size=6 imply a prompt alignment stride of 3, so the largest valid prefix for a two-prompt batch is zero prompts. That silently turned a non-empty training batch into an empty generator input and later surfaced as a much less helpful ValueError.

The existing batch-size validation caught only the effective expanded sample lower bound, so 2 * 4 >= 6 passed even though the original prompt count could not be aligned without truncation.

Changes

  • Introduced BatchAlignmentInfo and get_batch_alignment_info() in skyrl/train/utils/utils.py.
  • Reused the shared alignment math in validate_batch_sizes() and included critic DP size in the enabled-model LCM.
  • Added RayPPOTrainer._validate_prompt_batch_alignment() and a detailed formatter that explains the stride, enabled DP sizes, retained prompt count, and next valid batch size.
  • Replaced silent slicing in _remove_tail_data() with explicit validation.
  • Added requires_prompt_batch_alignment so trainer variants that do not consume standard synchronous prompt batches can opt out.
  • Marked FullyAsyncRayPPOTrainer and FullCtxTrainer as opt-outs.
  • Added focused tests in tests/train/test_trainer.py.

Validation

  • git diff --check
  • uv run --isolated --python 3.12 --extra dev --with ruff ruff check skyrl/train/utils/utils.py skyrl/train/trainer.py skyrl/train/fully_async_trainer.py examples/train_scripts/full_context/trainer_full_ctx.py tests/train/test_trainer.py
  • uv run --isolated --python 3.12 --extra dev --extra skyrl-train --with transformers pytest tests/train/test_trainer.py

The final pytest run passed with 14 passed, 1 warning; the warning was Ray's accelerator environment FutureWarning and was unrelated to these changes.

@taivu1998 taivu1998 marked this pull request as ready for review May 3, 2026 06:47
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements strict validation for prompt batch alignment in the RayPPOTrainer to avoid silent data truncation. It introduces a BatchAlignmentInfo utility to centralize data-parallel size calculations and updates the trainer to raise ValueError when batch sizes are not properly aligned with the model's stride. Feedback includes suggestions to refine error message phrasing for zero-sized batches and to improve the clarity of parallel dimension assertions.

Comment thread skyrl/train/trainer.py
Comment on lines +392 to +393
f"Prompt batch size {prompt_batch_size} is not divisible by the prompt DP alignment stride "
f"{prompt_alignment_stride}. With trainer.train_batch_size={self.cfg.trainer.train_batch_size}, "
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The error message states that the prompt batch size is 'not divisible' by the stride. However, this method is also called when prompt_batch_size is 0 (line 405), and 0 is technically divisible by any non-zero integer. While the suggestion for the next valid batch size is correct, the phrasing 'is not divisible' is slightly inaccurate for the zero case. Consider rephrasing to something like 'is invalid or not divisible' or handling the zero case with a specific message.

tp = megatron_config.tensor_model_parallel_size
parallel_size = pp * cp * tp
assert world_size % parallel_size == 0, (
f"{model_name}_world_size {world_size} should be divisible by (pp * cp * tp) {parallel_size}. "
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The f-string here is slightly redundant as it repeats the value of parallel_size. It might be clearer to show the equality, e.g., (pp * cp * tp = {parallel_size}).

Suggested change
f"{model_name}_world_size {world_size} should be divisible by (pp * cp * tp) {parallel_size}. "
f"{model_name}_world_size {world_size} should be divisible by the product of parallel dimensions (pp * cp * tp = {parallel_size}). "

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

RayPPOTrainer._remove_tail_data silently truncates a training batch to empty and throws ValueError

1 participant