Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions examples/train_scripts/full_context/trainer_full_ctx.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from skyrl.train.trainer import RayPPOTrainer
from loguru import logger
import random

from loguru import logger

from skyrl.train.trainer import RayPPOTrainer
from skyrl.train.utils.utils import Timer


Expand All @@ -14,6 +16,8 @@ class FullCtxTrainer(RayPPOTrainer):
This helps catch OOM issues early before running full training.
"""

requires_prompt_batch_alignment = False

async def train(self):
"""Run a few training steps with max sequence length."""
logger.info("Starting dummy training with max sequence length...")
Expand Down
1 change: 1 addition & 0 deletions skyrl/train/fully_async_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def get_consumed_uids_list(self) -> List[str]:


class FullyAsyncRayPPOTrainer(RayPPOTrainer):
requires_prompt_batch_alignment = False

def __init__(self, *args, **kwargs):
# Extract cfg before base init so we can initialize async-specific knobs used by our overrides.
Expand Down
89 changes: 83 additions & 6 deletions skyrl/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,16 @@
validate_generator_output,
zero_variance_filter,
)
from skyrl.train.utils.utils import ResolvedPlacementGroup, configure_ray_worker_logging
from skyrl.train.utils.utils import (
ResolvedPlacementGroup,
configure_ray_worker_logging,
get_batch_alignment_info,
)


class RayPPOTrainer:
requires_prompt_batch_alignment = True

def __init__(
self,
cfg: SkyRLTrainConfig,
Expand All @@ -103,6 +109,8 @@ def __init__(
self.eval_dataset = eval_dataset
self.inference_engine_client = inference_engine_client
self.generator = generator
if self.train_dataset is not None and self.requires_prompt_batch_alignment:
self._validate_prompt_batch_alignment(self.cfg.trainer.train_batch_size)
self.train_dataloader = None
self.total_training_steps = None
self._build_train_dataloader_and_compute_training_steps()
Expand Down Expand Up @@ -363,8 +371,59 @@ async def train(self):
self.tracker.finish()
logger.info("Training done!")

def _format_prompt_batch_alignment_error(
self,
prompt_batch_size: int,
kept_prompts: int,
lcm_dp_size: Optional[int] = None,
prompt_alignment_stride: Optional[int] = None,
) -> str:
alignment_info = get_batch_alignment_info(self.cfg)
lcm_dp_size = lcm_dp_size if lcm_dp_size is not None else alignment_info.lcm_dp_size
prompt_alignment_stride = (
prompt_alignment_stride
if prompt_alignment_stride is not None
else alignment_info.prompt_alignment_stride
)
next_valid_batch_size = (
math.ceil(max(prompt_batch_size, 1) / prompt_alignment_stride) * prompt_alignment_stride
)
return (
f"Prompt batch size {prompt_batch_size} is not divisible by the prompt DP alignment stride "
f"{prompt_alignment_stride}. With trainer.train_batch_size={self.cfg.trainer.train_batch_size}, "
Comment on lines +392 to +393
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The error message states that the prompt batch size is 'not divisible' by the stride. However, this method is also called when prompt_batch_size is 0 (line 405), and 0 is technically divisible by any non-zero integer. While the suggestion for the next valid batch size is correct, the phrasing 'is not divisible' is slightly inaccurate for the zero case. Consider rephrasing to something like 'is invalid or not divisible' or handling the zero case with a specific message.

f"generator.n_samples_per_prompt={self.cfg.generator.n_samples_per_prompt}, lcm_dp_size={lcm_dp_size} "
f"(policy_dp_size={alignment_info.policy_dp_size}, critic_dp_size={alignment_info.critic_dp_size}, "
f"ref_dp_size={alignment_info.ref_dp_size}), _remove_tail_data would keep {kept_prompts} prompts before "
"generation. Use a prompt batch size that is a multiple of "
f"{prompt_alignment_stride}, for example {next_valid_batch_size}, or adjust n_samples_per_prompt / "
"placement settings so the stride is smaller."
)

def _validate_prompt_batch_alignment(self, prompt_batch_size: int) -> None:
alignment_info = get_batch_alignment_info(self.cfg)
stride = alignment_info.prompt_alignment_stride
if prompt_batch_size <= 0:
raise ValueError(
self._format_prompt_batch_alignment_error(
prompt_batch_size,
kept_prompts=0,
lcm_dp_size=alignment_info.lcm_dp_size,
prompt_alignment_stride=stride,
)
)
if prompt_batch_size % stride != 0:
kept_prompts = (prompt_batch_size // stride) * stride
raise ValueError(
self._format_prompt_batch_alignment_error(
prompt_batch_size,
kept_prompts=kept_prompts,
lcm_dp_size=alignment_info.lcm_dp_size,
prompt_alignment_stride=stride,
)
)

def _remove_tail_data(self, entries: List[Any]) -> List[Any]:
"""Remove tail data to have even shards in terms of *effective* samples.
"""Validate prompt batch alignment for data-parallel training.

Each prompt produces `n_samples_per_prompt` samples. For data-parallel
training we care that the total number of samples is nicely splittable
Expand All @@ -374,18 +433,36 @@ def _remove_tail_data(self, entries: List[Any]) -> List[Any]:

n_samples_per_prompt = self.cfg.generator.n_samples_per_prompt

# We want the largest m <= len(entries) such that:
# (m * n_samples_per_prompt) % lcm_dp_size == 0
# A prompt count is valid only if:
# (len(entries) * n_samples_per_prompt) % lcm_dp_size == 0
#
# Let g = gcd(lcm_dp_size, n_samples_per_prompt). Then this is equivalent
# to requiring m to be a multiple of (lcm_dp_size / g).
# to requiring len(entries) to be a multiple of (lcm_dp_size / g).
stride = lcm_dp_size // math.gcd(lcm_dp_size, n_samples_per_prompt)
if len(entries) == 0:
raise ValueError(
self._format_prompt_batch_alignment_error(
len(entries),
kept_prompts=0,
lcm_dp_size=lcm_dp_size,
prompt_alignment_stride=stride,
)
)
if stride <= 1:
# Every prompt count is valid, keep all entries.
return entries

kept_prompts = (len(entries) // stride) * stride
return entries[:kept_prompts]
if kept_prompts != len(entries):
raise ValueError(
self._format_prompt_batch_alignment_error(
len(entries),
kept_prompts=kept_prompts,
lcm_dp_size=lcm_dp_size,
prompt_alignment_stride=stride,
)
)
return entries

def build_models(self, PolicyWorker, CriticWorker, RefWorker):
"""
Expand Down
123 changes: 85 additions & 38 deletions skyrl/train/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import socket
import sys
import time
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Optional

