Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions .github/workflows/gpu_skyrl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ concurrency:
jobs:
skyrl_gpu_tests:
runs-on: ubuntu-latest
env:
ANYSCALE_CLI_TOKEN: ${{ secrets.ANYSCALE_CLI_TOKEN }}
ANYSCALE_HOST: https://console.anyscale.com
defaults:
run:
shell: bash
Expand All @@ -47,10 +50,11 @@ jobs:
activate-environment: true
- name: Install dependencies
run: uv pip install anyscale==0.24.79 typer==0.9.0
- name: Skip GPU tests when Anyscale credentials are unavailable
if: ${{ env.ANYSCALE_CLI_TOKEN == '' }}
run: echo "Skipping GPU tests because ANYSCALE_CLI_TOKEN is unavailable in this workflow context."
- name: GPU tests
env:
ANYSCALE_CLI_TOKEN: ${{ secrets.ANYSCALE_CLI_TOKEN }}
ANYSCALE_HOST: https://console.anyscale.com
if: ${{ env.ANYSCALE_CLI_TOKEN != '' }}
run: |
anyscale job submit -f ci/anyscale_gpu_ci.yaml --timeout 10000
anyscale job wait --cloud sky-anyscale-aws-us-east-1 --name skyrl-tx-gpu-ci --timeout 10000
68 changes: 45 additions & 23 deletions skyrl/backends/skyrl_train/workers/worker_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
The trainer interacts with the worker dispatch if all models are always on GPU.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The Union type hint is needed for the type aliases defined later in the file to maintain compatibility with Python versions earlier than 3.10, as the | operator for types is only supported at runtime in 3.10+.

Suggested change
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union


import ray
from ray import ObjectRef
Expand All @@ -18,16 +20,18 @@
MeshDispatch,
concatenate_outputs_after_mesh_dispatch,
)
from skyrl.backends.skyrl_train.inference_engines.inference_engine_client import (
InferenceEngineClient,
)
from skyrl.backends.skyrl_train.training_batch import (
TrainingInputBatch,
TrainingOutputBatch,
)
from skyrl.backends.skyrl_train.workers.worker import PPORayActorGroup
from skyrl.train.config import SkyRLTrainConfig

if TYPE_CHECKING:
from skyrl.backends.skyrl_train.inference_engines.inference_engine_client import (
InferenceEngineClient,
)
from skyrl.backends.skyrl_train.workers.worker import PPORayActorGroup


@dataclass
class GPUState:
Expand All @@ -37,6 +41,36 @@ class GPUState:
optimizer_on_gpu: bool = False


LossFnOutput = dict[str, Any]
ForwardBackwardStatusValue = float | int | list[LossFnOutput]
ForwardBackwardStatus = dict[str, ForwardBackwardStatusValue]
Comment on lines +44 to +46
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The use of the | operator and lowercase dict/list in type aliases will cause a TypeError at runtime on Python versions earlier than 3.10 and 3.9 respectively. While from __future__ import annotations allows this syntax within annotations, type aliases are evaluated as expressions at import time. To ensure compatibility with older Python versions (as seen in other parts of the codebase using Union), it is safer to use typing.Union, typing.Dict, and typing.List.

Suggested change
LossFnOutput = dict[str, Any]
ForwardBackwardStatusValue = float | int | list[LossFnOutput]
ForwardBackwardStatus = dict[str, ForwardBackwardStatusValue]
LossFnOutput = Dict[str, Any]
ForwardBackwardStatusValue = Union[float, int, List[LossFnOutput]]
ForwardBackwardStatus = Dict[str, ForwardBackwardStatusValue]



def _merge_forward_backward_statuses(statuses: List[ForwardBackwardStatus]) -> ForwardBackwardStatus:
"""Normalize forward/backward statuses from per-rank results into one dict.

Scalar metrics are already reduced across data-parallel ranks on the workers,
so they are taken from the first returned status. If workers emit
``loss_fn_outputs``, concatenate them in rank order to reconstruct the full
logical batch.
"""

assert statuses, "Expected at least one status from forward_backward"

result = dict(statuses[0])
if not any("loss_fn_outputs" in status for status in statuses):
return result

all_loss_fn_outputs: list[LossFnOutput] = []
for status in statuses:
loss_fn_outputs = status.get("loss_fn_outputs")
if isinstance(loss_fn_outputs, list):
all_loss_fn_outputs.extend(loss_fn_outputs)

result["loss_fn_outputs"] = all_loss_fn_outputs
return result
Comment on lines +60 to +71
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This implementation can be optimized for the common case where only one status is returned (e.g., when dp_size=1) or when no merging is required. Returning statuses[0] directly in these cases avoids unnecessary dictionary copying and list rebuilding, restoring the efficient behavior of the previous implementation while still correctly handling the multi-rank merging logic when needed.

