-
Notifications
You must be signed in to change notification settings - Fork 323
Fix prompt batch alignment validation for synchronous PPO #1615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
taivu1998
wants to merge
1
commit into
NovaSky-AI:main
Choose a base branch
from
taivu1998:tdv/issue-1609-batch-alignment
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -6,8 +6,10 @@ | |||||
| import socket | ||||||
| import sys | ||||||
| import time | ||||||
| from dataclasses import dataclass | ||||||
| from datetime import datetime | ||||||
| from pathlib import Path | ||||||
| from typing import Optional | ||||||
|
|
||||||
| import ray | ||||||
| import torch | ||||||
|
|
@@ -55,6 +57,80 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): | |||||
| self.update_dict[self.message] = self.update_dict.get(self.message, 0.0) + time.time() - self.start_time | ||||||
|
|
||||||
|
|
||||||
| @dataclass(frozen=True) | ||||||
| class BatchAlignmentInfo: | ||||||
| policy_dp_size: int | ||||||
| critic_dp_size: Optional[int] | ||||||
| ref_dp_size: Optional[int] | ||||||
| lcm_dp_size: int | ||||||
| prompt_alignment_stride: int | ||||||
|
|
||||||
|
|
||||||
| def _use_ref_model(cfg: SkyRLTrainConfig) -> bool: | ||||||
| return cfg.trainer.algorithm.use_kl_loss or cfg.trainer.algorithm.use_kl_in_reward | ||||||
|
|
||||||
|
|
||||||
| def _model_world_size(num_nodes: int, num_gpus_per_node: int) -> int: | ||||||
| return num_nodes * num_gpus_per_node | ||||||
|
|
||||||
|
|
||||||
| def _fsdp_style_dp_size(world_size: int, sequence_parallel_size: int) -> int: | ||||||
| return world_size // sequence_parallel_size | ||||||
|
|
||||||
|
|
||||||
| def _megatron_dp_size(world_size: int, megatron_config, model_name: str) -> int: | ||||||
| pp = megatron_config.pipeline_model_parallel_size | ||||||
| cp = megatron_config.context_parallel_size | ||||||
| tp = megatron_config.tensor_model_parallel_size | ||||||
| parallel_size = pp * cp * tp | ||||||
| assert world_size % parallel_size == 0, ( | ||||||
| f"{model_name}_world_size {world_size} should be divisible by (pp * cp * tp) {parallel_size}. " | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The f-string here is slightly redundant as it repeats the value of
Suggested change
|
||||||
| "This ensures that the data parallel size is an integer." | ||||||
| ) | ||||||
| return world_size // parallel_size | ||||||
|
|
||||||
|
|
||||||
| def _policy_or_ref_dp_size(cfg: SkyRLTrainConfig, model_cfg, world_size: int, model_name: str) -> int: | ||||||
| if cfg.trainer.strategy == "megatron": | ||||||
| return _megatron_dp_size(world_size, model_cfg.megatron_config, model_name) | ||||||
| return _fsdp_style_dp_size(world_size, model_cfg.sequence_parallel_size) | ||||||
|
|
||||||
|
|
||||||
| def get_batch_alignment_info(cfg: SkyRLTrainConfig) -> BatchAlignmentInfo: | ||||||
| policy_world_size = _model_world_size( | ||||||
| cfg.trainer.placement.policy_num_nodes, cfg.trainer.placement.policy_num_gpus_per_node | ||||||
| ) | ||||||
| policy_dp_size = _policy_or_ref_dp_size(cfg, cfg.trainer.policy, policy_world_size, "policy") | ||||||
|
|
||||||
| lcm_dp_size = policy_dp_size | ||||||
|
|
||||||
| critic_dp_size = None | ||||||
| if cfg.trainer.critic.model.path is not None: | ||||||
| critic_world_size = _model_world_size( | ||||||
| cfg.trainer.placement.critic_num_nodes, cfg.trainer.placement.critic_num_gpus_per_node | ||||||
| ) | ||||||
| # Megatron critic is rejected by validate_megatron_cfg; keep the existing critic batch math here. | ||||||
| critic_dp_size = _fsdp_style_dp_size(critic_world_size, cfg.trainer.critic.sequence_parallel_size) | ||||||
| lcm_dp_size = math.lcm(lcm_dp_size, critic_dp_size) | ||||||
|
|
||||||
| ref_dp_size = None | ||||||
| if _use_ref_model(cfg): | ||||||
| ref_world_size = _model_world_size( | ||||||
| cfg.trainer.placement.ref_num_nodes, cfg.trainer.placement.ref_num_gpus_per_node | ||||||
| ) | ||||||
| ref_dp_size = _policy_or_ref_dp_size(cfg, cfg.trainer.ref, ref_world_size, "ref") | ||||||
| lcm_dp_size = math.lcm(lcm_dp_size, ref_dp_size) | ||||||
|
|
||||||
| prompt_alignment_stride = lcm_dp_size // math.gcd(lcm_dp_size, cfg.generator.n_samples_per_prompt) | ||||||
| return BatchAlignmentInfo( | ||||||
| policy_dp_size=policy_dp_size, | ||||||
| critic_dp_size=critic_dp_size, | ||||||
| ref_dp_size=ref_dp_size, | ||||||
| lcm_dp_size=lcm_dp_size, | ||||||
| prompt_alignment_stride=prompt_alignment_stride, | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| def validate_batch_sizes(cfg: SkyRLTrainConfig): | ||||||
| """ | ||||||
| Validate configured batch sizes. | ||||||
|
|
@@ -78,20 +154,8 @@ def validate_batch_sizes(cfg: SkyRLTrainConfig): | |||||
| assert cfg.trainer.micro_train_batch_size_per_gpu > 0, "micro_train_batch_size_per_gpu must be greater than 0" | ||||||
| assert cfg.trainer.micro_forward_batch_size_per_gpu > 0, "micro_forward_batch_size_per_gpu must be greater than 0" | ||||||
|
|
||||||
| # Validate policy mini batch size | ||||||
| policy_world_size = cfg.trainer.placement.policy_num_nodes * cfg.trainer.placement.policy_num_gpus_per_node | ||||||
|
|
||||||
| if cfg.trainer.strategy == "megatron": | ||||||
| pp = cfg.trainer.policy.megatron_config.pipeline_model_parallel_size | ||||||
| cp = cfg.trainer.policy.megatron_config.context_parallel_size | ||||||
| tp = cfg.trainer.policy.megatron_config.tensor_model_parallel_size | ||||||
| assert policy_world_size % (pp * cp * tp) == 0, ( | ||||||
| f"policy_world_size {policy_world_size} should be divisible by (pp * cp * tp) {pp * cp * tp}. " | ||||||
| "This ensures that the data parallel size is an integer." | ||||||
| ) | ||||||
| policy_dp_size = policy_world_size // (pp * cp * tp) | ||||||
| else: | ||||||
| policy_dp_size = policy_world_size // cfg.trainer.policy.sequence_parallel_size | ||||||
| alignment_info = get_batch_alignment_info(cfg) | ||||||
| policy_dp_size = alignment_info.policy_dp_size | ||||||
|
|
||||||
| assert cfg.trainer.train_batch_size % cfg.trainer.policy_mini_batch_size == 0, ( | ||||||
| f"train_batch_size {cfg.trainer.train_batch_size} should be divisible by " | ||||||
|
|
@@ -128,11 +192,9 @@ def validate_batch_sizes(cfg: SkyRLTrainConfig): | |||||
| f"(policy_mini_batch_size * n_samples_per_prompt // policy_dp_size) {policy_mini_batch_size_per_gpu}" | ||||||
| ) | ||||||
|
|
||||||
| # Validate critic mini batch size | ||||||
| critic_world_size = cfg.trainer.placement.critic_num_nodes * cfg.trainer.placement.critic_num_gpus_per_node | ||||||
| critic_dp_size = critic_world_size // cfg.trainer.critic.sequence_parallel_size | ||||||
|
|
||||||
| if cfg.trainer.critic.model.path is not None: | ||||||
| critic_dp_size = alignment_info.critic_dp_size | ||||||
| assert critic_dp_size is not None | ||||||
| assert cfg.trainer.train_batch_size % cfg.trainer.critic_mini_batch_size == 0, ( | ||||||
| f"train_batch_size {cfg.trainer.train_batch_size} should be divisible by " | ||||||
| f"critic_mini_batch_size {cfg.trainer.critic_mini_batch_size}" | ||||||
|
|
@@ -163,30 +225,15 @@ def validate_batch_sizes(cfg: SkyRLTrainConfig): | |||||
| f"(critic_mini_batch_size * n_samples_per_prompt // critic_dp_size) {critic_mini_batch_size_per_gpu}" | ||||||
| ) | ||||||
|
|
||||||
| # Validate training batch size is larger than the least common multiple of the DP sizes of policy (and ref if used). | ||||||
| lcm_dp_size = policy_dp_size | ||||||
|
|
||||||
| use_ref_model = cfg.trainer.algorithm.use_kl_loss or cfg.trainer.algorithm.use_kl_in_reward | ||||||
| if use_ref_model: | ||||||
| ref_world_size = cfg.trainer.placement.ref_num_nodes * cfg.trainer.placement.ref_num_gpus_per_node | ||||||
| if cfg.trainer.strategy == "megatron": | ||||||
| pp = cfg.trainer.ref.megatron_config.pipeline_model_parallel_size | ||||||
| cp = cfg.trainer.ref.megatron_config.context_parallel_size | ||||||
| tp = cfg.trainer.ref.megatron_config.tensor_model_parallel_size | ||||||
| assert ref_world_size % (pp * cp * tp) == 0, ( | ||||||
| f"ref_world_size {ref_world_size} should be divisible by (pp * cp * tp) {pp * cp * tp}. " | ||||||
| "This ensures that the data parallel size is an integer." | ||||||
| ) | ||||||
| ref_dp_size = ref_world_size // (pp * cp * tp) | ||||||
| else: | ||||||
| ref_dp_size = ref_world_size // cfg.trainer.ref.sequence_parallel_size | ||||||
| lcm_dp_size = math.lcm(lcm_dp_size, ref_dp_size) | ||||||
| # Validate training batch size is larger than the least common multiple of the DP sizes of enabled models. | ||||||
| lcm_dp_size = alignment_info.lcm_dp_size | ||||||
|
|
||||||
| assert cfg.trainer.train_batch_size * cfg.generator.n_samples_per_prompt >= lcm_dp_size, ( | ||||||
| f"train_batch_size * n_samples_per_prompt ({cfg.trainer.train_batch_size * cfg.generator.n_samples_per_prompt}) " | ||||||
| f"should be larger than or equal to the least common multiple of the data parallel sizes of the enabled models: " | ||||||
| f"policy_dp_size={policy_dp_size}, " | ||||||
| f"ref_dp_size={ref_dp_size if use_ref_model else 'None'}, " | ||||||
| f"policy_dp_size={alignment_info.policy_dp_size}, " | ||||||
| f"critic_dp_size={alignment_info.critic_dp_size}, " | ||||||
| f"ref_dp_size={alignment_info.ref_dp_size}, " | ||||||
| f"lcm_dp_size={lcm_dp_size}" | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message states that the prompt batch size is 'not divisible' by the stride. However, this method is also called when
prompt_batch_sizeis 0 (line 405), and 0 is technically divisible by any non-zero integer. While the suggestion for the next valid batch size is correct, the phrasing 'is not divisible' is slightly inaccurate for the zero case. Consider rephrasing to something like 'is invalid or not divisible' or handling the zero case with a specific message.