import ray
import torch
Expand Down Expand Up @@ -55,6 +57,80 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
self.update_dict[self.message] = self.update_dict.get(self.message, 0.0) + time.time() - self.start_time


@dataclass(frozen=True)
class BatchAlignmentInfo:
policy_dp_size: int
critic_dp_size: Optional[int]
ref_dp_size: Optional[int]
lcm_dp_size: int
prompt_alignment_stride: int


def _use_ref_model(cfg: SkyRLTrainConfig) -> bool:
return cfg.trainer.algorithm.use_kl_loss or cfg.trainer.algorithm.use_kl_in_reward


def _model_world_size(num_nodes: int, num_gpus_per_node: int) -> int:
return num_nodes * num_gpus_per_node


def _fsdp_style_dp_size(world_size: int, sequence_parallel_size: int) -> int:
return world_size // sequence_parallel_size


def _megatron_dp_size(world_size: int, megatron_config, model_name: str) -> int:
pp = megatron_config.pipeline_model_parallel_size
cp = megatron_config.context_parallel_size
tp = megatron_config.tensor_model_parallel_size
parallel_size = pp * cp * tp
assert world_size % parallel_size == 0, (
f"{model_name}_world_size {world_size} should be divisible by (pp * cp * tp) {parallel_size}. "
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The f-string here is slightly redundant as it repeats the value of parallel_size. It might be clearer to show the equality, e.g., (pp * cp * tp = {parallel_size}).

Suggested change
f"{model_name}_world_size {world_size} should be divisible by (pp * cp * tp) {parallel_size}. "
f"{model_name}_world_size {world_size} should be divisible by the product of parallel dimensions (pp * cp * tp = {parallel_size}). "

"This ensures that the data parallel size is an integer."
)
return world_size // parallel_size