Suggested change
result = dict(statuses[0])
if not any("loss_fn_outputs" in status for status in statuses):
return result
all_loss_fn_outputs: list[LossFnOutput] = []
for status in statuses:
loss_fn_outputs = status.get("loss_fn_outputs")
if isinstance(loss_fn_outputs, list):
all_loss_fn_outputs.extend(loss_fn_outputs)
result["loss_fn_outputs"] = all_loss_fn_outputs
return result
if len(statuses) == 1 or not any("loss_fn_outputs" in status for status in statuses):
return statuses[0]
result = dict(statuses[0])
all_loss_fn_outputs: List[LossFnOutput] = []
for status in statuses:
loss_fn_outputs = status.get("loss_fn_outputs")
if isinstance(loss_fn_outputs, list):
all_loss_fn_outputs.extend(loss_fn_outputs)
result["loss_fn_outputs"] = all_loss_fn_outputs
return result

Comment on lines +49 to +71
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of _merge_forward_backward_statuses concatenates loss_fn_outputs from all statuses in the list. This will lead to duplicated outputs when using non-Data-Parallelism (e.g., Tensor Parallelism or Sequence Parallelism), as multiple actors will process the same data chunk and return identical outputs.

To fix this, the merging logic should only collect outputs from primary ranks (where is_collection_dp_rank() is true), similar to how concatenate_outputs_after_mesh_dispatch handles TrainingOutputBatch. I suggest refactoring the function to accept actor_infos and filter the statuses accordingly.

def _merge_forward_backward_statuses(
    actor_infos: List[ActorInfo], statuses: List[ForwardBackwardStatus]
) -> ForwardBackwardStatus:
    """Normalize forward/backward statuses from per-rank results into one dict.

    Scalar metrics are already reduced across data-parallel ranks on the workers,
    so they are taken from the first returned status. If workers emit
    ``loss_fn_outputs``, concatenate them in rank order from primary ranks
    to reconstruct the full logical batch without duplicates from TP/SP/PP.
    """
    assert len(actor_infos) == len(statuses)
    assert statuses, "Expected at least one status from forward_backward"

    # Use the first status for scalar metrics (assumed all-reduced)
    result = dict(statuses[0])
    result.pop("loss_fn_outputs", None)

    # Collect loss_fn_outputs in DP rank order from collection ranks only
    dp_rank_to_outputs: Dict[int, List[LossFnOutput]] = {}
    for actor_info, status in zip(actor_infos, statuses):
        if actor_info.rank.is_collection_dp_rank():
            outputs = status.get("loss_fn_outputs")
            if isinstance(outputs, list):
                dp_rank_to_outputs[actor_info.rank.dp] = outputs

    if dp_rank_to_outputs:
        all_loss_fn_outputs: List[LossFnOutput] = []
        # Ensure we concatenate in DP rank order
        for i in range(actor_infos[0].rank.dp_size):
            if i in dp_rank_to_outputs:
                all_loss_fn_outputs.extend(dp_rank_to_outputs[i])
        result["loss_fn_outputs"] = all_loss_fn_outputs

    return result



class WorkerDispatch:
"""
Unified dispatch layer that manages all actor groups (policy, critic, ref).
Expand Down Expand Up @@ -191,7 +225,7 @@ def forward_backward(
data: TrainingInputBatch,
loss_fn: Optional[str] = None,
loss_fn_config: Optional[Dict[str, Any]] = None,
) -> Dict[str, float]:
) -> ForwardBackwardStatus:
"""Run forward/backward pass. Needs model + optimizer.

