[train] Fix rollout metrics for step-wise and custom generators (sync / fully async)#1556
[train] Fix rollout metrics for step-wise and custom generators (sync / fully async)#1556CharlieFRuan wants to merge 1 commit into
Conversation
- `SkyRLGymGenerator.get_rollout_metrics()` now uses the last step's (prompt + response) as the trajectory-length proxy for step-wise training, matching the approach used by HarborGenerator (#1542). - Trainer + fully-async trainer pop `rollout_metrics` from the `GeneratorOutput` after recording them, so downstream `concatenate_generator_outputs()` and `merge_stepwise_output()` do not accidentally re-aggregate stale metrics. - `concatenate_generator_outputs()` no longer re-runs `get_rollout_metrics` (which ignored step-wise semantics and dropped custom-generator metrics like Harbor's `num_timeout_trajectories`). Instead it re-aggregates the per-group `rollout_metrics` by key name: avg/mean → mean, min/max → min/max, std → mean-of-stds (approximation; true pooled std needs per-group mean + count), others → sum. - `merge_stepwise_output()` now drops `rollout_metrics` on the merged output and documents that they must be recorded before calling it. - Fixes a regression where `None`-valued list fields (e.g. `rollout_logprobs` when disabled) were silently dropped by the concat loop, causing KeyError in `validate_generator_output`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
| if "avg" in k or "mean" in k: | ||
| # Works because each generator output's rollout metric is computed based on the same | ||
| # number of trajectories (even for step-wise). | ||
| rollout_metrics[k] = sum(values) / len(values) |
There was a problem hiding this comment.
🟡 Key-name heuristic incorrectly averages subset-based metrics like avg_tokens_non_zero_rewards
The new rollout metrics aggregation in concatenate_generator_outputs uses a key-name heuristic: any key containing "avg" or "mean" is averaged across groups (line 270-273). The comment on line 271-272 claims this works because "each generator output's rollout metric is computed based on the same number of trajectories." However, metrics like generate/avg_tokens_non_zero_rewards and generate/avg_tokens_zero_rewards (produced by get_rollout_metrics at skyrl/train/generators/utils.py:351-363) are averages over a subset of trajectories (those with non-zero or zero rewards respectively), and each group can have a different number of trajectories in that subset. Simple averaging of these per-group means produces an incorrect result compared to the properly-weighted mean.
Example showing incorrect aggregation
Group 1: 10 trajectories, 2 with non-zero rewards → avg_tokens_non_zero_rewards = 100
Group 2: 10 trajectories, 8 with non-zero rewards → avg_tokens_non_zero_rewards = 50
Heuristic result: (100+50)/2 = 75
Correct weighted result: (2×100 + 8×50) / 10 = 60
Was this helpful? React with 👍 or 👎 to provide feedback.
There was a problem hiding this comment.
Code Review
This pull request refactors the handling and aggregation of rollout metrics across trainers and generators. Key changes include updating metric calculations for step-wise trajectories in SkyRLGymGenerator, refactoring concatenate_generator_outputs to aggregate metrics based on key name patterns (e.g., mean, min, max), and ensuring metrics are removed from generator outputs after being recorded to prevent redundancy. The review feedback identifies a critical TypeError when instantiating a TypedDict and suggests improvements for robustness, such as supporting NumPy numeric types and implementing case-insensitive key matching during metric aggregation.
| result = {} | ||
| for key in first: | ||
| if isinstance(first[key], list): | ||
| result[key] = _flatten_field(generator_outputs, key) | ||
| elif first[key] is None: | ||
| result[key] = None | ||
| result = GeneratorOutput(result) |
There was a problem hiding this comment.
The line result = GeneratorOutput(result) will raise a TypeError at runtime because TypedDict classes do not support being called with a dictionary as a positional argument (they only support keyword arguments for instantiation). Since result is already a dictionary and the function's return type hint handles the GeneratorOutput requirement, this line is both redundant and incorrect.
| result = {} | |
| for key in first: | |
| if isinstance(first[key], list): | |
| result[key] = _flatten_field(generator_outputs, key) | |
| elif first[key] is None: | |
| result[key] = None | |
| result = GeneratorOutput(result) | |
| result = {} | |
| for key in first: | |
| if isinstance(first[key], list): | |
| result[key] = _flatten_field(generator_outputs, key) | |
| elif first[key] is None: | |
| result[key] = None |
| for go in generator_outputs: | ||
| per_group = go.get("rollout_metrics") or {} | ||
| for k, v in per_group.items(): | ||
| if isinstance(v, (int, float)): |
There was a problem hiding this comment.
To better support custom generators that might return NumPy numeric types (e.g., np.float64), it is safer to include np.number in the type check. This ensures that metrics using NumPy scalars are not silently skipped during concatenation.
| if isinstance(v, (int, float)): | |
| if isinstance(v, (int, float, np.number)): |
| for k, values in rollout_metric_keys.items(): | ||
| if "avg" in k or "mean" in k: | ||
| # Works because each generator output's rollout metric is computed based on the same | ||
| # number of trajectories (even for step-wise). | ||
| rollout_metrics[k] = sum(values) / len(values) | ||
| elif "min" in k: | ||
| rollout_metrics[k] = min(values) | ||
| elif "max" in k: | ||
| rollout_metrics[k] = max(values) | ||
| elif "std" in k: | ||
| # Approximation: mean of per-group stds. True pooled std requires per-group mean and | ||
| # count, which we don't carry through. | ||
| rollout_metrics[k] = sum(values) / len(values) | ||
| else: | ||
| rollout_metrics[k] = sum(values) |
There was a problem hiding this comment.
The metric aggregation logic relies on string matching within key names. To make this more robust against variations in naming (e.g., capitalized keys like Generate/Avg_Tokens), it is recommended to perform case-insensitive matching using .lower().
| for k, values in rollout_metric_keys.items(): | |
| if "avg" in k or "mean" in k: | |
| # Works because each generator output's rollout metric is computed based on the same | |
| # number of trajectories (even for step-wise). | |
| rollout_metrics[k] = sum(values) / len(values) | |
| elif "min" in k: | |
| rollout_metrics[k] = min(values) | |
| elif "max" in k: | |
| rollout_metrics[k] = max(values) | |
| elif "std" in k: | |
| # Approximation: mean of per-group stds. True pooled std requires per-group mean and | |
| # count, which we don't carry through. | |
| rollout_metrics[k] = sum(values) / len(values) | |
| else: | |
| rollout_metrics[k] = sum(values) | |
| for k, values in rollout_metric_keys.items(): | |
| k_lower = k.lower() | |
| if "avg" in k_lower or "mean" in k_lower: | |
| # Works because each generator output's rollout metric is computed based on the same | |
| # number of trajectories (even for step-wise). | |
| rollout_metrics[k] = sum(values) / len(values) | |
| elif "min" in k_lower: | |
| rollout_metrics[k] = min(values) | |
| elif "max" in k_lower: | |
| rollout_metrics[k] = max(values) | |
| elif "std" in k_lower: | |
| # Approximation: mean of per-group stds. True pooled std requires per-group mean and | |
| # count, which we don't carry through. | |
| rollout_metrics[k] = sum(values) / len(values) | |
| else: | |
| rollout_metrics[k] = sum(values) |
Motivation
Addressing three related issues in how
rollout_metricsflow through the generator → trainer pipeline:merge_stepwise_output()androllout_metricsinteraction was unclear. After tracingfully_async_trainer.pyandtrainer.py, by the timemerge_stepwise_output()is called the trainer has already consumedrollout_metrics. The old code still propagated a stale dict through slice/merge/concat. This PR makes the contract explicit in the docstring and drops the field on the merged output.get_rollout_metrics()was wrong forSkyRLGymGeneratorin step-wise training. In step-wise moderesponse_ids[i]is a single turn's response, so token-length metrics measured per-turn, not per-trajectory. This PR follows the approach from HarborGenerator ([harbor][step-wise] Make Harbor use step-wise training #1542): use the last turn'sprompt_token_ids[i] + response_ids[i]as the trajectory-length proxy (the last step's prompt already accumulates all prior turns).concatenate_generator_outputs()re-ranget_rollout_metrics()on the concatenated data. This (a) undid the step-wise handling from point 2, and (b) silently dropped custom per-generator metrics (e.g. Harbor'snum_timeout_trajectories). This PR replaces the re-computation with a best-effort re-aggregation by inferring from the key name: avg/mean → mean, min/max → min/max, std → mean-of-stds (approximation — true pooled std would need per-group mean + count; see inline comment), others → sum.Changes
skyrl/train/generators/skyrl_gym_generator.py— step-wise metrics use last-step(prompt + response)rows.skyrl/train/generators/utils.pyconcatenate_generator_outputsre-aggregatesrollout_metricsper-key by name; preservesNone-valued list fields (fixes a regression that KeyError'd invalidate_generator_outputwhen logprobs were disabled).merge_stepwise_outputdropsrollout_metricson output; docstring clarifies that metrics must be recorded before calling it._slice_generator_output/_merge_single_trajectorystop propagating stalerollout_metrics.skyrl/train/trainer.py,skyrl/train/fully_async_trainer.py— poprollout_metricsafter recording, so downstream code doesn't accidentally re-aggregate it.tests/train/generators/test_generator_output_utils.py— updatedtest_generator_output_concatenationto cover every aggregation branch (avg/mean, min, max, std, fallback sum); new tests for missing/None rollout_metrics, non-numeric values, and step-wise + Harbor-style custom metrics.Test plan
uv run --isolated --extra skyrl-train --extra dev --extra fsdp pytest tests/train/— all 337 tests pass locally, including the 21 intest_generator_output_utils.py.generate/avg_num_tokensreflects trajectory-length rather than per-turn length.🤖 Generated with Claude Code