def _policy_or_ref_dp_size(cfg: SkyRLTrainConfig, model_cfg, world_size: int, model_name: str) -> int:
if cfg.trainer.strategy == "megatron":
return _megatron_dp_size(world_size, model_cfg.megatron_config, model_name)
return _fsdp_style_dp_size(world_size, model_cfg.sequence_parallel_size)


def get_batch_alignment_info(cfg: SkyRLTrainConfig) -> BatchAlignmentInfo:
policy_world_size = _model_world_size(
cfg.trainer.placement.policy_num_nodes, cfg.trainer.placement.policy_num_gpus_per_node
)
policy_dp_size = _policy_or_ref_dp_size(cfg, cfg.trainer.policy, policy_world_size, "policy")

lcm_dp_size = policy_dp_size

critic_dp_size = None
if cfg.trainer.critic.model.path is not None:
critic_world_size = _model_world_size(
cfg.trainer.placement.critic_num_nodes, cfg.trainer.placement.critic_num_gpus_per_node
)
# Megatron critic is rejected by validate_megatron_cfg; keep the existing critic batch math here.
critic_dp_size = _fsdp_style_dp_size(critic_world_size, cfg.trainer.critic.sequence_parallel_size)
lcm_dp_size = math.lcm(lcm_dp_size, critic_dp_size)

ref_dp_size = None
if _use_ref_model(cfg):
ref_world_size = _model_world_size(
cfg.trainer.placement.ref_num_nodes, cfg.trainer.placement.ref_num_gpus_per_node
)
ref_dp_size = _policy_or_ref_dp_size(cfg, cfg.trainer.ref, ref_world_size, "ref")
lcm_dp_size = math.lcm(lcm_dp_size, ref_dp_size)

prompt_alignment_stride = lcm_dp_size // math.gcd(lcm_dp_size, cfg.generator.n_samples_per_prompt)
return BatchAlignmentInfo(
policy_dp_size=policy_dp_size,
critic_dp_size=critic_dp_size,
ref_dp_size=ref_dp_size,
lcm_dp_size=lcm_dp_size,
prompt_alignment_stride=prompt_alignment_stride,
)


