Skip to content

[skyrl] Preserve staged forward_backward loss_fn_outputs across DP ranks#1534

Open
taivu1998 wants to merge 2 commits into
NovaSky-AI:mainfrom
taivu1998:tdv/issue-1520-staged-loss-fn-outputs
Open

[skyrl] Preserve staged forward_backward loss_fn_outputs across DP ranks#1534
taivu1998 wants to merge 2 commits into
NovaSky-AI:mainfrom
taivu1998:tdv/issue-1520-staged-loss-fn-outputs

Conversation

@taivu1998
Copy link
Copy Markdown

@taivu1998 taivu1998 commented Apr 19, 2026

Summary

This PR fixes the staged WorkerDispatch forward/backward path so it preserves full-batch loss_fn_outputs the same way as the existing unstaged path, and it cleans up the CI fallout that showed up on the initial PR run.

Specifically, it:

  • adds a shared status-merge helper in WorkerDispatch and uses it from both forward_backward(...) and forward_backward_from_staged(...)
  • ensures staged dispatch concatenates per-rank loss_fn_outputs in rank order instead of returning only rank 0's outputs
  • corrects the WorkerDispatch return annotations and docstrings so they no longer claim a scalar-only return
  • adds a focused CPU regression test covering both structured-output and scalar-only cases
  • makes WorkerDispatch's inference-client and actor-group imports type-only so the module is lighter to import in CPU test environments
  • applies the required lint/format fixes surfaced by CI
  • makes gpu_skyrl.yaml fork-safe by skipping the Anyscale submission step when ANYSCALE_CLI_TOKEN is unavailable, instead of hard-failing in fork PR contexts before repo tests start

Root Cause

Issue #1520 identified that WorkerDispatch.forward_backward(...) already had special handling to merge per-rank loss_fn_outputs, but WorkerDispatch.forward_backward_from_staged(...) returned statuses[0] directly.

That meant staged training silently dropped loss_fn_outputs from DP ranks >= 1 even though worker-side forward/backward already produced them correctly.

The initial PR also surfaced two CI-specific follow-ups:

  • the local patch needed the exact Ruff/Black formatting expected by the repo hooks
  • the skyrl_gpu_tests workflow was not fork-safe and failed on missing Anyscale credentials before any repo code ran

Impact

After this change:

  • staged and unstaged WorkerDispatch forward/backward paths return the same logical output shape
  • callers that rely on loss_fn_outputs get the full logical batch on the staged path as well
  • scalar metrics remain unchanged because they are still sourced from the already-reduced first status dict
  • fork PRs without Anyscale secrets no longer fail the skyrl_gpu_tests job during external credential bootstrap

Validation

Passed locally:

  • PRE_COMMIT_HOME=/tmp/pre-commit-cache /Users/vuductai/Documents/Projects/SkyRL/.venv/bin/pre-commit run ruff --files skyrl/backends/skyrl_train/workers/worker_dispatch.py tests/backends/skyrl_train/workers/test_worker_dispatch.py --config .pre-commit-config.yaml
  • PRE_COMMIT_HOME=/tmp/pre-commit-cache /Users/vuductai/Documents/Projects/SkyRL/.venv/bin/pre-commit run black --files skyrl/backends/skyrl_train/workers/worker_dispatch.py tests/backends/skyrl_train/workers/test_worker_dispatch.py --config .pre-commit-config.yaml
  • /Users/vuductai/Documents/Projects/SkyRL/.venv/bin/python -m pytest tests/backends/skyrl_train/workers/test_worker_dispatch.py tests/backends/skyrl_train/distributed/test_dispatch.py -q
  • /Users/vuductai/Documents/Projects/SkyRL/.venv/bin/python -c "import yaml, pathlib; yaml.safe_load(pathlib.Path('.github/workflows/gpu_skyrl.yaml').read_text()); print('yaml-ok')"
  • git diff --check

Notes:

  • The original skyrl_gpu_tests failure on this PR was an Anyscale credential/bootstrap failure, not a trainer regression.
  • I did not rerun GPU product tests locally.

