Skip to content

[train] Fix rollout metrics for step-wise and custom generators (sync / fully async)#1556

Draft
CharlieFRuan wants to merge 1 commit into
mainfrom
charlie/rollout-metrics-stepwise-fix
Draft

[train] Fix rollout metrics for step-wise and custom generators (sync / fully async)#1556
CharlieFRuan wants to merge 1 commit into
mainfrom
charlie/rollout-metrics-stepwise-fix

Conversation

@CharlieFRuan
Copy link
Copy Markdown
Member

@CharlieFRuan CharlieFRuan commented Apr 22, 2026

Motivation

Addressing three related issues in how rollout_metrics flow through the generator → trainer pipeline:

  1. merge_stepwise_output() and rollout_metrics interaction was unclear. After tracing fully_async_trainer.py and trainer.py, by the time merge_stepwise_output() is called the trainer has already consumed rollout_metrics. The old code still propagated a stale dict through slice/merge/concat. This PR makes the contract explicit in the docstring and drops the field on the merged output.

  2. get_rollout_metrics() was wrong for SkyRLGymGenerator in step-wise training. In step-wise mode response_ids[i] is a single turn's response, so token-length metrics measured per-turn, not per-trajectory. This PR follows the approach from HarborGenerator ([harbor][step-wise] Make Harbor use step-wise training #1542): use the last turn's prompt_token_ids[i] + response_ids[i] as the trajectory-length proxy (the last step's prompt already accumulates all prior turns).

  3. concatenate_generator_outputs() re-ran get_rollout_metrics() on the concatenated data. This (a) undid the step-wise handling from point 2, and (b) silently dropped custom per-generator metrics (e.g. Harbor's num_timeout_trajectories). This PR replaces the re-computation with a best-effort re-aggregation by inferring from the key name: avg/mean → mean, min/max → min/max, std → mean-of-stds (approximation — true pooled std would need per-group mean + count; see inline comment), others → sum.

Changes

  • skyrl/train/generators/skyrl_gym_generator.py — step-wise metrics use last-step (prompt + response) rows.
  • skyrl/train/generators/utils.py
    • concatenate_generator_outputs re-aggregates rollout_metrics per-key by name; preserves None-valued list fields (fixes a regression that KeyError'd in validate_generator_output when logprobs were disabled).
    • merge_stepwise_output drops rollout_metrics on output; docstring clarifies that metrics must be recorded before calling it.
    • _slice_generator_output / _merge_single_trajectory stop propagating stale rollout_metrics.
  • skyrl/train/trainer.py, skyrl/train/fully_async_trainer.py — pop rollout_metrics after recording, so downstream code doesn't accidentally re-aggregate it.
  • tests/train/generators/test_generator_output_utils.py — updated test_generator_output_concatenation to cover every aggregation branch (avg/mean, min, max, std, fallback sum); new tests for missing/None rollout_metrics, non-numeric values, and step-wise + Harbor-style custom metrics.

Test plan

  • uv run --isolated --extra skyrl-train --extra dev --extra fsdp pytest tests/train/ — all 337 tests pass locally, including the 21 in test_generator_output_utils.py.
  • Smoke-test a step-wise training run and verify generate/avg_num_tokens reflects trajectory-length rather than per-turn length.
  • Smoke-test a fully-async run with a custom generator (e.g. Harbor) and confirm custom metric keys survive concat.

🤖 Generated with Claude Code


Open in Devin Review

- `SkyRLGymGenerator.get_rollout_metrics()` now uses the last step's
  (prompt + response) as the trajectory-length proxy for step-wise
  training, matching the approach used by HarborGenerator (#1542).
- Trainer + fully-async trainer pop `rollout_metrics` from the
  `GeneratorOutput` after recording them, so downstream
  `concatenate_generator_outputs()` and `merge_stepwise_output()` do not
  accidentally re-aggregate stale metrics.
- `concatenate_generator_outputs()` no longer re-runs
  `get_rollout_metrics` (which ignored step-wise semantics and dropped
  custom-generator metrics like Harbor's `num_timeout_trajectories`).
  Instead it re-aggregates the per-group `rollout_metrics` by key name:
  avg/mean → mean, min/max → min/max, std → mean-of-stds (approximation;
  true pooled std needs per-group mean + count), others → sum.
- `merge_stepwise_output()` now drops `rollout_metrics` on the merged
  output and documents that they must be recorded before calling it.
- Fixes a regression where `None`-valued list fields
  (e.g. `rollout_logprobs` when disabled) were silently dropped by the
  concat loop, causing KeyError in `validate_generator_output`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@CharlieFRuan CharlieFRuan changed the title [train] Fix rollout metrics for step-wise and custom generators [train] Fix rollout metrics for step-wise and custom generators (sync / fully async) Apr 22, 2026
Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 1 potential issue.

View 4 additional findings in Devin Review.

Open in Devin Review

Comment on lines +270 to +273
if "avg" in k or "mean" in k:
# Works because each generator output's rollout metric is computed based on the same
# number of trajectories (even for step-wise).
rollout_metrics[k] = sum(values) / len(values)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Key-name heuristic incorrectly averages subset-based metrics like avg_tokens_non_zero_rewards

The new rollout metrics aggregation in concatenate_generator_outputs uses a key-name heuristic: any key containing "avg" or "mean" is averaged across groups (line 270-273). The comment on line 271-272 claims this works because "each generator output's rollout metric is computed based on the same number of trajectories." However, metrics like generate/avg_tokens_non_zero_rewards and generate/avg_tokens_zero_rewards (produced by get_rollout_metrics at skyrl/train/generators/utils.py:351-363) are averages over a subset of trajectories (those with non-zero or zero rewards respectively), and each group can have a different number of trajectories in that subset. Simple averaging of these per-group means produces an incorrect result compared to the properly-weighted mean.

Example showing incorrect aggregation

Group 1: 10 trajectories, 2 with non-zero rewards → avg_tokens_non_zero_rewards = 100
Group 2: 10 trajectories, 8 with non-zero rewards → avg_tokens_non_zero_rewards = 50

Heuristic result: (100+50)/2 = 75
Correct weighted result: (2×100 + 8×50) / 10 = 60

Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the handling and aggregation of rollout metrics across trainers and generators. Key changes include updating metric calculations for step-wise trajectories in SkyRLGymGenerator, refactoring concatenate_generator_outputs to aggregate metrics based on key name patterns (e.g., mean, min, max), and ensuring metrics are removed from generator outputs after being recorded to prevent redundancy. The review feedback identifies a critical TypeError when instantiating a TypedDict and suggests improvements for robustness, such as supporting NumPy numeric types and implementing case-insensitive key matching during metric aggregation.

Comment on lines +251 to +257
result = {}
for key in first:
if isinstance(first[key], list):
result[key] = _flatten_field(generator_outputs, key)
elif first[key] is None:
result[key] = None
result = GeneratorOutput(result)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The line result = GeneratorOutput(result) will raise a TypeError at runtime because TypedDict classes do not support being called with a dictionary as a positional argument (they only support keyword arguments for instantiation). Since result is already a dictionary and the function's return type hint handles the GeneratorOutput requirement, this line is both redundant and incorrect.

Suggested change
result = {}
for key in first:
if isinstance(first[key], list):
result[key] = _flatten_field(generator_outputs, key)
elif first[key] is None:
result[key] = None
result = GeneratorOutput(result)
result = {}
for key in first:
if isinstance(first[key], list):
result[key] = _flatten_field(generator_outputs, key)
elif first[key] is None:
result[key] = None

for go in generator_outputs:
per_group = go.get("rollout_metrics") or {}
for k, v in per_group.items():
if isinstance(v, (int, float)):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To better support custom generators that might return NumPy numeric types (e.g., np.float64), it is safer to include np.number in the type check. This ensures that metrics using NumPy scalars are not silently skipped during concatenation.

Suggested change
if isinstance(v, (int, float)):
if isinstance(v, (int, float, np.number)):

Comment on lines +269 to +283
for k, values in rollout_metric_keys.items():
if "avg" in k or "mean" in k:
# Works because each generator output's rollout metric is computed based on the same
# number of trajectories (even for step-wise).
rollout_metrics[k] = sum(values) / len(values)
elif "min" in k:
rollout_metrics[k] = min(values)
elif "max" in k:
rollout_metrics[k] = max(values)
elif "std" in k:
# Approximation: mean of per-group stds. True pooled std requires per-group mean and
# count, which we don't carry through.
rollout_metrics[k] = sum(values) / len(values)
else:
rollout_metrics[k] = sum(values)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The metric aggregation logic relies on string matching within key names. To make this more robust against variations in naming (e.g., capitalized keys like Generate/Avg_Tokens), it is recommended to perform case-insensitive matching using .lower().

Suggested change
for k, values in rollout_metric_keys.items():
if "avg" in k or "mean" in k:
# Works because each generator output's rollout metric is computed based on the same
# number of trajectories (even for step-wise).
rollout_metrics[k] = sum(values) / len(values)
elif "min" in k:
rollout_metrics[k] = min(values)
elif "max" in k:
rollout_metrics[k] = max(values)
elif "std" in k:
# Approximation: mean of per-group stds. True pooled std requires per-group mean and
# count, which we don't carry through.
rollout_metrics[k] = sum(values) / len(values)
else:
rollout_metrics[k] = sum(values)
for k, values in rollout_metric_keys.items():
k_lower = k.lower()
if "avg" in k_lower or "mean" in k_lower:
# Works because each generator output's rollout metric is computed based on the same
# number of trajectories (even for step-wise).
rollout_metrics[k] = sum(values) / len(values)
elif "min" in k_lower:
rollout_metrics[k] = min(values)
elif "max" in k_lower:
rollout_metrics[k] = max(values)
elif "std" in k_lower:
# Approximation: mean of per-group stds. True pooled std requires per-group mean and
# count, which we don't carry through.
rollout_metrics[k] = sum(values) / len(values)
else:
rollout_metrics[k] = sum(values)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant