diff --git a/examples/train_scripts/full_context/trainer_full_ctx.py b/examples/train_scripts/full_context/trainer_full_ctx.py index 6022639d9d..7cbb27a3a5 100644 --- a/examples/train_scripts/full_context/trainer_full_ctx.py +++ b/examples/train_scripts/full_context/trainer_full_ctx.py @@ -1,6 +1,8 @@ -from skyrl.train.trainer import RayPPOTrainer -from loguru import logger import random + +from loguru import logger + +from skyrl.train.trainer import RayPPOTrainer from skyrl.train.utils.utils import Timer @@ -14,6 +16,8 @@ class FullCtxTrainer(RayPPOTrainer): This helps catch OOM issues early before running full training. """ + requires_prompt_batch_alignment = False + async def train(self): """Run a few training steps with max sequence length.""" logger.info("Starting dummy training with max sequence length...") diff --git a/skyrl/train/fully_async_trainer.py b/skyrl/train/fully_async_trainer.py index 2286e57f8d..e689d19e6e 100644 --- a/skyrl/train/fully_async_trainer.py +++ b/skyrl/train/fully_async_trainer.py @@ -262,6 +262,7 @@ def get_consumed_uids_list(self) -> List[str]: class FullyAsyncRayPPOTrainer(RayPPOTrainer): + requires_prompt_batch_alignment = False def __init__(self, *args, **kwargs): # Extract cfg before base init so we can initialize async-specific knobs used by our overrides. diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index 61655f0f65..27327d53f4 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -80,10 +80,16 @@ validate_generator_output, zero_variance_filter, ) -from skyrl.train.utils.utils import ResolvedPlacementGroup, configure_ray_worker_logging +from skyrl.train.utils.utils import ( + ResolvedPlacementGroup, + configure_ray_worker_logging, + get_batch_alignment_info, +) class RayPPOTrainer: + requires_prompt_batch_alignment = True + def __init__( self, cfg: SkyRLTrainConfig, @@ -103,6 +109,8 @@ def __init__( self.eval_dataset = eval_dataset self.inference_engine_client = inference_engine_client self.generator = generator + if self.train_dataset is not None and self.requires_prompt_batch_alignment: + self._validate_prompt_batch_alignment(self.cfg.trainer.train_batch_size) self.train_dataloader = None self.total_training_steps = None self._build_train_dataloader_and_compute_training_steps() @@ -363,8 +371,59 @@ async def train(self): self.tracker.finish() logger.info("Training done!") + def _format_prompt_batch_alignment_error( + self, + prompt_batch_size: int, + kept_prompts: int, + lcm_dp_size: Optional[int] = None, + prompt_alignment_stride: Optional[int] = None, + ) -> str: + alignment_info = get_batch_alignment_info(self.cfg) + lcm_dp_size = lcm_dp_size if lcm_dp_size is not None else alignment_info.lcm_dp_size + prompt_alignment_stride = ( + prompt_alignment_stride + if prompt_alignment_stride is not None + else alignment_info.prompt_alignment_stride + ) + next_valid_batch_size = ( + math.ceil(max(prompt_batch_size, 1) / prompt_alignment_stride) * prompt_alignment_stride + ) + return ( + f"Prompt batch size {prompt_batch_size} is not divisible by the prompt DP alignment stride " + f"{prompt_alignment_stride}. With trainer.train_batch_size={self.cfg.trainer.train_batch_size}, " + f"generator.n_samples_per_prompt={self.cfg.generator.n_samples_per_prompt}, lcm_dp_size={lcm_dp_size} " + f"(policy_dp_size={alignment_info.policy_dp_size}, critic_dp_size={alignment_info.critic_dp_size}, " + f"ref_dp_size={alignment_info.ref_dp_size}), _remove_tail_data would keep {kept_prompts} prompts before " + "generation. Use a prompt batch size that is a multiple of " + f"{prompt_alignment_stride}, for example {next_valid_batch_size}, or adjust n_samples_per_prompt / " + "placement settings so the stride is smaller." + ) + + def _validate_prompt_batch_alignment(self, prompt_batch_size: int) -> None: + alignment_info = get_batch_alignment_info(self.cfg) + stride = alignment_info.prompt_alignment_stride + if prompt_batch_size <= 0: + raise ValueError( + self._format_prompt_batch_alignment_error( + prompt_batch_size, + kept_prompts=0, + lcm_dp_size=alignment_info.lcm_dp_size, + prompt_alignment_stride=stride, + ) + ) + if prompt_batch_size % stride != 0: + kept_prompts = (prompt_batch_size // stride) * stride + raise ValueError( + self._format_prompt_batch_alignment_error( + prompt_batch_size, + kept_prompts=kept_prompts, + lcm_dp_size=alignment_info.lcm_dp_size, + prompt_alignment_stride=stride, + ) + ) + def _remove_tail_data(self, entries: List[Any]) -> List[Any]: - """Remove tail data to have even shards in terms of *effective* samples. + """Validate prompt batch alignment for data-parallel training. Each prompt produces `n_samples_per_prompt` samples. For data-parallel training we care that the total number of samples is nicely splittable @@ -374,18 +433,36 @@ def _remove_tail_data(self, entries: List[Any]) -> List[Any]: n_samples_per_prompt = self.cfg.generator.n_samples_per_prompt - # We want the largest m <= len(entries) such that: - # (m * n_samples_per_prompt) % lcm_dp_size == 0 + # A prompt count is valid only if: + # (len(entries) * n_samples_per_prompt) % lcm_dp_size == 0 # # Let g = gcd(lcm_dp_size, n_samples_per_prompt). Then this is equivalent - # to requiring m to be a multiple of (lcm_dp_size / g). + # to requiring len(entries) to be a multiple of (lcm_dp_size / g). stride = lcm_dp_size // math.gcd(lcm_dp_size, n_samples_per_prompt) + if len(entries) == 0: + raise ValueError( + self._format_prompt_batch_alignment_error( + len(entries), + kept_prompts=0, + lcm_dp_size=lcm_dp_size, + prompt_alignment_stride=stride, + ) + ) if stride <= 1: # Every prompt count is valid, keep all entries. return entries kept_prompts = (len(entries) // stride) * stride - return entries[:kept_prompts] + if kept_prompts != len(entries): + raise ValueError( + self._format_prompt_batch_alignment_error( + len(entries), + kept_prompts=kept_prompts, + lcm_dp_size=lcm_dp_size, + prompt_alignment_stride=stride, + ) + ) + return entries def build_models(self, PolicyWorker, CriticWorker, RefWorker): """ diff --git a/skyrl/train/utils/utils.py b/skyrl/train/utils/utils.py index e7ba602153..51dd9644de 100644 --- a/skyrl/train/utils/utils.py +++ b/skyrl/train/utils/utils.py @@ -6,8 +6,10 @@ import socket import sys import time +from dataclasses import dataclass from datetime import datetime from pathlib import Path +from typing import Optional import ray import torch @@ -55,6 +57,80 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): self.update_dict[self.message] = self.update_dict.get(self.message, 0.0) + time.time() - self.start_time +@dataclass(frozen=True) +class BatchAlignmentInfo: + policy_dp_size: int + critic_dp_size: Optional[int] + ref_dp_size: Optional[int] + lcm_dp_size: int + prompt_alignment_stride: int + + +def _use_ref_model(cfg: SkyRLTrainConfig) -> bool: + return cfg.trainer.algorithm.use_kl_loss or cfg.trainer.algorithm.use_kl_in_reward + + +def _model_world_size(num_nodes: int, num_gpus_per_node: int) -> int: + return num_nodes * num_gpus_per_node + + +def _fsdp_style_dp_size(world_size: int, sequence_parallel_size: int) -> int: + return world_size // sequence_parallel_size + + +def _megatron_dp_size(world_size: int, megatron_config, model_name: str) -> int: + pp = megatron_config.pipeline_model_parallel_size + cp = megatron_config.context_parallel_size + tp = megatron_config.tensor_model_parallel_size + parallel_size = pp * cp * tp + assert world_size % parallel_size == 0, ( + f"{model_name}_world_size {world_size} should be divisible by (pp * cp * tp) {parallel_size}. " + "This ensures that the data parallel size is an integer." + ) + return world_size // parallel_size + + +def _policy_or_ref_dp_size(cfg: SkyRLTrainConfig, model_cfg, world_size: int, model_name: str) -> int: + if cfg.trainer.strategy == "megatron": + return _megatron_dp_size(world_size, model_cfg.megatron_config, model_name) + return _fsdp_style_dp_size(world_size, model_cfg.sequence_parallel_size) + + +def get_batch_alignment_info(cfg: SkyRLTrainConfig) -> BatchAlignmentInfo: + policy_world_size = _model_world_size( + cfg.trainer.placement.policy_num_nodes, cfg.trainer.placement.policy_num_gpus_per_node + ) + policy_dp_size = _policy_or_ref_dp_size(cfg, cfg.trainer.policy, policy_world_size, "policy") + + lcm_dp_size = policy_dp_size + + critic_dp_size = None + if cfg.trainer.critic.model.path is not None: + critic_world_size = _model_world_size( + cfg.trainer.placement.critic_num_nodes, cfg.trainer.placement.critic_num_gpus_per_node + ) + # Megatron critic is rejected by validate_megatron_cfg; keep the existing critic batch math here. + critic_dp_size = _fsdp_style_dp_size(critic_world_size, cfg.trainer.critic.sequence_parallel_size) + lcm_dp_size = math.lcm(lcm_dp_size, critic_dp_size) + + ref_dp_size = None + if _use_ref_model(cfg): + ref_world_size = _model_world_size( + cfg.trainer.placement.ref_num_nodes, cfg.trainer.placement.ref_num_gpus_per_node + ) + ref_dp_size = _policy_or_ref_dp_size(cfg, cfg.trainer.ref, ref_world_size, "ref") + lcm_dp_size = math.lcm(lcm_dp_size, ref_dp_size) + + prompt_alignment_stride = lcm_dp_size // math.gcd(lcm_dp_size, cfg.generator.n_samples_per_prompt) + return BatchAlignmentInfo( + policy_dp_size=policy_dp_size, + critic_dp_size=critic_dp_size, + ref_dp_size=ref_dp_size, + lcm_dp_size=lcm_dp_size, + prompt_alignment_stride=prompt_alignment_stride, + ) + + def validate_batch_sizes(cfg: SkyRLTrainConfig): """ Validate configured batch sizes. @@ -78,20 +154,8 @@ def validate_batch_sizes(cfg: SkyRLTrainConfig): assert cfg.trainer.micro_train_batch_size_per_gpu > 0, "micro_train_batch_size_per_gpu must be greater than 0" assert cfg.trainer.micro_forward_batch_size_per_gpu > 0, "micro_forward_batch_size_per_gpu must be greater than 0" - # Validate policy mini batch size - policy_world_size = cfg.trainer.placement.policy_num_nodes * cfg.trainer.placement.policy_num_gpus_per_node - - if cfg.trainer.strategy == "megatron": - pp = cfg.trainer.policy.megatron_config.pipeline_model_parallel_size - cp = cfg.trainer.policy.megatron_config.context_parallel_size - tp = cfg.trainer.policy.megatron_config.tensor_model_parallel_size - assert policy_world_size % (pp * cp * tp) == 0, ( - f"policy_world_size {policy_world_size} should be divisible by (pp * cp * tp) {pp * cp * tp}. " - "This ensures that the data parallel size is an integer." - ) - policy_dp_size = policy_world_size // (pp * cp * tp) - else: - policy_dp_size = policy_world_size // cfg.trainer.policy.sequence_parallel_size + alignment_info = get_batch_alignment_info(cfg) + policy_dp_size = alignment_info.policy_dp_size assert cfg.trainer.train_batch_size % cfg.trainer.policy_mini_batch_size == 0, ( f"train_batch_size {cfg.trainer.train_batch_size} should be divisible by " @@ -128,11 +192,9 @@ def validate_batch_sizes(cfg: SkyRLTrainConfig): f"(policy_mini_batch_size * n_samples_per_prompt // policy_dp_size) {policy_mini_batch_size_per_gpu}" ) - # Validate critic mini batch size - critic_world_size = cfg.trainer.placement.critic_num_nodes * cfg.trainer.placement.critic_num_gpus_per_node - critic_dp_size = critic_world_size // cfg.trainer.critic.sequence_parallel_size - if cfg.trainer.critic.model.path is not None: + critic_dp_size = alignment_info.critic_dp_size + assert critic_dp_size is not None assert cfg.trainer.train_batch_size % cfg.trainer.critic_mini_batch_size == 0, ( f"train_batch_size {cfg.trainer.train_batch_size} should be divisible by " f"critic_mini_batch_size {cfg.trainer.critic_mini_batch_size}" @@ -163,30 +225,15 @@ def validate_batch_sizes(cfg: SkyRLTrainConfig): f"(critic_mini_batch_size * n_samples_per_prompt // critic_dp_size) {critic_mini_batch_size_per_gpu}" ) - # Validate training batch size is larger than the least common multiple of the DP sizes of policy (and ref if used). - lcm_dp_size = policy_dp_size - - use_ref_model = cfg.trainer.algorithm.use_kl_loss or cfg.trainer.algorithm.use_kl_in_reward - if use_ref_model: - ref_world_size = cfg.trainer.placement.ref_num_nodes * cfg.trainer.placement.ref_num_gpus_per_node - if cfg.trainer.strategy == "megatron": - pp = cfg.trainer.ref.megatron_config.pipeline_model_parallel_size - cp = cfg.trainer.ref.megatron_config.context_parallel_size - tp = cfg.trainer.ref.megatron_config.tensor_model_parallel_size - assert ref_world_size % (pp * cp * tp) == 0, ( - f"ref_world_size {ref_world_size} should be divisible by (pp * cp * tp) {pp * cp * tp}. " - "This ensures that the data parallel size is an integer." - ) - ref_dp_size = ref_world_size // (pp * cp * tp) - else: - ref_dp_size = ref_world_size // cfg.trainer.ref.sequence_parallel_size - lcm_dp_size = math.lcm(lcm_dp_size, ref_dp_size) + # Validate training batch size is larger than the least common multiple of the DP sizes of enabled models. + lcm_dp_size = alignment_info.lcm_dp_size assert cfg.trainer.train_batch_size * cfg.generator.n_samples_per_prompt >= lcm_dp_size, ( f"train_batch_size * n_samples_per_prompt ({cfg.trainer.train_batch_size * cfg.generator.n_samples_per_prompt}) " f"should be larger than or equal to the least common multiple of the data parallel sizes of the enabled models: " - f"policy_dp_size={policy_dp_size}, " - f"ref_dp_size={ref_dp_size if use_ref_model else 'None'}, " + f"policy_dp_size={alignment_info.policy_dp_size}, " + f"critic_dp_size={alignment_info.critic_dp_size}, " + f"ref_dp_size={alignment_info.ref_dp_size}, " f"lcm_dp_size={lcm_dp_size}" ) diff --git a/tests/train/test_trainer.py b/tests/train/test_trainer.py index 20f93495fb..212a10eb1f 100644 --- a/tests/train/test_trainer.py +++ b/tests/train/test_trainer.py @@ -10,12 +10,14 @@ from jaxtyping import Float, Integer from pytest import approx +from examples.train_scripts.full_context.trainer_full_ctx import FullCtxTrainer from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch from skyrl.backends.skyrl_train.workers.worker import CriticWorkerBase, PolicyWorkerBase from skyrl.backends.skyrl_train.workers.worker_utils import BatchIterator from skyrl.train.config import SkyRLTrainConfig +from skyrl.train.fully_async_trainer import FullyAsyncRayPPOTrainer from skyrl.train.trainer import RayPPOTrainer -from skyrl.train.utils.utils import validate_batch_sizes +from skyrl.train.utils.utils import get_batch_alignment_info, validate_batch_sizes from tests.train.util import example_dummy_config @@ -62,6 +64,38 @@ def dummy_generator(): return MagicMock() +def _make_batch_alignment_config( + train_batch_size=2, + policy_dp=6, + ref_dp=6, + critic_dp=1, + include_ref=True, + include_critic=False, + n_samples_per_prompt=4, +) -> SkyRLTrainConfig: + cfg = SkyRLTrainConfig() + cfg.trainer.train_batch_size = train_batch_size + cfg.trainer.policy_mini_batch_size = train_batch_size + cfg.trainer.critic_mini_batch_size = train_batch_size + cfg.trainer.micro_train_batch_size_per_gpu = 1 + cfg.trainer.micro_forward_batch_size_per_gpu = 1 + cfg.trainer.placement.policy_num_nodes = 1 + cfg.trainer.placement.policy_num_gpus_per_node = policy_dp + cfg.trainer.placement.ref_num_nodes = 1 + cfg.trainer.placement.ref_num_gpus_per_node = ref_dp + cfg.trainer.placement.critic_num_nodes = 1 + cfg.trainer.placement.critic_num_gpus_per_node = critic_dp + cfg.trainer.policy.sequence_parallel_size = 1 + cfg.trainer.ref.sequence_parallel_size = 1 + cfg.trainer.critic.sequence_parallel_size = 1 + cfg.trainer.critic.model.path = "critic" if include_critic else None + cfg.trainer.algorithm.use_kl_loss = include_ref + cfg.trainer.algorithm.use_kl_in_reward = False + cfg.trainer.algorithm.policy_loss_type = "regular" + cfg.generator.n_samples_per_prompt = n_samples_per_prompt + return cfg + + def _get_test_data(trainer: RayPPOTrainer): trainer.critic_model = MagicMock() # pretend we're using a critic @@ -621,3 +655,148 @@ def create_config(train_batch_size, policy_dp, ref_dp, include_ref=True): # Pass: ref disabled -> requirement reduces to policy_dp. With policy_dp=2, tbs=2 is valid. cfg = create_config(train_batch_size=2, policy_dp=2, ref_dp=3, include_ref=False) validate_batch_sizes(cfg) + + +def test_batch_alignment_info_matches_issue_1609_repro(): + cfg = _make_batch_alignment_config( + train_batch_size=2, + policy_dp=6, + ref_dp=6, + include_ref=True, + include_critic=False, + n_samples_per_prompt=4, + ) + + alignment_info = get_batch_alignment_info(cfg) + + assert alignment_info.policy_dp_size == 6 + assert alignment_info.critic_dp_size is None + assert alignment_info.ref_dp_size == 6 + assert alignment_info.lcm_dp_size == 6 + assert alignment_info.prompt_alignment_stride == 3 + + # The existing effective-size validation is necessary but not enough: + # train_batch_size * n_samples_per_prompt is 8, which satisfies lcm_dp_size 6. + validate_batch_sizes(cfg) + + +def test_validate_batch_sizes_lcm_includes_critic_dp(): + cfg = _make_batch_alignment_config( + train_batch_size=5, + policy_dp=2, + critic_dp=3, + include_ref=False, + include_critic=True, + n_samples_per_prompt=1, + ) + + with pytest.raises( + AssertionError, + match=r"critic_dp_size=3.*lcm_dp_size=6", + ): + validate_batch_sizes(cfg) + + +def test_ray_trainer_prompt_batch_alignment_rejects_issue_1609(): + cfg = _make_batch_alignment_config( + train_batch_size=2, + policy_dp=6, + ref_dp=6, + include_ref=True, + include_critic=False, + n_samples_per_prompt=4, + ) + + with pytest.raises(ValueError, match="prompt DP alignment stride 3"): + RayPPOTrainer( + cfg=cfg, + tracker=None, + tokenizer=None, + train_dataset=DummyDataset(), + eval_dataset=None, + inference_engine_client=None, + generator=MagicMock(), + ) + + +def test_ray_trainer_prompt_batch_alignment_allows_issue_1609_workaround(): + cfg = _make_batch_alignment_config( + train_batch_size=3, + policy_dp=6, + ref_dp=6, + include_ref=True, + include_critic=False, + n_samples_per_prompt=4, + ) + + trainer = RayPPOTrainer( + cfg=cfg, + tracker=None, + tokenizer=None, + train_dataset=None, + eval_dataset=None, + inference_engine_client=None, + generator=MagicMock(), + ) + + trainer._validate_prompt_batch_alignment(cfg.trainer.train_batch_size) + + +def test_ray_trainer_prompt_batch_alignment_rejects_silent_non_empty_truncation(): + cfg = _make_batch_alignment_config( + train_batch_size=5, + policy_dp=6, + ref_dp=6, + include_ref=True, + include_critic=False, + n_samples_per_prompt=4, + ) + trainer = RayPPOTrainer( + cfg=cfg, + tracker=None, + tokenizer=None, + train_dataset=None, + eval_dataset=None, + inference_engine_client=None, + generator=MagicMock(), + ) + + with pytest.raises(ValueError, match="would keep 3 prompts"): + trainer._validate_prompt_batch_alignment(cfg.trainer.train_batch_size) + + +def test_non_truncating_trainers_do_not_require_sync_prompt_alignment(): + assert RayPPOTrainer.requires_prompt_batch_alignment is True + assert FullyAsyncRayPPOTrainer.requires_prompt_batch_alignment is False + assert FullCtxTrainer.requires_prompt_batch_alignment is False + + +def test_remove_tail_data_rejects_truncation(): + cfg = _make_batch_alignment_config( + train_batch_size=3, + policy_dp=6, + ref_dp=6, + include_ref=True, + include_critic=False, + n_samples_per_prompt=4, + ) + trainer = RayPPOTrainer( + cfg=cfg, + tracker=None, + tokenizer=None, + train_dataset=None, + eval_dataset=None, + inference_engine_client=None, + generator=MagicMock(), + ) + trainer.dispatch = MagicMock() + trainer.dispatch.get_lcm_dp_size.return_value = 6 + + with pytest.raises(ValueError, match="would keep 0 prompts"): + trainer._remove_tail_data(["p0", "p1"]) + with pytest.raises(ValueError, match="would keep 3 prompts"): + trainer._remove_tail_data(["p0", "p1", "p2", "p3", "p4"]) + with pytest.raises(ValueError, match="Prompt batch size 0"): + trainer._remove_tail_data([]) + + assert trainer._remove_tail_data(["p0", "p1", "p2"]) == ["p0", "p1", "p2"]