diff --git a/.github/workflows/gpu_skyrl.yaml b/.github/workflows/gpu_skyrl.yaml index b9d9bb8552..21f188bbff 100644 --- a/.github/workflows/gpu_skyrl.yaml +++ b/.github/workflows/gpu_skyrl.yaml @@ -29,6 +29,9 @@ concurrency: jobs: skyrl_gpu_tests: runs-on: ubuntu-latest + env: + ANYSCALE_CLI_TOKEN: ${{ secrets.ANYSCALE_CLI_TOKEN }} + ANYSCALE_HOST: https://console.anyscale.com defaults: run: shell: bash @@ -47,10 +50,11 @@ jobs: activate-environment: true - name: Install dependencies run: uv pip install anyscale==0.24.79 typer==0.9.0 + - name: Skip GPU tests when Anyscale credentials are unavailable + if: ${{ env.ANYSCALE_CLI_TOKEN == '' }} + run: echo "Skipping GPU tests because ANYSCALE_CLI_TOKEN is unavailable in this workflow context." - name: GPU tests - env: - ANYSCALE_CLI_TOKEN: ${{ secrets.ANYSCALE_CLI_TOKEN }} - ANYSCALE_HOST: https://console.anyscale.com + if: ${{ env.ANYSCALE_CLI_TOKEN != '' }} run: | anyscale job submit -f ci/anyscale_gpu_ci.yaml --timeout 10000 anyscale job wait --cloud sky-anyscale-aws-us-east-1 --name skyrl-tx-gpu-ci --timeout 10000 diff --git a/skyrl/backends/skyrl_train/workers/worker_dispatch.py b/skyrl/backends/skyrl_train/workers/worker_dispatch.py index 91e14a356f..64e005e765 100644 --- a/skyrl/backends/skyrl_train/workers/worker_dispatch.py +++ b/skyrl/backends/skyrl_train/workers/worker_dispatch.py @@ -8,8 +8,10 @@ The trainer interacts with the worker dispatch if all models are always on GPU. """ +from __future__ import annotations + from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import ray from ray import ObjectRef @@ -18,16 +20,18 @@ MeshDispatch, concatenate_outputs_after_mesh_dispatch, ) -from skyrl.backends.skyrl_train.inference_engines.inference_engine_client import ( - InferenceEngineClient, -) from skyrl.backends.skyrl_train.training_batch import ( TrainingInputBatch, TrainingOutputBatch, ) -from skyrl.backends.skyrl_train.workers.worker import PPORayActorGroup from skyrl.train.config import SkyRLTrainConfig +if TYPE_CHECKING: + from skyrl.backends.skyrl_train.inference_engines.inference_engine_client import ( + InferenceEngineClient, + ) + from skyrl.backends.skyrl_train.workers.worker import PPORayActorGroup + @dataclass class GPUState: @@ -37,6 +41,36 @@ class GPUState: optimizer_on_gpu: bool = False +LossFnOutput = dict[str, Any] +ForwardBackwardStatusValue = float | int | list[LossFnOutput] +ForwardBackwardStatus = dict[str, ForwardBackwardStatusValue] + + +def _merge_forward_backward_statuses(statuses: List[ForwardBackwardStatus]) -> ForwardBackwardStatus: + """Normalize forward/backward statuses from per-rank results into one dict. + + Scalar metrics are already reduced across data-parallel ranks on the workers, + so they are taken from the first returned status. If workers emit + ``loss_fn_outputs``, concatenate them in rank order to reconstruct the full + logical batch. + """ + + assert statuses, "Expected at least one status from forward_backward" + + result = dict(statuses[0]) + if not any("loss_fn_outputs" in status for status in statuses): + return result + + all_loss_fn_outputs: list[LossFnOutput] = [] + for status in statuses: + loss_fn_outputs = status.get("loss_fn_outputs") + if isinstance(loss_fn_outputs, list): + all_loss_fn_outputs.extend(loss_fn_outputs) + + result["loss_fn_outputs"] = all_loss_fn_outputs + return result + + class WorkerDispatch: """ Unified dispatch layer that manages all actor groups (policy, critic, ref). @@ -191,7 +225,7 @@ def forward_backward( data: TrainingInputBatch, loss_fn: Optional[str] = None, loss_fn_config: Optional[Dict[str, Any]] = None, - ) -> Dict[str, float]: + ) -> ForwardBackwardStatus: """Run forward/backward pass. Needs model + optimizer. Args: @@ -203,7 +237,7 @@ def forward_backward( (e.g., {"clip_low_threshold": 0.9} for PPO) Returns: - Dictionary of training metrics + Reduced scalar metrics and optional ``loss_fn_outputs`` for the full batch. """ self._ensure_on_gpu(model, need_optimizer=True, need_model=True) @@ -218,19 +252,7 @@ def forward_backward( statuses = ray.get(refs) self._save_memory_snapshot(model, "forward_backward") - - # With DP>1, each rank returns loss_fn_outputs for its data chunk. - # Concatenate them in rank order to get the full batch's outputs. - # Scalar metrics (loss, lr) are already all-reduced, so use statuses[0] for those. - if len(statuses) > 1 and statuses[0] and "loss_fn_outputs" in statuses[0]: - all_loss_fn_outputs = [] - for status in statuses: - all_loss_fn_outputs.extend(status.pop("loss_fn_outputs", [])) - result = statuses[0] - result["loss_fn_outputs"] = all_loss_fn_outputs - return result - - return statuses[0] + return _merge_forward_backward_statuses(statuses) def forward_backward_from_staged( self, @@ -238,7 +260,7 @@ def forward_backward_from_staged( chunk_refs: List[ObjectRef], loss_fn: Optional[str] = None, loss_fn_config: Optional[Dict[str, Any]] = None, - ) -> Dict[str, float]: + ) -> ForwardBackwardStatus: """ Run forward/backward pass using pre-staged per-DP chunks. @@ -250,7 +272,7 @@ def forward_backward_from_staged( chunk_refs: Pre-staged ObjectRefs, one per DP rank (from ``stage_data``) Returns: - Aggregated metrics dict from training + Reduced scalar metrics and optional ``loss_fn_outputs`` for the full batch. """ self._ensure_on_gpu(model, need_optimizer=True, need_model=True) @@ -270,7 +292,7 @@ def forward_backward_from_staged( statuses = ray.get(refs) self._save_memory_snapshot(model, "forward_backward") - return statuses[0] + return _merge_forward_backward_statuses(statuses) def optim_step(self, model: str) -> Optional[float]: """Run optimizer step. Model should already be on GPU from forward_backward.""" diff --git a/tests/backends/skyrl_train/workers/test_worker_dispatch.py b/tests/backends/skyrl_train/workers/test_worker_dispatch.py new file mode 100644 index 0000000000..054d592718 --- /dev/null +++ b/tests/backends/skyrl_train/workers/test_worker_dispatch.py @@ -0,0 +1,103 @@ +""" +Tests for WorkerDispatch forward/backward status aggregation. + +uv run --isolated --extra skyrl-train --extra dev pytest tests/backends/skyrl_train/workers/test_worker_dispatch.py +""" + +import ray +import torch + +from skyrl.backends.skyrl_train.distributed.dispatch import ( + ActorInfo, + MeshDispatch, + MeshRank, +) +from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch +from skyrl.backends.skyrl_train.workers.worker_dispatch import WorkerDispatch +from skyrl.train.config import SkyRLTrainConfig + + +@ray.remote +class FakeForwardBackwardWorker: + def __init__(self, rank: int, dp_rank: int): + self.rank = rank + self.dp_rank = dp_rank + + def forward_backward(self, data: TrainingInputBatch, loss_fn=None, loss_fn_config=None): + del loss_fn_config + + status = { + "policy_loss": 1.25, + "policy_lr": 0.5, + } + if loss_fn == "scalar_only": + return status + + status["loss_fn_outputs"] = [ + {"sample_id": sample_id, "dp_rank": self.dp_rank} for sample_id in data["sample_id"].tolist() + ] + return status + + def save_memory_snapshot(self, tag: str): + return tag + + +class StubActorGroup: + def __init__(self, dp_size: int = 2): + self.actors = [FakeForwardBackwardWorker.remote(rank=i, dp_rank=i) for i in range(dp_size)] + self.actor_infos = [ + ActorInfo( + actor, + MeshRank(dp=i, sp=0, tp=0, pp=0, world_size=dp_size, dp_size=dp_size, pp_size=1), + ) + for i, actor in enumerate(self.actors) + ] + + def async_run_ray_method(self, dispatch_type: str, method_name: str, *args, **kwargs): + if dispatch_type == "mesh": + return MeshDispatch.dispatch(self.actor_infos, method_name, *args, **kwargs) + if dispatch_type == "pass_through": + return [getattr(actor_info.handle, method_name).remote(*args, **kwargs) for actor_info in self.actor_infos] + raise AssertionError(f"Unsupported dispatch type: {dispatch_type}") + + +def _make_dispatch() -> WorkerDispatch: + cfg = SkyRLTrainConfig() + cfg.trainer.placement.colocate_all = False + cfg.trainer.placement.colocate_policy_ref = False + return WorkerDispatch(cfg, policy_actor_group=StubActorGroup()) + + +def _make_batch(batch_size: int = 4) -> TrainingInputBatch: + return TrainingInputBatch( + { + "sample_id": torch.arange(batch_size, dtype=torch.long), + "dummy": torch.arange(batch_size, dtype=torch.long), + } + ) + + +def test_forward_backward_from_staged_matches_unstaged_loss_fn_outputs(ray_init): + dispatch = _make_dispatch() + batch = _make_batch() + + unstaged = dispatch.forward_backward("policy", batch) + chunk_refs = dispatch.stage_data("policy", batch, [(0, len(batch))])[0] + staged = dispatch.forward_backward_from_staged("policy", chunk_refs) + + assert staged == unstaged + assert [output["sample_id"] for output in staged["loss_fn_outputs"]] == [0, 1, 2, 3] + assert [output["dp_rank"] for output in staged["loss_fn_outputs"]] == [0, 0, 1, 1] + + +def test_forward_backward_from_staged_preserves_scalar_only_status(ray_init): + dispatch = _make_dispatch() + batch = _make_batch() + + unstaged = dispatch.forward_backward("policy", batch, loss_fn="scalar_only") + chunk_refs = dispatch.stage_data("policy", batch, [(0, len(batch))])[0] + staged = dispatch.forward_backward_from_staged("policy", chunk_refs, loss_fn="scalar_only") + + assert staged == {"policy_loss": 1.25, "policy_lr": 0.5} + assert staged == unstaged + assert "loss_fn_outputs" not in staged