Issue

Closes #1520.

@taivu1998 taivu1998 marked this pull request as ready for review April 19, 2026 23:46
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a helper function _merge_forward_backward_statuses in WorkerDispatch to aggregate training metrics and loss function outputs across multiple ranks, alongside new type aliases and unit tests. The review feedback highlights potential runtime errors on older Python versions due to the use of modern type hinting syntax in type aliases and suggests an optimization for the merging logic in single-rank scenarios.

Comment on lines +44 to +46
LossFnOutput = dict[str, Any]
ForwardBackwardStatusValue = float | int | list[LossFnOutput]
ForwardBackwardStatus = dict[str, ForwardBackwardStatusValue]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The use of the | operator and lowercase dict/list in type aliases will cause a TypeError at runtime on Python versions earlier than 3.10 and 3.9 respectively. While from __future__ import annotations allows this syntax within annotations, type aliases are evaluated as expressions at import time. To ensure compatibility with older Python versions (as seen in other parts of the codebase using Union), it is safer to use typing.Union, typing.Dict, and typing.List.

Suggested change
LossFnOutput = dict[str, Any]
ForwardBackwardStatusValue = float | int | list[LossFnOutput]
ForwardBackwardStatus = dict[str, ForwardBackwardStatusValue]
LossFnOutput = Dict[str, Any]
ForwardBackwardStatusValue = Union[float, int, List[LossFnOutput]]
ForwardBackwardStatus = Dict[str, ForwardBackwardStatusValue]


from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The Union type hint is needed for the type aliases defined later in the file to maintain compatibility with Python versions earlier than 3.10, as the | operator for types is only supported at runtime in 3.10+.

Suggested change
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

Comment on lines +60 to +71
result = dict(statuses[0])
if not any("loss_fn_outputs" in status for status in statuses):
return result

all_loss_fn_outputs: list[LossFnOutput] = []
for status in statuses:
loss_fn_outputs = status.get("loss_fn_outputs")
if isinstance(loss_fn_outputs, list):
all_loss_fn_outputs.extend(loss_fn_outputs)

result["loss_fn_outputs"] = all_loss_fn_outputs
return result
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This implementation can be optimized for the common case where only one status is returned (e.g., when dp_size=1) or when no merging is required. Returning statuses[0] directly in these cases avoids unnecessary dictionary copying and list rebuilding, restoring the efficient behavior of the previous implementation while still correctly handling the multi-rank merging logic when needed.

Suggested change
result = dict(statuses[0])
if not any("loss_fn_outputs" in status for status in statuses):
return result
all_loss_fn_outputs: list[LossFnOutput] = []
for status in statuses:
loss_fn_outputs = status.get("loss_fn_outputs")
if isinstance(loss_fn_outputs, list):
all_loss_fn_outputs.extend(loss_fn_outputs)
result["loss_fn_outputs"] = all_loss_fn_outputs
return result
if len(statuses) == 1 or not any("loss_fn_outputs" in status for status in statuses):
return statuses[0]
result = dict(statuses[0])
all_loss_fn_outputs: List[LossFnOutput] = []
for status in statuses:
loss_fn_outputs = status.get("loss_fn_outputs")
if isinstance(loss_fn_outputs, list):
all_loss_fn_outputs.extend(loss_fn_outputs)
result["loss_fn_outputs"] = all_loss_fn_outputs
return result

Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

✅ Devin Review: No Issues Found

Devin Review analyzed this PR and found no potential bugs to report.

View in Devin Review to see 3 additional findings.

Open in Devin Review

@taivu1998 taivu1998 marked this pull request as draft April 20, 2026 00:06
@taivu1998 taivu1998 marked this pull request as ready for review April 20, 2026 01:23
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a centralized utility, _merge_forward_backward_statuses, to aggregate training metrics and loss function outputs from distributed workers, and adds a new test suite to verify this behavior. The review feedback points out a significant issue where the current implementation would cause duplicate outputs in non-Data-Parallelism setups like Tensor or Sequence Parallelism. To address this, it is recommended to refactor the merging logic to only collect data from primary ranks by incorporating actor metadata.

Comment on lines +49 to +71
def _merge_forward_backward_statuses(statuses: List[ForwardBackwardStatus]) -> ForwardBackwardStatus:
"""Normalize forward/backward statuses from per-rank results into one dict.

Scalar metrics are already reduced across data-parallel ranks on the workers,
so they are taken from the first returned status. If workers emit
``loss_fn_outputs``, concatenate them in rank order to reconstruct the full
logical batch.
"""

assert statuses, "Expected at least one status from forward_backward"

result = dict(statuses[0])
if not any("loss_fn_outputs" in status for status in statuses):
return result

all_loss_fn_outputs: list[LossFnOutput] = []
for status in statuses:
loss_fn_outputs = status.get("loss_fn_outputs")
if isinstance(loss_fn_outputs, list):
all_loss_fn_outputs.extend(loss_fn_outputs)

result["loss_fn_outputs"] = all_loss_fn_outputs
return result
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of _merge_forward_backward_statuses concatenates loss_fn_outputs from all statuses in the list. This will lead to duplicated outputs when using non-Data-Parallelism (e.g., Tensor Parallelism or Sequence Parallelism), as multiple actors will process the same data chunk and return identical outputs.

To fix this, the merging logic should only collect outputs from primary ranks (where is_collection_dp_rank() is true), similar to how concatenate_outputs_after_mesh_dispatch handles TrainingOutputBatch. I suggest refactoring the function to accept actor_infos and filter the statuses accordingly.

def _merge_forward_backward_statuses(
    actor_infos: List[ActorInfo], statuses: List[ForwardBackwardStatus]
) -> ForwardBackwardStatus:
    """Normalize forward/backward statuses from per-rank results into one dict.

    Scalar metrics are already reduced across data-parallel ranks on the workers,
    so they are taken from the first returned status. If workers emit
    ``loss_fn_outputs``, concatenate them in rank order from primary ranks
    to reconstruct the full logical batch without duplicates from TP/SP/PP.
    """
    assert len(actor_infos) == len(statuses)
    assert statuses, "Expected at least one status from forward_backward"

    # Use the first status for scalar metrics (assumed all-reduced)
    result = dict(statuses[0])
    result.pop("loss_fn_outputs", None)

    # Collect loss_fn_outputs in DP rank order from collection ranks only
    dp_rank_to_outputs: Dict[int, List[LossFnOutput]] = {}
    for actor_info, status in zip(actor_infos, statuses):
        if actor_info.rank.is_collection_dp_rank():
            outputs = status.get("loss_fn_outputs")
            if isinstance(outputs, list):
                dp_rank_to_outputs[actor_info.rank.dp] = outputs

    if dp_rank_to_outputs:
        all_loss_fn_outputs: List[LossFnOutput] = []
        # Ensure we concatenate in DP rank order
        for i in range(actor_infos[0].rank.dp_size):
            if i in dp_rank_to_outputs:
                all_loss_fn_outputs.extend(dp_rank_to_outputs[i])
        result["loss_fn_outputs"] = all_loss_fn_outputs

    return result

return result

return statuses[0]
return _merge_forward_backward_statuses(statuses)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Update the call to _merge_forward_backward_statuses to include actor_infos as required by the suggested refactoring.

Suggested change
return _merge_forward_backward_statuses(statuses)
return _merge_forward_backward_statuses(self._actor_groups[model].actor_infos, statuses)


self._save_memory_snapshot(model, "forward_backward")
return statuses[0]
return _merge_forward_backward_statuses(statuses)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Update the call to _merge_forward_backward_statuses to include actor_infos as required by the suggested refactoring.

Suggested change
return _merge_forward_backward_statuses(statuses)
return _merge_forward_backward_statuses(self._actor_groups[model].actor_infos, statuses)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

WorkerDispatch.forward_backward_from_staged drops DP-rank≥1 loss_fn_outputs and has the wrong return annotation

1 participant