def validate_batch_sizes(cfg: SkyRLTrainConfig):
"""
Validate configured batch sizes.
Expand All @@ -78,20 +154,8 @@ def validate_batch_sizes(cfg: SkyRLTrainConfig):
assert cfg.trainer.micro_train_batch_size_per_gpu > 0, "micro_train_batch_size_per_gpu must be greater than 0"
assert cfg.trainer.micro_forward_batch_size_per_gpu > 0, "micro_forward_batch_size_per_gpu must be greater than 0"

# Validate policy mini batch size
policy_world_size = cfg.trainer.placement.policy_num_nodes * cfg.trainer.placement.policy_num_gpus_per_node

if cfg.trainer.strategy == "megatron":
pp = cfg.trainer.policy.megatron_config.pipeline_model_parallel_size
cp = cfg.trainer.policy.megatron_config.context_parallel_size
tp = cfg.trainer.policy.megatron_config.tensor_model_parallel_size
assert policy_world_size % (pp * cp * tp) == 0, (
f"policy_world_size {policy_world_size} should be divisible by (pp * cp * tp) {pp * cp * tp}. "
"This ensures that the data parallel size is an integer."
)
policy_dp_size = policy_world_size // (pp * cp * tp)
else:
policy_dp_size = policy_world_size // cfg.trainer.policy.sequence_parallel_size
alignment_info = get_batch_alignment_info(cfg)
policy_dp_size = alignment_info.policy_dp_size

assert cfg.trainer.train_batch_size % cfg.trainer.policy_mini_batch_size == 0, (
f"train_batch_size {cfg.trainer.train_batch_size} should be divisible by "
Expand Down Expand Up @@ -128,11 +192,9 @@ def validate_batch_sizes(cfg: SkyRLTrainConfig):
f"(policy_mini_batch_size * n_samples_per_prompt // policy_dp_size) {policy_mini_batch_size_per_gpu}"
)

# Validate critic mini batch size
critic_world_size = cfg.trainer.placement.critic_num_nodes * cfg.trainer.placement.critic_num_gpus_per_node
critic_dp_size = critic_world_size // cfg.trainer.critic.sequence_parallel_size

if cfg.trainer.critic.model.path is not None:
critic_dp_size = alignment_info.critic_dp_size
assert critic_dp_size is not None
assert cfg.trainer.train_batch_size % cfg.trainer.critic_mini_batch_size == 0, (
f"train_batch_size {cfg.trainer.train_batch_size} should be divisible by "
f"critic_mini_batch_size {cfg.trainer.critic_mini_batch_size}"
Expand Down Expand Up @@ -163,30 +225,15 @@ def validate_batch_sizes(cfg: SkyRLTrainConfig):
f"(critic_mini_batch_size * n_samples_per_prompt // critic_dp_size) {critic_mini_batch_size_per_gpu}"
)

# Validate training batch size is larger than the least common multiple of the DP sizes of policy (and ref if used).
lcm_dp_size = policy_dp_size

use_ref_model = cfg.trainer.algorithm.use_kl_loss or cfg.trainer.algorithm.use_kl_in_reward
if use_ref_model:
ref_world_size = cfg.trainer.placement.ref_num_nodes * cfg.trainer.placement.ref_num_gpus_per_node
if cfg.trainer.strategy == "megatron":
pp = cfg.trainer.ref.megatron_config.pipeline_model_parallel_size
cp = cfg.trainer.ref.megatron_config.context_parallel_size
tp = cfg.trainer.ref.megatron_config.tensor_model_parallel_size
assert ref_world_size % (pp * cp * tp) == 0, (
f"ref_world_size {ref_world_size} should be divisible by (pp * cp * tp) {pp * cp * tp}. "
"This ensures that the data parallel size is an integer."
)
ref_dp_size = ref_world_size // (pp * cp * tp)
else:
ref_dp_size = ref_world_size // cfg.trainer.ref.sequence_parallel_size
lcm_dp_size = math.lcm(lcm_dp_size, ref_dp_size)
# Validate training batch size is larger than the least common multiple of the DP sizes of enabled models.
lcm_dp_size = alignment_info.lcm_dp_size

assert cfg.trainer.train_batch_size * cfg.generator.n_samples_per_prompt >= lcm_dp_size, (
f"train_batch_size * n_samples_per_prompt ({cfg.trainer.train_batch_size * cfg.generator.n_samples_per_prompt}) "
f"should be larger than or equal to the least common multiple of the data parallel sizes of the enabled models: "
f"policy_dp_size={policy_dp_size}, "
f"ref_dp_size={ref_dp_size if use_ref_model else 'None'}, "
f"policy_dp_size={alignment_info.policy_dp_size}, "
f"critic_dp_size={alignment_info.critic_dp_size}, "
f"ref_dp_size={alignment_info.ref_dp_size}, "
f"lcm_dp_size={lcm_dp_size}"
)

Expand Down
Loading
Loading