[skyrl] Preserve staged forward_backward loss_fn_outputs across DP ranks#1534
[skyrl] Preserve staged forward_backward loss_fn_outputs across DP ranks#1534taivu1998 wants to merge 2 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a helper function _merge_forward_backward_statuses in WorkerDispatch to aggregate training metrics and loss function outputs across multiple ranks, alongside new type aliases and unit tests. The review feedback highlights potential runtime errors on older Python versions due to the use of modern type hinting syntax in type aliases and suggests an optimization for the merging logic in single-rank scenarios.
| LossFnOutput = dict[str, Any] | ||
| ForwardBackwardStatusValue = float | int | list[LossFnOutput] | ||
| ForwardBackwardStatus = dict[str, ForwardBackwardStatusValue] |
There was a problem hiding this comment.
The use of the | operator and lowercase dict/list in type aliases will cause a TypeError at runtime on Python versions earlier than 3.10 and 3.9 respectively. While from __future__ import annotations allows this syntax within annotations, type aliases are evaluated as expressions at import time. To ensure compatibility with older Python versions (as seen in other parts of the codebase using Union), it is safer to use typing.Union, typing.Dict, and typing.List.
| LossFnOutput = dict[str, Any] | |
| ForwardBackwardStatusValue = float | int | list[LossFnOutput] | |
| ForwardBackwardStatus = dict[str, ForwardBackwardStatusValue] | |
| LossFnOutput = Dict[str, Any] | |
| ForwardBackwardStatusValue = Union[float, int, List[LossFnOutput]] | |
| ForwardBackwardStatus = Dict[str, ForwardBackwardStatusValue] |
|
|
||
| from dataclasses import dataclass | ||
| from typing import Any, Dict, List, Optional, Tuple | ||
| from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple |
There was a problem hiding this comment.
The Union type hint is needed for the type aliases defined later in the file to maintain compatibility with Python versions earlier than 3.10, as the | operator for types is only supported at runtime in 3.10+.
| from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple | |
| from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union |
| result = dict(statuses[0]) | ||
| if not any("loss_fn_outputs" in status for status in statuses): | ||
| return result | ||
|
|
||
| all_loss_fn_outputs: list[LossFnOutput] = [] | ||
| for status in statuses: | ||
| loss_fn_outputs = status.get("loss_fn_outputs") | ||
| if isinstance(loss_fn_outputs, list): | ||
| all_loss_fn_outputs.extend(loss_fn_outputs) | ||
|
|
||
| result["loss_fn_outputs"] = all_loss_fn_outputs | ||
| return result |
There was a problem hiding this comment.
This implementation can be optimized for the common case where only one status is returned (e.g., when dp_size=1) or when no merging is required. Returning statuses[0] directly in these cases avoids unnecessary dictionary copying and list rebuilding, restoring the efficient behavior of the previous implementation while still correctly handling the multi-rank merging logic when needed.
| result = dict(statuses[0]) | |
| if not any("loss_fn_outputs" in status for status in statuses): | |
| return result | |
| all_loss_fn_outputs: list[LossFnOutput] = [] | |
| for status in statuses: | |
| loss_fn_outputs = status.get("loss_fn_outputs") | |
| if isinstance(loss_fn_outputs, list): | |
| all_loss_fn_outputs.extend(loss_fn_outputs) | |
| result["loss_fn_outputs"] = all_loss_fn_outputs | |
| return result | |
| if len(statuses) == 1 or not any("loss_fn_outputs" in status for status in statuses): | |
| return statuses[0] | |
| result = dict(statuses[0]) | |
| all_loss_fn_outputs: List[LossFnOutput] = [] | |
| for status in statuses: | |
| loss_fn_outputs = status.get("loss_fn_outputs") | |
| if isinstance(loss_fn_outputs, list): | |
| all_loss_fn_outputs.extend(loss_fn_outputs) | |
| result["loss_fn_outputs"] = all_loss_fn_outputs | |
| return result |
There was a problem hiding this comment.
Code Review
This pull request introduces a centralized utility, _merge_forward_backward_statuses, to aggregate training metrics and loss function outputs from distributed workers, and adds a new test suite to verify this behavior. The review feedback points out a significant issue where the current implementation would cause duplicate outputs in non-Data-Parallelism setups like Tensor or Sequence Parallelism. To address this, it is recommended to refactor the merging logic to only collect data from primary ranks by incorporating actor metadata.
| def _merge_forward_backward_statuses(statuses: List[ForwardBackwardStatus]) -> ForwardBackwardStatus: | ||
| """Normalize forward/backward statuses from per-rank results into one dict. | ||
|
|
||
| Scalar metrics are already reduced across data-parallel ranks on the workers, | ||
| so they are taken from the first returned status. If workers emit | ||
| ``loss_fn_outputs``, concatenate them in rank order to reconstruct the full | ||
| logical batch. | ||
| """ | ||
|
|
||
| assert statuses, "Expected at least one status from forward_backward" | ||
|
|
||
| result = dict(statuses[0]) | ||
| if not any("loss_fn_outputs" in status for status in statuses): | ||
| return result | ||
|
|
||
| all_loss_fn_outputs: list[LossFnOutput] = [] | ||
| for status in statuses: | ||
| loss_fn_outputs = status.get("loss_fn_outputs") | ||
| if isinstance(loss_fn_outputs, list): | ||
| all_loss_fn_outputs.extend(loss_fn_outputs) | ||
|
|
||
| result["loss_fn_outputs"] = all_loss_fn_outputs | ||
| return result |
There was a problem hiding this comment.
The current implementation of _merge_forward_backward_statuses concatenates loss_fn_outputs from all statuses in the list. This will lead to duplicated outputs when using non-Data-Parallelism (e.g., Tensor Parallelism or Sequence Parallelism), as multiple actors will process the same data chunk and return identical outputs.
To fix this, the merging logic should only collect outputs from primary ranks (where is_collection_dp_rank() is true), similar to how concatenate_outputs_after_mesh_dispatch handles TrainingOutputBatch. I suggest refactoring the function to accept actor_infos and filter the statuses accordingly.
def _merge_forward_backward_statuses(
actor_infos: List[ActorInfo], statuses: List[ForwardBackwardStatus]
) -> ForwardBackwardStatus:
"""Normalize forward/backward statuses from per-rank results into one dict.
Scalar metrics are already reduced across data-parallel ranks on the workers,
so they are taken from the first returned status. If workers emit
``loss_fn_outputs``, concatenate them in rank order from primary ranks
to reconstruct the full logical batch without duplicates from TP/SP/PP.
"""
assert len(actor_infos) == len(statuses)
assert statuses, "Expected at least one status from forward_backward"
# Use the first status for scalar metrics (assumed all-reduced)
result = dict(statuses[0])
result.pop("loss_fn_outputs", None)
# Collect loss_fn_outputs in DP rank order from collection ranks only
dp_rank_to_outputs: Dict[int, List[LossFnOutput]] = {}
for actor_info, status in zip(actor_infos, statuses):
if actor_info.rank.is_collection_dp_rank():
outputs = status.get("loss_fn_outputs")
if isinstance(outputs, list):
dp_rank_to_outputs[actor_info.rank.dp] = outputs
if dp_rank_to_outputs:
all_loss_fn_outputs: List[LossFnOutput] = []
# Ensure we concatenate in DP rank order
for i in range(actor_infos[0].rank.dp_size):
if i in dp_rank_to_outputs:
all_loss_fn_outputs.extend(dp_rank_to_outputs[i])
result["loss_fn_outputs"] = all_loss_fn_outputs
return result| return result | ||
|
|
||
| return statuses[0] | ||
| return _merge_forward_backward_statuses(statuses) |
There was a problem hiding this comment.
|
|
||
| self._save_memory_snapshot(model, "forward_backward") | ||
| return statuses[0] | ||
| return _merge_forward_backward_statuses(statuses) |
There was a problem hiding this comment.
Summary
This PR fixes the staged
WorkerDispatchforward/backward path so it preserves full-batchloss_fn_outputsthe same way as the existing unstaged path, and it cleans up the CI fallout that showed up on the initial PR run.Specifically, it:
WorkerDispatchand uses it from bothforward_backward(...)andforward_backward_from_staged(...)loss_fn_outputsin rank order instead of returning only rank 0's outputsWorkerDispatchreturn annotations and docstrings so they no longer claim a scalar-only returnWorkerDispatch's inference-client and actor-group imports type-only so the module is lighter to import in CPU test environmentsgpu_skyrl.yamlfork-safe by skipping the Anyscale submission step whenANYSCALE_CLI_TOKENis unavailable, instead of hard-failing in fork PR contexts before repo tests startRoot Cause
Issue #1520 identified that
WorkerDispatch.forward_backward(...)already had special handling to merge per-rankloss_fn_outputs, butWorkerDispatch.forward_backward_from_staged(...)returnedstatuses[0]directly.That meant staged training silently dropped
loss_fn_outputsfrom DP ranks >= 1 even though worker-side forward/backward already produced them correctly.The initial PR also surfaced two CI-specific follow-ups:
skyrl_gpu_testsworkflow was not fork-safe and failed on missing Anyscale credentials before any repo code ranImpact
After this change:
WorkerDispatchforward/backward paths return the same logical output shapeloss_fn_outputsget the full logical batch on the staged path as wellskyrl_gpu_testsjob during external credential bootstrapValidation
Passed locally:
PRE_COMMIT_HOME=/tmp/pre-commit-cache /Users/vuductai/Documents/Projects/SkyRL/.venv/bin/pre-commit run ruff --files skyrl/backends/skyrl_train/workers/worker_dispatch.py tests/backends/skyrl_train/workers/test_worker_dispatch.py --config .pre-commit-config.yamlPRE_COMMIT_HOME=/tmp/pre-commit-cache /Users/vuductai/Documents/Projects/SkyRL/.venv/bin/pre-commit run black --files skyrl/backends/skyrl_train/workers/worker_dispatch.py tests/backends/skyrl_train/workers/test_worker_dispatch.py --config .pre-commit-config.yaml/Users/vuductai/Documents/Projects/SkyRL/.venv/bin/python -m pytest tests/backends/skyrl_train/workers/test_worker_dispatch.py tests/backends/skyrl_train/distributed/test_dispatch.py -q/Users/vuductai/Documents/Projects/SkyRL/.venv/bin/python -c "import yaml, pathlib; yaml.safe_load(pathlib.Path('.github/workflows/gpu_skyrl.yaml').read_text()); print('yaml-ok')"git diff --checkNotes:
skyrl_gpu_testsfailure on this PR was an Anyscale credential/bootstrap failure, not a trainer regression.Issue
Closes #1520.