-
Notifications
You must be signed in to change notification settings - Fork 323
[skyrl] Preserve staged forward_backward loss_fn_outputs across DP ranks #1534
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -8,8 +8,10 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||
| The trainer interacts with the worker dispatch if all models are always on GPU. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| from __future__ import annotations | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| from dataclasses import dataclass | ||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Any, Dict, List, Optional, Tuple | ||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| import ray | ||||||||||||||||||||||||||||||||||||||||||||||||||
| from ray import ObjectRef | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -18,16 +20,18 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||
| MeshDispatch, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| concatenate_outputs_after_mesh_dispatch, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| from skyrl.backends.skyrl_train.inference_engines.inference_engine_client import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| InferenceEngineClient, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| from skyrl.backends.skyrl_train.training_batch import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| TrainingInputBatch, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| TrainingOutputBatch, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| from skyrl.backends.skyrl_train.workers.worker import PPORayActorGroup | ||||||||||||||||||||||||||||||||||||||||||||||||||
| from skyrl.train.config import SkyRLTrainConfig | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| if TYPE_CHECKING: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| from skyrl.backends.skyrl_train.inference_engines.inference_engine_client import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| InferenceEngineClient, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| from skyrl.backends.skyrl_train.workers.worker import PPORayActorGroup | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| @dataclass | ||||||||||||||||||||||||||||||||||||||||||||||||||
| class GPUState: | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -37,6 +41,36 @@ class GPUState: | |||||||||||||||||||||||||||||||||||||||||||||||||
| optimizer_on_gpu: bool = False | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| LossFnOutput = dict[str, Any] | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ForwardBackwardStatusValue = float | int | list[LossFnOutput] | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ForwardBackwardStatus = dict[str, ForwardBackwardStatusValue] | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+44
to
+46
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The use of the
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def _merge_forward_backward_statuses(statuses: List[ForwardBackwardStatus]) -> ForwardBackwardStatus: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """Normalize forward/backward statuses from per-rank results into one dict. | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| Scalar metrics are already reduced across data-parallel ranks on the workers, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| so they are taken from the first returned status. If workers emit | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ``loss_fn_outputs``, concatenate them in rank order to reconstruct the full | ||||||||||||||||||||||||||||||||||||||||||||||||||
| logical batch. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| assert statuses, "Expected at least one status from forward_backward" | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| result = dict(statuses[0]) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if not any("loss_fn_outputs" in status for status in statuses): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| return result | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| all_loss_fn_outputs: list[LossFnOutput] = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||
| for status in statuses: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| loss_fn_outputs = status.get("loss_fn_outputs") | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(loss_fn_outputs, list): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| all_loss_fn_outputs.extend(loss_fn_outputs) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| result["loss_fn_outputs"] = all_loss_fn_outputs | ||||||||||||||||||||||||||||||||||||||||||||||||||
| return result | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+60
to
+71
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This implementation can be optimized for the common case where only one status is returned (e.g., when
Suggested change
Comment on lines
+49
to
+71
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current implementation of To fix this, the merging logic should only collect outputs from primary ranks (where def _merge_forward_backward_statuses(
actor_infos: List[ActorInfo], statuses: List[ForwardBackwardStatus]
) -> ForwardBackwardStatus:
"""Normalize forward/backward statuses from per-rank results into one dict.
Scalar metrics are already reduced across data-parallel ranks on the workers,
so they are taken from the first returned status. If workers emit
``loss_fn_outputs``, concatenate them in rank order from primary ranks
to reconstruct the full logical batch without duplicates from TP/SP/PP.
"""
assert len(actor_infos) == len(statuses)
assert statuses, "Expected at least one status from forward_backward"
# Use the first status for scalar metrics (assumed all-reduced)
result = dict(statuses[0])
result.pop("loss_fn_outputs", None)
# Collect loss_fn_outputs in DP rank order from collection ranks only
dp_rank_to_outputs: Dict[int, List[LossFnOutput]] = {}
for actor_info, status in zip(actor_infos, statuses):
if actor_info.rank.is_collection_dp_rank():
outputs = status.get("loss_fn_outputs")
if isinstance(outputs, list):
dp_rank_to_outputs[actor_info.rank.dp] = outputs
if dp_rank_to_outputs:
all_loss_fn_outputs: List[LossFnOutput] = []
# Ensure we concatenate in DP rank order
for i in range(actor_infos[0].rank.dp_size):
if i in dp_rank_to_outputs:
all_loss_fn_outputs.extend(dp_rank_to_outputs[i])
result["loss_fn_outputs"] = all_loss_fn_outputs
return result |
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| class WorkerDispatch: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Unified dispatch layer that manages all actor groups (policy, critic, ref). | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -191,7 +225,7 @@ def forward_backward( | |||||||||||||||||||||||||||||||||||||||||||||||||
| data: TrainingInputBatch, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| loss_fn: Optional[str] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| loss_fn_config: Optional[Dict[str, Any]] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> Dict[str, float]: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> ForwardBackwardStatus: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """Run forward/backward pass. Needs model + optimizer. | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -203,7 +237,7 @@ def forward_backward( | |||||||||||||||||||||||||||||||||||||||||||||||||
| (e.g., {"clip_low_threshold": 0.9} for PPO) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Dictionary of training metrics | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Reduced scalar metrics and optional ``loss_fn_outputs`` for the full batch. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self._ensure_on_gpu(model, need_optimizer=True, need_model=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -218,27 +252,15 @@ def forward_backward( | |||||||||||||||||||||||||||||||||||||||||||||||||
| statuses = ray.get(refs) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| self._save_memory_snapshot(model, "forward_backward") | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| # With DP>1, each rank returns loss_fn_outputs for its data chunk. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # Concatenate them in rank order to get the full batch's outputs. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # Scalar metrics (loss, lr) are already all-reduced, so use statuses[0] for those. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if len(statuses) > 1 and statuses[0] and "loss_fn_outputs" in statuses[0]: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| all_loss_fn_outputs = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||
| for status in statuses: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| all_loss_fn_outputs.extend(status.pop("loss_fn_outputs", [])) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| result = statuses[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||
| result["loss_fn_outputs"] = all_loss_fn_outputs | ||||||||||||||||||||||||||||||||||||||||||||||||||
| return result | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| return statuses[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||
| return _merge_forward_backward_statuses(statuses) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def forward_backward_from_staged( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| model: str, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| chunk_refs: List[ObjectRef], | ||||||||||||||||||||||||||||||||||||||||||||||||||
| loss_fn: Optional[str] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| loss_fn_config: Optional[Dict[str, Any]] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> Dict[str, float]: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> ForwardBackwardStatus: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Run forward/backward pass using pre-staged per-DP chunks. | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -250,7 +272,7 @@ def forward_backward_from_staged( | |||||||||||||||||||||||||||||||||||||||||||||||||
| chunk_refs: Pre-staged ObjectRefs, one per DP rank (from ``stage_data``) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Aggregated metrics dict from training | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Reduced scalar metrics and optional ``loss_fn_outputs`` for the full batch. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self._ensure_on_gpu(model, need_optimizer=True, need_model=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -270,7 +292,7 @@ def forward_backward_from_staged( | |||||||||||||||||||||||||||||||||||||||||||||||||
| statuses = ray.get(refs) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| self._save_memory_snapshot(model, "forward_backward") | ||||||||||||||||||||||||||||||||||||||||||||||||||
| return statuses[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||
| return _merge_forward_backward_statuses(statuses) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def optim_step(self, model: str) -> Optional[float]: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """Run optimizer step. Model should already be on GPU from forward_backward.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,103 @@ | ||
| """ | ||
| Tests for WorkerDispatch forward/backward status aggregation. | ||
|
|
||
| uv run --isolated --extra skyrl-train --extra dev pytest tests/backends/skyrl_train/workers/test_worker_dispatch.py | ||
| """ | ||
|
|
||
| import ray | ||
| import torch | ||
|
|
||
| from skyrl.backends.skyrl_train.distributed.dispatch import ( | ||
| ActorInfo, | ||
| MeshDispatch, | ||
| MeshRank, | ||
| ) | ||
| from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch | ||
| from skyrl.backends.skyrl_train.workers.worker_dispatch import WorkerDispatch | ||
| from skyrl.train.config import SkyRLTrainConfig | ||
|
|
||
|
|
||
| @ray.remote | ||
| class FakeForwardBackwardWorker: | ||
| def __init__(self, rank: int, dp_rank: int): | ||
| self.rank = rank | ||
| self.dp_rank = dp_rank | ||
|
|
||
| def forward_backward(self, data: TrainingInputBatch, loss_fn=None, loss_fn_config=None): | ||
| del loss_fn_config | ||
|
|
||
| status = { | ||
| "policy_loss": 1.25, | ||
| "policy_lr": 0.5, | ||
| } | ||
| if loss_fn == "scalar_only": | ||
| return status | ||
|
|
||
| status["loss_fn_outputs"] = [ | ||
| {"sample_id": sample_id, "dp_rank": self.dp_rank} for sample_id in data["sample_id"].tolist() | ||
| ] | ||
| return status | ||
|
|
||
| def save_memory_snapshot(self, tag: str): | ||
| return tag | ||
|
|
||
|
|
||
| class StubActorGroup: | ||
| def __init__(self, dp_size: int = 2): | ||
| self.actors = [FakeForwardBackwardWorker.remote(rank=i, dp_rank=i) for i in range(dp_size)] | ||
| self.actor_infos = [ | ||
| ActorInfo( | ||
| actor, | ||
| MeshRank(dp=i, sp=0, tp=0, pp=0, world_size=dp_size, dp_size=dp_size, pp_size=1), | ||
| ) | ||
| for i, actor in enumerate(self.actors) | ||
| ] | ||
|
|
||
| def async_run_ray_method(self, dispatch_type: str, method_name: str, *args, **kwargs): | ||
| if dispatch_type == "mesh": | ||
| return MeshDispatch.dispatch(self.actor_infos, method_name, *args, **kwargs) | ||
| if dispatch_type == "pass_through": | ||
| return [getattr(actor_info.handle, method_name).remote(*args, **kwargs) for actor_info in self.actor_infos] | ||
| raise AssertionError(f"Unsupported dispatch type: {dispatch_type}") | ||
|
|
||
|
|
||
| def _make_dispatch() -> WorkerDispatch: | ||
| cfg = SkyRLTrainConfig() | ||
| cfg.trainer.placement.colocate_all = False | ||
| cfg.trainer.placement.colocate_policy_ref = False | ||
| return WorkerDispatch(cfg, policy_actor_group=StubActorGroup()) | ||
|
|
||
|
|
||
| def _make_batch(batch_size: int = 4) -> TrainingInputBatch: | ||
| return TrainingInputBatch( | ||
| { | ||
| "sample_id": torch.arange(batch_size, dtype=torch.long), | ||
| "dummy": torch.arange(batch_size, dtype=torch.long), | ||
| } | ||
| ) | ||
|
|
||
|
|
||
| def test_forward_backward_from_staged_matches_unstaged_loss_fn_outputs(ray_init): | ||
| dispatch = _make_dispatch() | ||
| batch = _make_batch() | ||
|
|
||
| unstaged = dispatch.forward_backward("policy", batch) | ||
| chunk_refs = dispatch.stage_data("policy", batch, [(0, len(batch))])[0] | ||
| staged = dispatch.forward_backward_from_staged("policy", chunk_refs) | ||
|
|
||
| assert staged == unstaged | ||
| assert [output["sample_id"] for output in staged["loss_fn_outputs"]] == [0, 1, 2, 3] | ||
| assert [output["dp_rank"] for output in staged["loss_fn_outputs"]] == [0, 0, 1, 1] | ||
|
|
||
|
|
||
| def test_forward_backward_from_staged_preserves_scalar_only_status(ray_init): | ||
| dispatch = _make_dispatch() | ||
| batch = _make_batch() | ||
|
|
||
| unstaged = dispatch.forward_backward("policy", batch, loss_fn="scalar_only") | ||
| chunk_refs = dispatch.stage_data("policy", batch, [(0, len(batch))])[0] | ||
| staged = dispatch.forward_backward_from_staged("policy", chunk_refs, loss_fn="scalar_only") | ||
|
|
||
| assert staged == {"policy_loss": 1.25, "policy_lr": 0.5} | ||
| assert staged == unstaged | ||
| assert "loss_fn_outputs" not in staged |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
Uniontype hint is needed for the type aliases defined later in the file to maintain compatibility with Python versions earlier than 3.10, as the|operator for types is only supported at runtime in 3.10+.