Args:
Expand All @@ -203,7 +237,7 @@ def forward_backward(
(e.g., {"clip_low_threshold": 0.9} for PPO)

Returns:
Dictionary of training metrics
Reduced scalar metrics and optional ``loss_fn_outputs`` for the full batch.
"""
self._ensure_on_gpu(model, need_optimizer=True, need_model=True)

Expand All @@ -218,27 +252,15 @@ def forward_backward(
statuses = ray.get(refs)

self._save_memory_snapshot(model, "forward_backward")

# With DP>1, each rank returns loss_fn_outputs for its data chunk.
# Concatenate them in rank order to get the full batch's outputs.
# Scalar metrics (loss, lr) are already all-reduced, so use statuses[0] for those.
if len(statuses) > 1 and statuses[0] and "loss_fn_outputs" in statuses[0]:
all_loss_fn_outputs = []
for status in statuses:
all_loss_fn_outputs.extend(status.pop("loss_fn_outputs", []))
result = statuses[0]
result["loss_fn_outputs"] = all_loss_fn_outputs
return result

return statuses[0]
return _merge_forward_backward_statuses(statuses)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Update the call to _merge_forward_backward_statuses to include actor_infos as required by the suggested refactoring.

Suggested change
return _merge_forward_backward_statuses(statuses)
return _merge_forward_backward_statuses(self._actor_groups[model].actor_infos, statuses)


def forward_backward_from_staged(
self,
model: str,
chunk_refs: List[ObjectRef],
loss_fn: Optional[str] = None,
loss_fn_config: Optional[Dict[str, Any]] = None,
) -> Dict[str, float]:
) -> ForwardBackwardStatus:
"""
Run forward/backward pass using pre-staged per-DP chunks.

Expand All @@ -250,7 +272,7 @@ def forward_backward_from_staged(
chunk_refs: Pre-staged ObjectRefs, one per DP rank (from ``stage_data``)

Returns:
Aggregated metrics dict from training
Reduced scalar metrics and optional ``loss_fn_outputs`` for the full batch.
"""
self._ensure_on_gpu(model, need_optimizer=True, need_model=True)

Expand All @@ -270,7 +292,7 @@ def forward_backward_from_staged(
statuses = ray.get(refs)

self._save_memory_snapshot(model, "forward_backward")
return statuses[0]
return _merge_forward_backward_statuses(statuses)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Update the call to _merge_forward_backward_statuses to include actor_infos as required by the suggested refactoring.

Suggested change
return _merge_forward_backward_statuses(statuses)
return _merge_forward_backward_statuses(self._actor_groups[model].actor_infos, statuses)


def optim_step(self, model: str) -> Optional[float]:
"""Run optimizer step. Model should already be on GPU from forward_backward."""
Expand Down
103 changes: 103 additions & 0 deletions tests/backends/skyrl_train/workers/test_worker_dispatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""
Tests for WorkerDispatch forward/backward status aggregation.

uv run --isolated --extra skyrl-train --extra dev pytest tests/backends/skyrl_train/workers/test_worker_dispatch.py
"""

import ray
import torch

from skyrl.backends.skyrl_train.distributed.dispatch import (
ActorInfo,
MeshDispatch,
MeshRank,
)
from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch
from skyrl.backends.skyrl_train.workers.worker_dispatch import WorkerDispatch
from skyrl.train.config import SkyRLTrainConfig


@ray.remote
class FakeForwardBackwardWorker:
def __init__(self, rank: int, dp_rank: int):
self.rank = rank
self.dp_rank = dp_rank

def forward_backward(self, data: TrainingInputBatch, loss_fn=None, loss_fn_config=None):
del loss_fn_config

status = {
"policy_loss": 1.25,
"policy_lr": 0.5,
}
if loss_fn == "scalar_only":
return status

status["loss_fn_outputs"] = [
{"sample_id": sample_id, "dp_rank": self.dp_rank} for sample_id in data["sample_id"].tolist()
]
return status

def save_memory_snapshot(self, tag: str):
return tag


class StubActorGroup:
def __init__(self, dp_size: int = 2):
self.actors = [FakeForwardBackwardWorker.remote(rank=i, dp_rank=i) for i in range(dp_size)]
self.actor_infos = [
ActorInfo(
actor,
MeshRank(dp=i, sp=0, tp=0, pp=0, world_size=dp_size, dp_size=dp_size, pp_size=1),
)
for i, actor in enumerate(self.actors)
]

def async_run_ray_method(self, dispatch_type: str, method_name: str, *args, **kwargs):
if dispatch_type == "mesh":
return MeshDispatch.dispatch(self.actor_infos, method_name, *args, **kwargs)
if dispatch_type == "pass_through":
return [getattr(actor_info.handle, method_name).remote(*args, **kwargs) for actor_info in self.actor_infos]
raise AssertionError(f"Unsupported dispatch type: {dispatch_type}")


def _make_dispatch() -> WorkerDispatch:
cfg = SkyRLTrainConfig()
cfg.trainer.placement.colocate_all = False
cfg.trainer.placement.colocate_policy_ref = False
return WorkerDispatch(cfg, policy_actor_group=StubActorGroup())


def _make_batch(batch_size: int = 4) -> TrainingInputBatch:
return TrainingInputBatch(
{
"sample_id": torch.arange(batch_size, dtype=torch.long),
"dummy": torch.arange(batch_size, dtype=torch.long),
}
)


def test_forward_backward_from_staged_matches_unstaged_loss_fn_outputs(ray_init):
dispatch = _make_dispatch()
batch = _make_batch()

unstaged = dispatch.forward_backward("policy", batch)
chunk_refs = dispatch.stage_data("policy", batch, [(0, len(batch))])[0]
staged = dispatch.forward_backward_from_staged("policy", chunk_refs)

assert staged == unstaged
assert [output["sample_id"] for output in staged["loss_fn_outputs"]] == [0, 1, 2, 3]
assert [output["dp_rank"] for output in staged["loss_fn_outputs"]] == [0, 0, 1, 1]


def test_forward_backward_from_staged_preserves_scalar_only_status(ray_init):
dispatch = _make_dispatch()
batch = _make_batch()

unstaged = dispatch.forward_backward("policy", batch, loss_fn="scalar_only")
chunk_refs = dispatch.stage_data("policy", batch, [(0, len(batch))])[0]
staged = dispatch.forward_backward_from_staged("policy", chunk_refs, loss_fn="scalar_only")

assert staged == {"policy_loss": 1.25, "policy_lr": 0.5}
assert staged == unstaged
assert "loss_fn_outputs" not in staged
Loading