diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000000..5ab9ef4a90 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,42 @@ +# SkyRL-v2 (fleet-ai/SkyRL-v2) + +Fork of SkyRL with Fleet-specific optimizations for multi-node FSDP2 training at scale. + +## Fleet Integration + +Fleet-specific changes, fixes, and context are documented in: +- **[integrations/fleet/CHANGELOG.md](integrations/fleet/CHANGELOG.md)** — detailed changelog with root causes and fixes + +Always consult the changelog before modifying Fleet training paths (`fsdp_worker.py`, `worker.py`, `model_wrapper.py`, `dispatch.py`, `fleet-*.sh`). + +## Key Differences from Upstream SkyRL + +1. **Multi-node FSDP2 stability**: Synchronous ref model offload/backload with `torch.distributed.barrier()` in `fsdp_worker.py`. Required because cross-node colocated training has no shared CUDA context. + +2. **Chunked lm_head forward**: `model_wrapper.py` has `loss_chunk_size` support ported from the old fork. Avoids materializing full `(B, S, vocab_size)` logits — critical for 35B with 131K vocab at 97K sequence length. Without it, OOM/Xid 31 during training forward. + +3. **CUDA memory management for 35B**: `torch.cuda.empty_cache()` before backward pass in `worker.py` (policy + critic). Prevents OOM from fragmentation. + +4. **Reduced sequence length (72K) for 35B**: `fleet-35b-run.sh` uses `MAX_INPUT_LENGTH=72000` (down from 96000) with `--no-pytorch-alloc-conf` (disables `expandable_segments` which conflicts with vLLM 0.18.0's `CuMemAllocator`). At 97K, SDPA OOM'd and flash_attn hit Xid 31 in GatedDeltaNet. At 72K, flash_attn=true + chunked lm_head + empty_cache fits without expandable_segments. + +5. **`stage_chunks` pre-staging**: `dispatch.py` has a `stage_chunks` optimization (not in upstream) that pre-stages mini-batch chunks in Ray object store. Includes dynamic `mini_batch_size` adjustment for hint augmentation's variable batch sizes. + +## Training Scripts + +- `scripts/fleet-common-run.sh` — shared infra (Ray, NCCL, gIB detection, deps). Used by all runs. +- `scripts/fleet-35b-run.sh` — Qwen3.5-35B config. Calls `fleet-common-run.sh`. +- `scripts/fleet-9b-run.sh` — Qwen3.5-9B config. Calls `fleet-common-run.sh`. + +All training flags live in these scripts. Never duplicate flags in SkyPilot YAMLs or fleet-research scripts. + +## Task-Gen Metrics + +When reporting task-gen training metrics, distinguish between: +- **pass@8 / avg_raw_reward**: includes `base_quality=0.1` for passing sandbox+judge. Misleading — inflated by gate-passing alone. +- **binary variance reward**: the actual learning signal. `1.0` when solver rollouts are mixed (at least 1 pass + 1 fail), `0.0` otherwise. This is what matters. + +Report binary variance reward count (how many tasks got `reward >= 1.0`) separately from gate-pass count. Check `EVAL` log lines for `total=1.0000` vs `total=0.0000`. + +## Branch + +Primary development branch: `main` diff --git a/docs/taste/LAUNCH.md b/docs/taste/LAUNCH.md new file mode 100644 index 0000000000..af5bd75ae4 --- /dev/null +++ b/docs/taste/LAUNCH.md @@ -0,0 +1,147 @@ +# Taste-Judge GRPO Launch Recipe + +Wires `research/judge/judge.py` into the SkyRL Fleet GRPO training loop. +Reward shape is **GATED TASTE**: + +``` +effective_taste = max(taste_floor, taste_score) # 1.0 if judge fails / None +reward = verifier_reward * effective_taste +``` + +Blended only on the terminal step of each rollout, with a 10s judge timeout +and verifier-only fallback (`effective_taste = 1.0`, so reward collapses to +`verifier_reward`) on timeout/exception/None. + +### Why gated > additive + +The previous additive shape `R = alpha * verifier + (1-alpha) * taste` +rewarded "pretty failures" — a trajectory that fails the verifier (v=0) +but narrates clean intent (t high) earned `(1-alpha) * t > 0`, which +incentivized the policy to learn good-looking failure modes. Gated taste +closes this hack: `verifier=0` forces `reward=0` regardless of taste, so +there is zero gradient toward pretty-failure mimicry. Among successes, +ugly successes still earn `floor * verifier` (default `floor=0.1`) so GRPO +sees within-group taste variance and can prefer pretty successes; setting +`floor=1.0` collapses the shape to pure verifier and serves as a clean +ablation baseline. **The floor is set to 0.1 (not 0.3) because offline +analysis showed mean rescaled taste of verifier=1 trajectories is ~0.13; +floor=0.3 would clip nearly all successes and kill within-group variance. +Re-tune floor after a 50-100 step pilot using the empirical effective_taste +P25 logged in WandB.** + +## One-block launch + +```bash +# 0. From your machine: +cd /tmp && rm -rf skyrl-fleet && git clone https://github.com/fleet-ai/skyrl-fleet.git +cd /tmp/skyrl-fleet + +# 1. Apply the env patch (adds taste_floor config, _apply_taste_reward helper, +# and updates the three terminal returns + get_metrics). +git apply /Users/alliegu/Desktop/fleet/integration/env.py.diff + +# 2. Vendor the taste-judge package into the workdir Python path. +cp -r /Users/alliegu/Desktop/fleet/integration/skyrl_taste skyrl-gym/skyrl_taste +cp -r /Users/alliegu/Desktop/fleet/research/judge research/judge + +# 3. Drop the new YAML config into tasks/. +cp /Users/alliegu/Desktop/fleet/integration/configs/openenv-fleet-grpo-vl-taste.yaml \ + tasks/openenv-fleet-grpo-vl-taste.yaml + +# 4. Sky launch with the new yaml + new env vars (judge keys are NEW; the rest +# are unchanged from the existing VL launch). +sky launch tasks/openenv-fleet-grpo-vl-taste.yaml \ + --env FLEET_API_KEY="$FLEET_API_KEY" \ + --env WANDB_API_KEY="$WANDB_API_KEY" \ + --env AWS_ACCESS_KEY_ID="$AWS_ACCESS_KEY_ID" \ + --env AWS_SECRET_ACCESS_KEY="$AWS_SECRET_ACCESS_KEY" \ + --env ANTHROPIC_API_KEY="$ANTHROPIC_API_KEY" \ + --env OPENAI_API_KEY="$OPENAI_API_KEY" +``` + +## Required env vars + +- `ANTHROPIC_API_KEY` — **required**. Default judge backend (Claude via + `research/judge/judge.py`). Without it the judge import fails and the env + silently falls back to verifier-only reward (you'll see + `taste_judge_failed=True` in WandB). +- `OPENAI_API_KEY` — **only required if running inter-rater agreement + passes** (GPT-4o judge for cross-checking Claude scores during eval). Not + needed for the standard training run. +- `FLEET_API_KEY`, `WANDB_API_KEY`, `AWS_ACCESS_KEY_ID`, + `AWS_SECRET_ACCESS_KEY` — same as the upstream VL launch. + +**Important:** Invoke `judge.py` with `blind_outcome=True` at training time +to suppress outcome bleed (Stream 4 finding — when the judge sees the +verifier outcome, taste scores correlate ~0.7 with verifier and the +shaping signal collapses to a noisy duplicate of the binary reward). The +async wrapper in `skyrl_taste/judge.py` handles this; double-check the +flag is forwarded if you swap the wrapper. + +## WandB metrics to watch + +- `reward/train/mean` — gated reward; bounded above by verifier mean. +- `env/taste_reward` — judge's [0,1] raw score per trajectory. +- `env/effective_taste` — `max(floor, taste_reward)`; what actually + multiplies the verifier. +- `env/verifier_reward` — raw binary verifier per trajectory. +- `env/taste_floor` — the configured floor; sanity-check. +- `env/taste_judge_failed` — should stay near 0; spikes mean Anthropic + outage or judge parse failures (auto-fallback to pure verifier engaged). +- **Cross-check**: in within-group runs, plot Pearson(`taste_reward`, + `verifier_reward`). If correlation collapses below ~0.3, the judge is + scoring a different signal than the verifier — that's the expected case + and where the shaped-reward gradient comes from. If it climbs above + ~0.7, suspect outcome bleed (re-verify `blind_outcome=True`). +- `reward/train/variance_per_prompt` and `signal_ratio` (from + `integrations/fleet/reward_metrics.py`) should *increase* relative to a + verifier-only baseline on groups with mixed pretty/ugly successes. + +## Rollback + +**Runtime kill switch (no redeploy):** +```bash +sky exec "echo SKYRL_TASTE_DISABLED=1 >> ~/.bashrc && pkill -HUP -f main_fleet" +# or update the SkyPilot env block and re-launch with --env SKYRL_TASTE_DISABLED=1 +``` +This makes `score_trajectory_async` return `None`, the env's +`effective_taste` becomes `1.0`, and reward collapses to pure verifier. + +**Full revert (uncheck-out the patch):** +```bash +cd /tmp/skyrl-fleet +git apply -R /Users/alliegu/Desktop/fleet/integration/env.py.diff +rm -rf skyrl-gym/skyrl_taste research/judge +``` + +## Two-knob ablation (floor x grpo_norm_by_std) + +| floor \ grpo_norm_by_std | true (default) | false (recommended w/ gated taste) | +|---|---|---| +| 0.0 (pure multiplicative) | Ugly successes get R=0; group std collapses on all-ugly groups. Heavy gradient damping; expect slow learning. | Same dynamics, undamped; risk of policy ignoring ugly successes entirely. | +| 0.1 | Tiny within-success variance; std-norm wipes most of the gradient. | Tight bonus for pretty successes; conservative shaping. | +| 0.1 (default) | Tiny within-success variance from floor itself; std-norm still wipes most of the gradient. | **Headline candidate.** Multiplicative-with-cushion; closes hack and matches the empirical taste distribution. | +| 0.3 | Within-success std damped; offline data shows nearly all successes clip to floor — kills the signal. | Heavier shaping; only sensible if live taste distribution skews high. | +| 0.5 | Floor close to pretty-mid; less taste differentiation among successes. | Shallower shaping; useful as sensitivity check. | +| 1.0 (pure verifier) | **Identical to upstream baseline.** A/B control, no taste in std. | Identical to upstream too (no taste in std). | + +Recommended order: run cell `(0.1, false)` first as the headline candidate, +then `(0.1, true)` to measure the std-norm effect, then `(1.0, true)` as +the upstream baseline. `(0.0, false)` is a diagnostic: confirms the gate +itself bites (ugly successes get zero) without floor compensation. + +## Risks / gotchas + +- **Judge latency budget**: 10s timeout x `n_samples_per_prompt=4` at + `train_batch_size=50` = ~200 concurrent judge calls per training step. + Anthropic rate limits will throttle you before the GPU does. Watch + `taste_judge_failed` — sustained >10% means raise the limit or batch. +- **Reward range**: gated reward is in `[0, 1]` — same as verifier — so + pass@n threshold (`reward >= 1.0` in `reward_metrics.py:79-82`) only + triggers on `(verifier=1, taste=1.0)`. With `floor=0.1` and `verifier=1`, + blended max is 1.0 only when `taste_score=1.0`. **Pass@n will look + worse than verifier-only**; report it alongside the new gated-reward + mean, and consider plotting `verifier_reward >= 1.0` as a separate + pass@n line for direct comparison to the baseline. +- **Outcome bleed**: confirmed Stream 4 risk if the judge ever sees the + verifier outcome. Keep `blind_outcome=True` in `score_trajectory_async`. diff --git a/docs/taste/integration_map.md b/docs/taste/integration_map.md new file mode 100644 index 0000000000..9c6ef7fc00 --- /dev/null +++ b/docs/taste/integration_map.md @@ -0,0 +1,219 @@ +# Fleet GRPO Reward Integration Map + +Repo: `https://github.com/fleet-ai/skyrl-fleet` (cloned to `/tmp/skyrl-fleet-2` in sandbox; `git clone` into `/sessions/.../outputs` failed because the existing mount blocked write to `.git/`, so we cloned to `/tmp/skyrl-fleet-2`). + +The `skyrl-train` package has been merged into `skyrl/` (per `skyrl-train/README.md`). Modern code paths live under `skyrl/train/...`. + +--- + +## Reward emit point + +The Fleet env returns reward in **two places**, both in `skyrl-gym/skyrl_gym/envs/fleet_task/env.py`: + +### Per-step reward — `step_async()` returns +File: `skyrl-gym/skyrl_gym/envs/fleet_task/env.py` + +The reward is initialized to `0.0` at line **552**, populated from OpenEnv at lines **590–592** and **615–617**, and finally emitted on the `BaseTextEnvStepOutput` returns at lines **674, 708, 762**. + +``` +552 reward = 0.0 +... +588 try: +589 mcp_start = time.time() +590 obs, reward, done, info = ( +591 await self.openenv_task_env.step_async(openenv_action) +592 ) +... +613 try: +614 mcp_start = time.time() +615 obs, reward, done, info = ( +616 await self.openenv_task_env.step_async(openenv_action) +617 ) +... +672 return BaseTextEnvStepOutput( +673 observations=[], +674 reward=reward, +675 done=True, +676 metadata={...}, +677 ) +... +706 return BaseTextEnvStepOutput( +707 observations=[new_obs], +708 reward=reward, +709 done=episode_done, +710 metadata=metadata, +711 ) +... +760 return BaseTextEnvStepOutput( +761 observations=[new_obs], +762 reward=reward, +763 done=episode_done, +764 metadata=metadata, +765 ) +``` + +### Final reward fallback — `close()` / `close_async()` +For trajectories that get terminated by SkyRL (context overflow, timeout) **without** the agent emitting ``, OpenEnv's verifier is run inside `close()` / `close_async()` and the result is stashed on `self.last_reward` (lines **789–790** and **805–806**). It then surfaces via `get_metrics()` as `final_reward` (line **824**): + +``` +784 def close(self): +785 """Close the Fleet environment and cleanup resources.""" +786 if self.openenv_task_env: +787 try: +788 self.openenv_task_env.close() +789 if self.openenv_task_env.final_reward is not None: +790 self.last_reward = self.openenv_task_env.final_reward +... +796 async def close_async(self): +... +802 if self.openenv_task_env: +803 try: +804 await self.openenv_task_env.close_async() +805 if self.openenv_task_env.final_reward is not None: +806 self.last_reward = self.openenv_task_env.final_reward +``` + +The terminal reward used by the GRPO trainer comes from the last step where `done=True`, i.e. one of the three return sites above. **The clean place to inject `taste_score` is inside `step_async()` immediately before each of those three returns, when `episode_done is True`.** + +--- + +## Verifier source + +The binary `0.0 / 1.0` reward is **not computed inside this repo**. It comes back from OpenEnv's `FleetTaskEnv.step_async()` (and `close_async()`) at: + +- `skyrl-gym/skyrl_gym/envs/fleet_task/env.py:590-592` — happy path during the step where the agent submits its tool call +- `skyrl-gym/skyrl_gym/envs/fleet_task/env.py:615-617` — when the agent emits `` with no tool call +- `skyrl-gym/skyrl_gym/envs/fleet_task/env.py:788-790, 804-806` — orphaned-trajectory fallback via `openenv_task_env.final_reward` + +OpenEnv runs a programmatic Python verifier server-side; the Fleet wrapper only consumes its return value. There is also a **partial-reward** mode (not binary) toggled by `env_config.partial_reward` (constructor lines **176–181**); the VL launch script enables it (`scripts/fleet-vl-run.sh:42` — `environment.skyrl_gym.fleet_task.partial_reward=true`). Per `reward_metrics.py:79-82`, only `reward >= 1.0` counts as a "pass" in pass@n, so partial values land in `(0,1)`. + +For task-generation runs there is also `integrations/fleet/task_gen_reward.py` which applies a derived "mixed result" reward — orthogonal to the browser-use loop but worth noting because it's a precedent for shaping rewards inside this repo. + +--- + +## LLM-as-judge example + +Path: `examples/train/llm_as_a_judge/`. Four files (5 with `__init__.py`): + +| File | Purpose | +|---|---| +| `llm_judge_env.py` | The env: `GSM8kLLMJudgeEnv(BaseTextEnv)` with a synchronous `step()` that calls the OpenAI client to score an answer. | +| `main_llm_judge.py` | Ray entrypoint that registers the env id `"llm_as_a_judge"` and calls `BasePPOExp(cfg).run()`. | +| `gsm8k_dataset_judge.py` | Dataset prep: emits parquet with `env_class="llm_as_a_judge"` and `reward_spec.ground_truth`. | +| `run_llm_judge.sh` | GRPO launch (Qwen2.5-1.5B-Instruct, 4× GPU). Sets `environment.skyrl_gym.llm_as_a_judge.model="gpt-4o-mini"`. | + +What the env actually does — quoting the **only** reward-relevant section of `llm_judge_env.py`: + +```python +def _get_reward(self, action: str) -> float: + message = PROMPT + f"\n\nGOLD SOLUTION:\n{self.ground_truth}\n\nPREDICTED SOLUTION:\n{action}\n\nAnswer:" + try: + response = self.llm_judge_client.chat.completions.create( + model=self.model, messages=[{"role": "user", "content": message}] + ) + reply = response.choices[0].message.content.strip() + match = re.search(r"### Final Score:\s*([01](?:\.0)?)", reply) + if match: + return float(match.group(1)) + if reply.strip() in {"1", "0"}: + return float(reply.strip()) + return 0.0 + except Exception as e: + print(f"LLM Judge error: {type(e).__name__}: {e}") + return 0.0 + +def step(self, action: str) -> BaseTextEnvStepOutput: + done = True + reward = self._get_reward(action) + return BaseTextEnvStepOutput(observations=[], reward=reward, done=done, metadata={}) +``` + +Properties: + +- **Synchronous and blocking.** `step()` is sync and uses the `openai.OpenAI` blocking client. Each rollout's `step()` call blocks the worker thread for the full judge latency. +- **Single-turn.** The env always returns `done=True` on the first step, so the judge runs exactly once per trajectory at the very end. +- **No batching.** One judge call per rollout, no aggregation across the GRPO group of `n_samples_per_prompt` trajectories. +- **No async / no thread pool / no retry / no timeout.** Errors swallow to `0.0` (silent failure mode). +- **Caller is the env itself.** Reward is computed inline in `step()` — the trainer never knows there is an LLM judge in the loop. + +The reason this is acceptable in the GSM8k example: it is **single-turn**, run on cheap CPU-side I/O, with a tiny batch (the script uses `train_batch_size=32`, `n_samples_per_prompt=5`), and rollouts in SkyRL's generator already run concurrently via the async generator + Ray, so blocking calls in different envs proceed in parallel. The pattern does **not** scale cleanly to long multi-turn browser_use rollouts where you don't want to hold the env alive for an extra 1–3 s × group_size at the very end. + +--- + +## Async strategy + +**Recommendation: post-hoc, parallel, and out-of-step.** Specifically: + +1. **Do not call the judge inside `step_async()` per turn.** Browser-use trajectories have 50–80 turns (`MAX_TURNS=80` in the YAML); judging every step is wasteful and the judge can't reasonably score before the trajectory is done anyway. +2. **At episode end** (the `episode_done` branch in `step_async()` and inside `close_async()`), kick off the judge call **as an awaitable**. Two options, in order of cleanliness: + - **Option A (preferred):** make `score_trajectory` an `async def` that uses `httpx.AsyncClient` or `openai.AsyncOpenAI`, with `asyncio.wait_for(..., timeout=judge_timeout_s)`. SkyRL's generator already runs `step_async` inside an asyncio task per rollout, so judge calls across the entire GRPO group naturally overlap. With `n_samples_per_prompt=4` and ~50 prompts, you get 200 judge calls running concurrently and the wall-clock cost collapses to ~max(judge_latency). + - **Option B (escape hatch):** wrap the sync judge in `asyncio.to_thread(...)` (Python 3.9+) so the existing sync OpenAI/Anthropic client doesn't block the event loop. Slightly worse than A under load but a one-line change. +3. **Use `asyncio.gather` or `asyncio.wait_for` with a hard timeout** of e.g. 10 s. On timeout/exception, log a warning and fall back to `verifier_reward` only (i.e. effectively `alpha = 1.0` for that trajectory). This keeps a slow Anthropic API outage from stalling a training step. +4. **Do not gate trajectory cleanup on the judge.** Resolve the judge future, attach the score to the final `BaseTextEnvStepOutput`, and let `close_async()` proceed independently if the judge is still pending. (In practice, since `step_async` returns the terminal `done=True` step, you must either `await` the judge before the final return or do post-hoc reward attribution at the trainer level.) +5. **Optional optimization — batch by prompt-group at the trainer layer.** A more invasive variant: store the trajectory transcripts in `metadata`, then have the trainer call the judge once per GRPO group (with all `n` trajectories in one prompt) before computing advantages. This gives the judge cross-trajectory context for relative ranking and is what most production RLAIF setups do. Requires patching the trainer's reward post-processing path (where `flatten_rewards` in `integrations/fleet/reward_metrics.py` is called), not the env. Out of scope for the minimal patch but worth flagging. + +**The existing `llm_as_a_judge` example uses none of these**: it is sync, inline, single-call, single-turn, no timeout, no retry. **Do not copy it as-is for browser_use** — copy the *interface shape* (judge runs inside the env at episode end and emits a scalar in `[0,1]`) and rewrite the call to be async + timed-out. + +--- + +## GRPO config knobs + +Defaults in `skyrl/train/config/ppo_base_config.yaml` (lines 96–124), with VL overrides from `scripts/fleet-vl-run.sh`: + +| Knob | Default | VL launch override | Interaction with shaped reward | +|---|---|---|---| +| `trainer.algorithm.advantage_estimator` | `"grpo"` | `grpo` | Computes per-prompt-group advantages from raw rewards. A continuous `taste_score` increases within-group variance and produces non-zero advantages even when all trajectories pass/fail the binary verifier — exactly the desired effect. | +| `trainer.algorithm.grpo_norm_by_std` | `true` | (default) | GRPO divides advantage by group-level reward std. With binary rewards, std is 0 when the whole group passes/fails; mixing in `taste_score` raises std, which **also damps the advantage magnitude**. Watch for: groups where verifier is unanimous now produce small but non-zero advantages — the gradient signal will be tiny. May want `grpo_norm_by_std=false` once shaped reward is on. | +| `trainer.algorithm.zero_variance_filter` | `false` | `true` (line 73) | Currently masks out prompts where all rewards are identical (no signal). With shaped reward this filter would fire **far less often** since `taste_score` is approximately continuous → almost every prompt now contributes a gradient. This is good for sample efficiency but may also amplify judge noise into the policy. Consider keeping it on but with a tolerance threshold. | +| `trainer.algorithm.use_kl_loss` | `true` | `true` | KL is on the policy loss, so it is independent of reward scale. Good. | +| `trainer.algorithm.kl_loss_coef` | `0.001` | (default) | Independent of reward, no change needed. | +| `trainer.algorithm.use_kl_in_reward` | `false` | (default, mutually exclusive with `use_kl_loss`) | If you ever flip to `use_kl_in_reward=true`, the KL term gets *added to the reward* and competes directly with `taste_score`. Keep this `false`. | +| `trainer.algorithm.eps_clip_low / eps_clip_high` | `0.2 / 0.2` | (default) | PPO ratio clip. Independent of reward magnitude (operates on log-prob ratio), so safe. | +| `trainer.algorithm.advantage_batch_normalize` | `false` | (default) | If turned on, would re-normalize advantages across the whole batch. Consider enabling if the taste_score's scale + verifier mix produces unstable cross-prompt advantage magnitudes. | +| `trainer.algorithm.loss_reduction` | `"token_mean"` | `"sequence_mean"` (line 47) | Doesn't touch reward, but `sequence_mean` is what's used for VL — keep aware that gradient is per-trajectory averaged. | + +**Concrete suggestions:** +- Start with `alpha=0.5` (balanced). +- Keep `grpo_norm_by_std=true` initially; if you observe gradient norm collapse, set it to `false`. +- Bound `taste_score` to `[0,1]` (same range as verifier) so the mixed reward stays in `[0,1]` and existing pass@n / signal-ratio metrics in `integrations/fleet/reward_metrics.py` still parse correctly. +- Consider reporting `verifier_reward` and `taste_reward` separately as wandb metrics so you can disentangle their contributions — fits naturally into the existing metric schema. + +--- + +## Existing evals + +**Eval entrypoint:** `integrations/fleet/entrypoints/main_eval.py` — `FleetEvalExp(BasePPOExp).run()`. Resumes FSDP weights from S3, calls `await trainer.eval()` once (line 125), logs the dict via `trainer.tracker.log(...)`, and (optionally) uploads dump to S3. + +**Metric computation:** `integrations/fleet/reward_metrics.py` exposes: +- `flatten_rewards(rewards)` — collapses token-level rewards to scalars. +- `compute_pass_at_n(rewards, uids)` — fraction of unique prompts with **at least one rollout `>= 1.0`**. +- The module's docstring documents the wandb naming convention: `reward/{group}/pass_at_n`, `reward/{group}/variance_per_prompt`, `reward/{group}/signal_ratio`, `reward/{group}/mean_positive_reward`. + +**What gets measured today:** +- Final reward distribution (pass@n with threshold ≥ 1.0). +- Within-prompt reward variance (the GRPO learning-signal proxy). +- Signal ratio (% prompts with non-zero variance). +- Mean positive reward. +- Per-env metrics emitted from `FleetTaskEnv.get_metrics()` at lines 812–835: `task_key`, `env_key`, `turns`, `tool_calls`, `tool_errors`, `is_hinted`, `final_reward`, `verifier_stdout`, `verifier_error`, `tool_error_messages`, `chat_history`. + +**How to add a new metric (e.g. `taste_reward_mean`):** +1. In `step_async()`'s terminal returns, also stash `self.last_taste_reward` and `self.last_verifier_reward` on the env. +2. Append both to the metadata dict and to `get_metrics()` output (line 814 onward) so they flow into the trainer's metric aggregator alongside `final_reward`. +3. The trainer's `_get_response_level_rewards`/eval-dump path picks up env metadata — no further patching needed if the keys are scalar-typed. For aggregated metrics (group-level), add a function to `reward_metrics.py` modeled on `compute_pass_at_n` and call it from wherever `pass_at_n` is logged in the trainer (search `compute_pass_at_n` to find the call sites — they live in `skyrl/train/trainer.py` and `integrations/fleet/entrypoints/main_fleet_tinker.py`). +4. Test path: `integrations/fleet/tests/test_main_eval.py`. + +--- + +## Files referenced (absolute, in cloned repo) + +- `/tmp/skyrl-fleet-2/skyrl-gym/skyrl_gym/envs/fleet_task/env.py` — env, reward emit point (lines 525–765, 784–810). +- `/tmp/skyrl-fleet-2/examples/train/llm_as_a_judge/llm_judge_env.py` — sync judge example. +- `/tmp/skyrl-fleet-2/examples/train/llm_as_a_judge/main_llm_judge.py` +- `/tmp/skyrl-fleet-2/examples/train/llm_as_a_judge/gsm8k_dataset_judge.py` +- `/tmp/skyrl-fleet-2/examples/train/llm_as_a_judge/run_llm_judge.sh` +- `/tmp/skyrl-fleet-2/tasks/openenv-fleet-grpo-vl.yaml` — VL launch task. +- `/tmp/skyrl-fleet-2/scripts/fleet-vl-run.sh` — actual GRPO CLI args. +- `/tmp/skyrl-fleet-2/skyrl/train/config/ppo_base_config.yaml` — GRPO/PPO defaults (lines 96–124). +- `/tmp/skyrl-fleet-2/integrations/fleet/reward_metrics.py` — metric helpers. +- `/tmp/skyrl-fleet-2/integrations/fleet/entrypoints/main_eval.py` — eval entrypoint. +- `/tmp/skyrl-fleet-2/integrations/fleet/task_gen_reward.py` — precedent for shaping rewards inside this repo (task-gen, not browser-use). diff --git a/docs/taste/smoke_test.py b/docs/taste/smoke_test.py new file mode 100644 index 0000000000..2e7002fd17 --- /dev/null +++ b/docs/taste/smoke_test.py @@ -0,0 +1,219 @@ +"""Smoke test for the patched FleetTaskEnv reward-gating logic. + +Runs WITHOUT a real Fleet env: we duplicate the small `_apply_taste_reward` +helper that the diff installs on FleetTaskEnv (lifted verbatim from the diff +body) and exercise it against stubbed `score_trajectory_async` callables. + +Reward shape under test: + effective_taste = max(taste_floor, taste_score) (1.0 on judge fail/None) + reward = verifier_reward * effective_taste + +Cases: + (a) success + pretty taste (v=1, t=1.0, floor=0.1) -> R=1.0 + (b) success + mid taste (v=1, t=0.5, floor=0.1) -> R=0.5 + (c) success + ugly taste (v=1, t=0.0, floor=0.1) -> R=0.1 (floor) + (d) failure + pretty taste (v=0, t=1.0) -> R=0.0 (gated to 0) + (e) failure + ugly taste (v=0, t=0.0) -> R=0.0 + (f) judge timeout + success -> R=verifier (1.0) + (g) judge exception + success -> R=verifier (1.0) + (h) SKYRL_TASTE_DISABLED=1 + success -> R=verifier (1.0) + +Prints PASS/FAIL per test. Exits 0 if all pass, 1 otherwise. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +import sys +from typing import Any, Optional + + +# ----------------------------------------------------------------------------- +# Reproduce the helper installed by env.py.diff. Keep this in sync with the +# diff body (search for "async def _apply_taste_reward" in env.py.diff). +# ----------------------------------------------------------------------------- + +logger = logging.getLogger("smoke_test") + + +class _FakeFleetTaskEnv: + """Minimal stand-in for FleetTaskEnv with the attributes the helper reads.""" + + def __init__(self, floor: float, timeout_s: float, judge_async): + self.taste_floor = floor + self.taste_judge_timeout_s = timeout_s + self.task_key = "smoke-task-1" + self.task_config = {"prompt": "Send an email to bob@example.com saying hi"} + self.chat_history = [ + {"role": "system", "content": "you are a CU agent"}, + {"role": "user", "content": "Send an email..."}, + {"role": "assistant", "content": "I will click Compose."}, + {"role": "user", "content": "ok"}, + {"role": "assistant", "content": "Now I type the address."}, + ] + self.last_verifier_reward: Optional[float] = None + self.last_taste_reward: Optional[float] = None + self.last_effective_taste: Optional[float] = None + self.last_taste_judge_failed: bool = False + # Inject the stubbed judge in place of the real package. + self._judge_async = judge_async + + async def _apply_taste_reward(self, verifier_reward: float, episode_done: bool) -> float: + # Body lifted from env.py.diff (kept tight). + if not episode_done: + return verifier_reward + + self.last_verifier_reward = float(verifier_reward) + self.last_taste_reward = None + self.last_effective_taste = None + self.last_taste_judge_failed = False + + score_trajectory_async = self._judge_async + + actions = [ + {"role": m.get("role"), "content": m.get("content")} + for m in self.chat_history + if m.get("role") == "assistant" + ] + task_text = self.task_config.get("prompt", "") + outcome = bool(self.last_verifier_reward >= 1.0) + + taste_score: Optional[float] + try: + taste_score = await asyncio.wait_for( + score_trajectory_async(task_text, actions, outcome), + timeout=self.taste_judge_timeout_s, + ) + except asyncio.TimeoutError: + self.last_taste_judge_failed = True + taste_score = None + except Exception: + self.last_taste_judge_failed = True + taste_score = None + + if taste_score is None: + self.last_effective_taste = 1.0 + return verifier_reward + + taste_score = max(0.0, min(1.0, float(taste_score))) + self.last_taste_reward = taste_score + effective_taste = max(self.taste_floor, taste_score) + self.last_effective_taste = effective_taste + return verifier_reward * effective_taste + + +# ----------------------------------------------------------------------------- +# Stubbed judges +# ----------------------------------------------------------------------------- + + +def _judge_returning(value: float): + async def _inner(task: str, actions, outcome: bool) -> float: + return value + return _inner + + +async def _judge_returns_none_if_disabled(task: str, actions, outcome: bool) -> Optional[float]: + # Mimics the SKYRL_TASTE_DISABLED=1 short-circuit in skyrl_taste.judge. + if os.environ.get("SKYRL_TASTE_DISABLED") == "1": + return None + return 1.0 + + +async def _judge_slow(task: str, actions, outcome: bool) -> float: + await asyncio.sleep(5.0) + return 0.9 + + +async def _judge_raises(task: str, actions, outcome: bool) -> float: + raise RuntimeError("simulated API outage") + + +# ----------------------------------------------------------------------------- +# Test cases +# ----------------------------------------------------------------------------- + + +def _ok(name: str) -> None: + print(f"PASS: {name}") + + +def _fail(name: str, msg: str) -> None: + print(f"FAIL: {name} -> {msg}") + + +async def _check(name: str, env: _FakeFleetTaskEnv, verifier: float, expected: float, + *, expect_failed: bool = False) -> int: + r = await env._apply_taste_reward(verifier_reward=verifier, episode_done=True) + ok = abs(r - expected) < 1e-9 and env.last_taste_judge_failed is expect_failed + if ok: + _ok(name) + return 0 + _fail(name, f"r={r} expected={expected} failed={env.last_taste_judge_failed} " + f"verifier={env.last_verifier_reward} taste={env.last_taste_reward} " + f"effective={env.last_effective_taste}") + return 1 + + +async def run() -> int: + failures = 0 + floor = 0.1 + + # (a) success + pretty taste -> 1.0 + env = _FakeFleetTaskEnv(floor=floor, timeout_s=2.0, judge_async=_judge_returning(1.0)) + failures += await _check("a_success_pretty_v1_t1_floor0.1_R1.0", env, 1.0, 1.0) + + # (b) success + mid taste -> 0.5 + env = _FakeFleetTaskEnv(floor=floor, timeout_s=2.0, judge_async=_judge_returning(0.5)) + failures += await _check("b_success_mid_v1_t0.5_floor0.1_R0.5", env, 1.0, 0.5) + + # (c) success + ugly taste -> floor (0.1) + env = _FakeFleetTaskEnv(floor=floor, timeout_s=2.0, judge_async=_judge_returning(0.0)) + failures += await _check("c_success_ugly_v1_t0_floor0.1_R0.1", env, 1.0, 0.1) + + # (d) failure + pretty taste -> 0.0 (the hack is closed) + env = _FakeFleetTaskEnv(floor=floor, timeout_s=2.0, judge_async=_judge_returning(1.0)) + failures += await _check("d_failure_pretty_v0_t1_R0.0_HACK_CLOSED", env, 0.0, 0.0) + + # (e) failure + ugly taste -> 0.0 + env = _FakeFleetTaskEnv(floor=floor, timeout_s=2.0, judge_async=_judge_returning(0.0)) + failures += await _check("e_failure_ugly_v0_t0_R0.0", env, 0.0, 0.0) + + # (f) judge timeout + success -> verifier (1.0), failed=True + env = _FakeFleetTaskEnv(floor=floor, timeout_s=0.05, judge_async=_judge_slow) + failures += await _check("f_timeout_success_R_eq_verifier_1.0", env, 1.0, 1.0, + expect_failed=True) + + # (g) judge exception + success -> verifier (1.0), failed=True + env = _FakeFleetTaskEnv(floor=floor, timeout_s=2.0, judge_async=_judge_raises) + failures += await _check("g_exception_success_R_eq_verifier_1.0", env, 1.0, 1.0, + expect_failed=True) + + # (h) SKYRL_TASTE_DISABLED=1 + success -> verifier (1.0), failed=False + os.environ["SKYRL_TASTE_DISABLED"] = "1" + try: + env = _FakeFleetTaskEnv(floor=floor, timeout_s=2.0, + judge_async=_judge_returns_none_if_disabled) + failures += await _check("h_disabled_env_var_R_eq_verifier_1.0", env, 1.0, 1.0, + expect_failed=False) + # Extra invariant: effective_taste should be 1.0 in the disabled path. + if env.last_effective_taste != 1.0: + _fail("h_disabled_env_var_R_eq_verifier_1.0", + f"effective_taste={env.last_effective_taste} expected 1.0") + failures += 1 + finally: + del os.environ["SKYRL_TASTE_DISABLED"] + + return failures + + +if __name__ == "__main__": + failures = asyncio.run(run()) + if failures == 0: + print("\nALL SMOKE TESTS PASSED (8/8)") + sys.exit(0) + else: + print(f"\n{failures} TEST(S) FAILED") + sys.exit(1) diff --git a/integrations/__init__.py b/integrations/__init__.py new file mode 100644 index 0000000000..2c8aa36f02 --- /dev/null +++ b/integrations/__init__.py @@ -0,0 +1 @@ +# Fleet integrations for SkyRL diff --git a/integrations/fleet/CHANGELOG.md b/integrations/fleet/CHANGELOG.md new file mode 100644 index 0000000000..38ebbb2211 --- /dev/null +++ b/integrations/fleet/CHANGELOG.md @@ -0,0 +1,84 @@ +# Fleet Integration Changelog + +## 2026-03-29: Multi-node 35B training parity with old SkyRL fork + +Fixes for 2-node (16 GPU) Qwen3.5-35B GRPO training on GCP H200. Ported from fleet-ai/SkyRL PR #328 and PR #333, plus new fixes for SkyRL-v2-specific issues. + +### Problems + +2-node training crashed with: +1. `cudaErrorIllegalAddress` during FSDP ref model offload/backload (multi-node race) +2. OOM / Xid 31 FAULT_PDE during policy training forward+backward (missing chunked lm_head) +3. OOM / Xid 31 at 97K sequence length — SDPA too memory-hungry, flash_attn triggers GatedDeltaNet crash +4. `AssertionError: data batch size must be divisible by mini_batch_size, got 160 and 128` (hint augmentation) + +### Root causes and fixes + +#### 1. Synchronous ref offload + barrier (`fsdp_worker.py`) + +**Where:** `FSDPRefWorkerBase.offload_to_cpu()` and `backload_to_gpu()` + +**Problem:** With colocated models, the trainer cycles: ref on GPU → ref offload to CPU → policy on GPU. With `non_blocking=True`, the CPU←GPU transfer is *queued* but returns immediately. On a single node, CUDA stream ordering serializes this naturally. Across nodes, there's no shared CUDA context — node 0's policy worker can start touching GPU memory while node 1's ref worker is still mid-transfer. Result: `cudaErrorIllegalAddress`. + +**Fix:** `non_blocking=False` (wait for transfer) + `torch.distributed.barrier()` (all ranks synchronize). Guarantees every GPU finishes offloading before any policy worker starts backloading. + +**Why the old fork doesn't need this:** Designed for single-node where all workers share the same CUDA context and stream ordering prevents races. + +#### 2. Port chunked lm_head forward (`model_wrapper.py`, `fsdp_worker.py`) + +**Where:** `HFModelWrapper.forward()` and `HFModelWrapper._chunked_lm_head_forward()` + +**Problem:** SkyRL-v2's `HFModelWrapper` was missing `loss_chunk_size` support entirely — the parameter existed in config but was never passed through `fsdp_worker.py` to the model wrapper. Without it, the model materializes the full `(B, S, 131072)` logits tensor during forward pass (~10 GB for 97K-length sequences on Qwen3.5-35B with vocab_size=131072). This consumed so much GPU memory that the subsequent training forward pass (with gradients enabled) hit OOM or Xid 31 FAULT_PDE when FSDP tried to unshard parameters. + +**Fix:** Ported the chunked lm_head implementation from the old fork: +- Added `loss_chunk_size` parameter to `HFModelWrapper.__init__` +- Pass `loss_chunk_size` from `fsdp_worker.py` for both policy and ref model init +- During forward, replace `lm_head` with an identity module so the model returns hidden states `(B, S, 8192)` instead of logits `(B, S, 131072)` — 16x smaller +- Compute logits in chunks of 4096 tokens with gradient checkpointing, never materializing full logits + +**Why the old fork doesn't have this problem:** It already has `loss_chunk_size` support and passes it correctly. + +#### 3. `empty_cache` before backward (`worker.py`) + +**Where:** `PolicyWorkerBase._forward_backward_micro()` (both SFT and RL paths) and `CriticWorkerBase._forward_backward_micro()` + +**Problem:** After the forward pass, freed intermediate tensors stay in PyTorch's CUDA cache as scattered blocks. The backward pass needs large contiguous allocations for gradients. On the 35B model with tight GPU memory margins, the fragmented cache can't satisfy these allocations → OOM, even though total free memory is sufficient. + +**Fix:** `torch.cuda.empty_cache()` before `strategy.backward()`. Returns cached blocks to CUDA which coalesces them into contiguous allocations. This is especially important because `expandable_segments:True` cannot be used (see fix #4). + +**Why the old fork doesn't need this:** Targets smaller models (8B) with enough GPU headroom that fragmentation doesn't matter. + +#### 4. Reduce sequence length to 72K and disable `expandable_segments` (`fleet-35b-run.sh`) + +**Where:** `fleet-35b-run.sh` — `MAX_INPUT_LENGTH` and `--no-pytorch-alloc-conf` flag. + +**Problem:** At 97K sequences (96000 input + 4096 generate), memory was too tight even with chunked lm_head and `empty_cache`: +- `flash_attn=false` (SDPA): OOM requesting 5.95 GiB during backward — SDPA's O(n²) attention memory is too large at 97K. +- `flash_attn=true`: Xid 31 FAULT_PDE in GatedDeltaNet layers during ref model forward — reproduced at both 97K and 72K. Not a memory issue; vLLM 0.18.0's CuMemAllocator corrupts CUDA memory mappings that FSDP2 DTensor operations later touch. +- `expandable_segments:True` would help with fragmentation but conflicts with vLLM 0.18.0's `CuMemAllocator` (`cuMemCreate`/`cuMemMap`). + +**Fix:** Reduce `MAX_INPUT_LENGTH` from 96000 to 72000 (total seq ~76K) and use `flash_attn=false` (SDPA). At 72K, SDPA's O(n²) memory is ~55% of what it was at 97K — enough to fit with chunked lm_head + `empty_cache`. The `--no-pytorch-alloc-conf` flag passed to `fleet-common-run.sh` skips the default `export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True`, avoiding the vLLM 0.18.0 CuMemAllocator conflict. The 9B VL script (`fleet-vl-run.sh`) also passes this flag for the same reason. + +**Verified working:** 10 steps completed on GCP spot 2×H200:8 (asia-south1-b) with zero GPU errors over 12 hours. Step timing: generation ~7 min, ref forward ~8 min, policy backward ~44 min, total step ~70 min avg. Checkpoint saved to S3 at step 10. SDPA is slower than flash_attn but stable. WandB: `fleet_qwen35_35b_tool_use_2c0e13b7` (run ID `f6kw15p2`). + +#### 5. Dynamic mini_batch_size for hint augmentation (`dispatch.py`) + +**Where:** `MeshDispatch.stage_chunks()` + +**Problem:** `mini_batch_size` is computed as `policy_mini_batch_size * n_samples_per_prompt` (e.g., 16 × 8 = 128). But hint augmentation appends extra samples: 16 prompts × 2 hints = 32 additional, total batch = 160. The `stage_chunks` method asserted `160 % 128 == 0` → crash. + +The old fork's manual loop (`num_mini_batches = len(data) // mini_batch_size`) silently dropped the 32 hint samples — no crash, but hint training was wasted. + +**Fix:** When batch size isn't divisible by mini_batch_size, step down mini_batch_size (by `dp_size` increments to stay DP-divisible) until it divides evenly. For 160 samples with dp_size=16: adjusts from 128 → 80, giving 2 mini-batches of 80. All 160 samples (including hints) are trained on. + +**Why upstream SkyRL doesn't have this:** Upstream uses a simple `for` loop with `//` division (no `stage_chunks` optimization). The `stage_chunks` pre-staging is a SkyRL-v2 optimization that added a strict assert the old code path never had. + +### Files changed + +| File | Change | +|------|--------| +| `skyrl/backends/skyrl_train/workers/model_wrapper.py` | Port chunked lm_head forward (loss_chunk_size) | +| `skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py` | Pass loss_chunk_size to HFModelWrapper; synchronous ref offload + barrier | +| `skyrl/backends/skyrl_train/workers/worker.py` | empty_cache before backward (3 sites) | +| `scripts/fleet-35b-run.sh` | Reduce seq length to 72K, flash_attn=false, --no-pytorch-alloc-conf, wandb project rename | +| `skyrl/backends/skyrl_train/distributed/dispatch.py` | Dynamic mini_batch_size adjustment | diff --git a/integrations/fleet/__init__.py b/integrations/fleet/__init__.py new file mode 100644 index 0000000000..88f3247b76 --- /dev/null +++ b/integrations/fleet/__init__.py @@ -0,0 +1,15 @@ +# Fleet Task Environment Integration for SkyRL +# +# This module provides a SkyRL-compatible environment wrapper for Fleet-hosted tasks. +# It uses OpenEnv's FleetTaskEnv as the abstraction layer. + +__all__ = ["FleetTaskEnv"] + + +def __getattr__(name: str): + """Lazy import to avoid import errors when dependencies are not installed.""" + if name == "FleetTaskEnv": + from skyrl_gym.envs.fleet_task.env import FleetTaskEnv + + return FleetTaskEnv + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/integrations/fleet/entrypoints/__init__.py b/integrations/fleet/entrypoints/__init__.py new file mode 100644 index 0000000000..3ba7648c80 --- /dev/null +++ b/integrations/fleet/entrypoints/__init__.py @@ -0,0 +1 @@ +# Fleet entrypoints diff --git a/integrations/fleet/entrypoints/main_eval.py b/integrations/fleet/entrypoints/main_eval.py new file mode 100644 index 0000000000..d70fd7249a --- /dev/null +++ b/integrations/fleet/entrypoints/main_eval.py @@ -0,0 +1,221 @@ +""" +Fleet Task Eval-Only Entrypoint for SkyRL. + +Resumes a Fleet GRPO checkpoint from S3 (FSDP shards on every node), runs a +single evaluation pass over the eval dataset, and uploads the dumped eval +results to S3. No training loop, no optimizer state. + +Mirrors the resume + weight-sync path used by `main_fleet.py:FleetPPOExp.run()` +so the same FSDP checkpoints can be replayed against the same eval set on a +fresh cluster (e.g. for variance bars across seeds). + +Usage: + python -m integrations.fleet.entrypoints.main_eval \ + environment.env_class=fleet_task \ + environment.skyrl_gym.fleet_task.tasks_file=/path/to/tasks.json \ + data.val_data=['/path/to/validation.parquet'] \ + trainer.policy.model.path=Qwen/Qwen3.5-9B \ + trainer.run_name=my_eval_run \ + trainer.dump_eval_results=true + +Environment Variables for S3 Checkpoint Management: + AWS_ACCESS_KEY_ID: AWS access key + AWS_SECRET_ACCESS_KEY: AWS secret key + AWS_REGION: AWS region (default: us-east-1) + S3_CHECKPOINT_BUCKET: S3 bucket for FSDP checkpoints (default: skyrl-checkpoints) + S3_TRAJECTORY_BUCKET: S3 bucket for eval result uploads (default: skyrl-trajectories) + RESUME_RUN_NAME: Run name to resume from. If unset, evaluates the base + weights at trainer.policy.model.path with no FSDP load. +""" + +import asyncio +import logging +import os +import sys +from pathlib import Path + +import ray +from skyrl.train.config import SkyRLTrainConfig +from skyrl.train.entrypoints.main_base import BasePPOExp +from skyrl.train.utils import validate_cfg +from skyrl.train.utils.utils import initialize_ray + +logger = logging.getLogger(__name__) + + +def _strip_hydra_prefixes(args: list[str]) -> list[str]: + """Strip Hydra ++ and + prefixes from CLI args. + + Matches `main_fleet.py`. `from_cli_overrides` rejects +/++ prefixed args, + but our run scripts use them for environment-specific config keys that + now exist in the dataclass — so we can strip the prefix safely. + """ + cleaned = [] + for arg in args: + if arg.startswith("++"): + cleaned.append(arg[2:]) + elif arg.startswith("+"): + cleaned.append(arg[1:]) + else: + cleaned.append(arg) + return cleaned + + +class FleetEvalExp(BasePPOExp): + """Fleet eval-only experiment with optional S3 checkpoint resume. + + Reuses the trainer's FSDP weight loading and inference-engine weight sync, + then calls `trainer.eval()` once. `trainer.eval()` already handles local + eval dump and S3 upload when `trainer.dump_eval_results=true`, so this + entrypoint just needs to wire up resume + run a single eval pass. + """ + + def get_train_dataset(self): + """No train dataset is needed for eval-only runs.""" + return None + + def run(self): + trainer = self._setup_trainer() + assert trainer.eval_dataloader is not None, ( + "FleetEvalExp requires an eval dataset. Set `data.val_data` " + "and `trainer.eval_interval > 0`." + ) + + # Optional S3 resume: download FSDP shards on this VM and broadcast + # to the rest of the cluster. Mirrors FleetPPOExp.run(). + resume_run_name = os.environ.get("RESUME_RUN_NAME", "") + if resume_run_name: + try: + from integrations.fleet.s3_checkpoints import ( + broadcast_checkpoint_to_workers, + download_checkpoint_from_s3, + ) + + ckpt_path = trainer.cfg.trainer.ckpt_path + model_path = getattr(trainer.cfg.trainer.policy.model, "path", "unknown-model") + model_name = Path(model_path).name + project_name = getattr(trainer.cfg.trainer, "project_name", "skyrl") + download_checkpoint_from_s3( + ckpt_path=ckpt_path, + run_name=resume_run_name, + project_name=project_name, + model_name=model_name, + ) + broadcast_checkpoint_to_workers(ckpt_path) + except Exception as e: + logger.warning(f"Failed to download checkpoint from S3: {e}") + + asyncio.run(self._run_eval(trainer)) + + async def _run_eval(self, trainer): + """Initialize weight sync, load policy weights, and run eval once.""" + trainer.init_weight_sync_state() + + # Load only the policy FSDP shards. We bypass `trainer.load_checkpoints()` + # because it also restores `train_dataloader.state_dict()`, which is None + # in eval-only mode. Optimizer / lr scheduler state are skipped too. + self._load_policy_only(trainer) + + # Push fresh weights to the inference engine for evaluation. + await trainer.dispatch.save_weights_for_sampler() + + # `trainer.eval()` runs the eval loop and uploads to S3 when + # `dump_eval_results=true`. The S3 prefix uses `trainer.global_step`, + # which `_load_policy_only` sets from the resumed checkpoint. + eval_metrics = await trainer.eval() + trainer.tracker.log(eval_metrics, step=trainer.global_step, commit=True) + trainer.tracker.finish() + logger.info(f"Eval-only metrics: {eval_metrics}") + + def _load_policy_only(self, trainer): + """Load only the policy FSDP shards from a `global_step_` directory. + + Resolves the checkpoint path the same way `trainer.load_checkpoints()` + does (LATEST via `latest_ckpt_global_step.txt`, or FROM_PATH via + `cfg.trainer.resume_path`), then calls `dispatch.load_checkpoint` + with optimizer / scheduler state disabled. Sets `trainer.global_step` + so downstream eval dumps and S3 uploads land under the correct step. + + TODO: This duplicates the path-resolution half of + `RayPPOTrainer.load_checkpoints()`. The reason for the duplication is + that `load_checkpoints()` unconditionally calls + `self.train_dataloader.load_state_dict(...)`, which crashes when + `train_dataloader is None` (eval-only). If trainer ever grows a + `skip_dataloader_state` / `policy_only` flag, drop this helper and + call `trainer.load_checkpoints(...)` directly. + """ + from skyrl.backends.skyrl_train.utils.io import io + from skyrl.train.utils.trainer_utils import ( + GLOBAL_STEP_PREFIX, + ResumeMode, + extract_step_from_path, + validate_consistency_for_latest_checkpoint, + ) + + if trainer.resume_mode == ResumeMode.NONE: + logger.info("resume_mode=none; evaluating base model weights") + return + + if trainer.resume_mode == ResumeMode.LATEST: + latest_file = os.path.join( + trainer.cfg.trainer.ckpt_path, "latest_ckpt_global_step.txt" + ) + if not io.exists(latest_file): + logger.warning( + "resume_mode=latest but no checkpoint found at " + f"{trainer.cfg.trainer.ckpt_path}; using base weights" + ) + return + with io.open_file(latest_file, "r") as f: + step = int(f.read().strip()) + ckpt_dir = os.path.join( + trainer.cfg.trainer.ckpt_path, f"{GLOBAL_STEP_PREFIX}{step}" + ) + validate_consistency_for_latest_checkpoint( + trainer.cfg.trainer.ckpt_path, + step, + ckpt_dir, + latest_file, + trainer.cfg.trainer.ckpt_interval, + ) + else: # ResumeMode.FROM_PATH + ckpt_dir = str(trainer.cfg.trainer.resume_path) + + if not io.exists(ckpt_dir): + raise FileNotFoundError(f"Checkpoint path not found: {ckpt_dir}") + + global_step = extract_step_from_path(Path(ckpt_dir)) + if global_step == -1: + raise ValueError(f"Checkpoint path is not a valid global_step dir: {ckpt_dir}") + trainer.global_step = global_step + + policy_ckpt_dir = os.path.join(ckpt_dir, "policy") + logger.info(f"Loading policy checkpoint from {policy_ckpt_dir} (step {global_step})") + trainer.dispatch.load_checkpoint( + "policy", + policy_ckpt_dir, + load_optimizer_states=False, + load_lr_scheduler_states=False, + ) + logger.info("Successfully loaded policy checkpoint for eval") + + +@ray.remote(num_cpus=1) +def skyrl_eval_entrypoint(cfg: SkyRLTrainConfig): + """Ray remote function that runs Fleet eval-only.""" + # fleet_task env is auto-registered by skyrl_gym.envs.__init__ + exp = FleetEvalExp(cfg) + exp.run() + + +def main() -> None: + """Main entry point for Fleet task eval-only.""" + args = _strip_hydra_prefixes(sys.argv[1:]) + cfg = SkyRLTrainConfig.from_cli_overrides(args) + validate_cfg(cfg) + initialize_ray(cfg) + ray.get(skyrl_eval_entrypoint.remote(cfg)) + + +if __name__ == "__main__": + main() diff --git a/integrations/fleet/entrypoints/main_fleet.py b/integrations/fleet/entrypoints/main_fleet.py new file mode 100644 index 0000000000..940b24e1c4 --- /dev/null +++ b/integrations/fleet/entrypoints/main_fleet.py @@ -0,0 +1,116 @@ +""" +Fleet Task Training Entrypoint for SkyRL. + +Registers the FleetTaskEnv and runs GRPO training on Fleet-hosted environments +with S3 checkpoint management. + +Usage: + python -m integrations.fleet.entrypoints.main_fleet \ + environment.env_class=fleet_task \ + environment.skyrl_gym.fleet_task.tasks_file=/path/to/tasks.json \ + data.train_data=./data/fleet/train.parquet \ + data.val_data=./data/fleet/validation.parquet + +Environment Variables for S3 Checkpoint Management: + AWS_ACCESS_KEY_ID: AWS access key + AWS_SECRET_ACCESS_KEY: AWS secret key + AWS_REGION: AWS region (default: us-east-1) + S3_CHECKPOINT_BUCKET: S3 bucket name (default: skyrl-checkpoints) + RESUME_RUN_NAME: Run name to resume from (downloads checkpoint from S3) +""" + +import asyncio +import logging +import os +import sys +from pathlib import Path + +import ray +from skyrl.train.config import SkyRLTrainConfig +from skyrl.train.entrypoints.main_base import BasePPOExp +from skyrl.train.utils import validate_cfg +from skyrl.train.utils.utils import initialize_ray + +logger = logging.getLogger(__name__) + + +def _strip_hydra_prefixes(args: list[str]) -> list[str]: + """Strip Hydra ++ and + prefixes from CLI args. + + from_cli_overrides rejects +/++ prefixed args, but our run scripts use + them for environment-specific config (e.g. ++environment.skyrl_gym.task_gen.*). + Since these fields now exist in the dataclass, we can strip the prefix. + """ + cleaned = [] + for arg in args: + if arg.startswith("++"): + cleaned.append(arg[2:]) + elif arg.startswith("+"): + cleaned.append(arg[1:]) + else: + cleaned.append(arg) + return cleaned + + +class FleetPPOExp(BasePPOExp): + """Fleet-specific PPO experiment with S3 checkpoint management.""" + + def run(self): + trainer = self._setup_trainer() + + # Download checkpoint from S3 if RESUME_RUN_NAME is set (cross-VM resume) + resume_run_name = os.environ.get("RESUME_RUN_NAME", "") + if resume_run_name: + try: + from integrations.fleet.s3_checkpoints import download_checkpoint_from_s3 + + ckpt_path = trainer.cfg.trainer.ckpt_path + model_path = getattr(trainer.cfg.trainer.policy.model, "path", "unknown-model") + model_name = Path(model_path).name + project_name = getattr(trainer.cfg.trainer, "project_name", "skyrl") + download_checkpoint_from_s3( + ckpt_path=ckpt_path, + run_name=resume_run_name, + project_name=project_name, + model_name=model_name, + ) + + # Broadcast checkpoint to worker nodes (FSDP requires shards on every node) + from integrations.fleet.s3_checkpoints import broadcast_checkpoint_to_workers + broadcast_checkpoint_to_workers(ckpt_path) + except Exception as e: + logger.warning(f"Failed to download checkpoint from S3: {e}") + + # Wrap trainer for checkpoint management (cleanup + S3 upload) + try: + from integrations.fleet.s3_checkpoints import wrap_trainer_with_s3_upload + + trainer = wrap_trainer_with_s3_upload(trainer) + except Exception as e: + logger.warning(f"Failed to setup checkpoint management: {e}") + + asyncio.run(trainer.train()) + + +@ray.remote(num_cpus=1) +def skyrl_entrypoint(cfg: SkyRLTrainConfig): + """Ray remote function that runs Fleet training.""" + # fleet_task env is auto-registered by skyrl_gym.envs.__init__ + exp = FleetPPOExp(cfg) + exp.run() + + +def main() -> None: + """Main entry point for Fleet task training.""" + # Strip ++/+ prefixes from CLI args (used for env-specific config keys + # that now have proper dataclass fields) + args = _strip_hydra_prefixes(sys.argv[1:]) + # Build typed dataclass config (handles legacy flat→nested translation) + cfg = SkyRLTrainConfig.from_cli_overrides(args) + validate_cfg(cfg) + initialize_ray(cfg) + ray.get(skyrl_entrypoint.remote(cfg)) + + +if __name__ == "__main__": + main() diff --git a/integrations/fleet/entrypoints/main_fleet_tinker.py b/integrations/fleet/entrypoints/main_fleet_tinker.py new file mode 100644 index 0000000000..440bd620be --- /dev/null +++ b/integrations/fleet/entrypoints/main_fleet_tinker.py @@ -0,0 +1,1015 @@ +""" +Fleet Task Training with Tinker Backend. + +This entrypoint uses Tinker (hosted) for training and inference, +combined with Fleet environments via OpenEnv for rollout collection. + +Usage: + python -m integrations.fleet.entrypoints.main_fleet_tinker \ + --model-name Qwen/Qwen3-VL-30B-A3B-Instruct \ + --tasks-file /path/to/tasks.json \ + --dataset-file /path/to/train.parquet \ + --eval-dataset-file /path/to/validation.parquet + +Environment Variables: + TINKER_API_KEY: Tinker API key for authentication (required) + TINKER_API_URL: Tinker service URL (optional, SDK uses default if not set) + FLEET_API_KEY: Fleet API key for environment access + WANDB_API_KEY: Weights & Biases API key for logging + +Architecture: + 1. Load tasks from JSON file (same format as SkyRL Fleet integration) + 2. For each training step: + a. Save current model weights for sampling + b. Create SamplingClient from Tinker + c. Collect rollouts using FleetTaskEnv (OpenEnv) + Tinker inference + d. Compute GRPO advantages + e. Train using Tinker's forward_backward + optim_step + 3. Checkpoints saved via Tinker API + +Metrics (matching SkyRL): + - reward/avg_pass_at_{n}: Pass@k across all prompts + - reward/variance_per_prompt: Mean within-prompt reward variance (GRPO learning signal) + - reward/{env_key}/pass_at_{n}: Per-environment pass@k + - reward/{env_key}/variance_per_prompt: Per-environment variance (learning signal) + - eval/all/pass_at_{n}: Evaluation pass@k + - eval/{env_key}/pass_at_{n}: Per-environment eval pass@k +""" + +import asyncio +import logging +import os +import random +import time +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from typing import Any, Dict, List, Optional + +import numpy as np +from pydantic import BaseModel +from tqdm import tqdm +import tinker +import torch +import wandb +from tinker import types +from tinker.types.tensor_data import TensorData +from transformers import AutoTokenizer +from datasets import load_dataset +from torch.utils.data import DataLoader + +# Use SkyRL's FleetTaskEnv wrapper (now supports async via init_async/step_async) +from omegaconf import OmegaConf +from skyrl_gym.envs.fleet_task.env import FleetTaskEnv + +# Import SkyRL's overlong filtering for parity +from skyrl.train.generators.utils import apply_overlong_filtering + +# Import shared metrics module for consistent metric calculation with SkyRL trainer +from integrations.fleet.reward_metrics import ( + compute_pass_at_n as _compute_pass_at_n, + compute_reward_metrics, + compute_per_group_metrics, + sanitize_metric_key, +) + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + datefmt="%H:%M:%S", +) +logger = logging.getLogger(__name__) +logging.getLogger("httpx").setLevel(logging.WARNING) +logging.getLogger("mcp").setLevel(logging.WARNING) + +# Thread pool for env operations - isolates MCP connections per thread (like SkyRL) +_env_executor: ThreadPoolExecutor = None + + +def _get_env_executor(max_workers: int = 16) -> ThreadPoolExecutor: + global _env_executor + if _env_executor is None: + _env_executor = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="fleet-env-") + return _env_executor + + +async def _run_in_executor(func, *args): + """Run sync function in thread pool - each thread gets isolated event loop/connections.""" + loop = asyncio.get_running_loop() + return await loop.run_in_executor(_get_env_executor(), func, *args) + + +class RolloutOutput(BaseModel): + """Output from a single rollout collection.""" + + prompt_ids: List[int] + response_ids: List[int] + logprobs: List[float] + loss_mask: List[int] + reward: float + task_key: str + env_key: str + turns: int + tool_calls: int + tool_errors: int = 0 # Count of tool call errors in this rollout + stop_reason: str + duration: float + # Timing breakdown for WandB + total_gen_time: float = 0.0 # Total Tinker generation time + total_step_time: float = 0.0 # Total MCP/Fleet step time + total_tokens: int = 0 # Total tokens generated + error: Optional[str] = None + + class Config: + frozen = True + + +def set_seed(seed: int): + """Set random seed for reproducibility.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def normalize_advantages(advantages: List[float]) -> List[float]: + """Normalize advantages to have mean 0 and std 1.""" + if not advantages or len(advantages) == 1: + return advantages + mean = np.mean(advantages) + std = np.std(advantages) + if std < 1e-8: + return [0.0] * len(advantages) + return [(a - mean) / (std + 1e-8) for a in advantages] + + +def compute_advantages_grpo( + rewards: List[float], + group_size: int = None, + normalize: bool = True, +) -> List[float]: + """ + GRPO (Group Relative Policy Optimization) advantage estimation. + + For each group of trajectories from the same prompt, compute advantages + as deviations from the group mean. + """ + rewards = np.array(rewards) + + if group_size is None: + group_size = len(rewards) + + n_groups = len(rewards) // group_size + advantages = [] + + for i in range(n_groups): + start_idx = i * group_size + end_idx = start_idx + group_size + group_rewards = rewards[start_idx:end_idx] + group_mean = group_rewards.mean() + group_advantages = group_rewards - group_mean + advantages.extend(group_advantages.tolist()) + + remaining = len(rewards) % group_size + if remaining > 0: + remaining_rewards = rewards[-remaining:] + remaining_mean = remaining_rewards.mean() + advantages.extend((remaining_rewards - remaining_mean).tolist()) + + if normalize: + advantages = normalize_advantages(advantages) + + return advantages + + +def compute_pass_at_n(rollouts: List[Dict[str, Any]], n_samples_per_prompt: int) -> float: + """ + Compute pass@n metric using the shared metrics module. + + For each unique prompt (task_key), if ANY of the n trajectories has reward > 0, + that counts as a "pass". + + This function is a thin wrapper around the shared compute_pass_at_n for backward + compatibility with the rollout dict format. + """ + rewards = [r.get("reward", 0.0) for r in rollouts] + uids = [r.get("task_key", "unknown") for r in rollouts] + return _compute_pass_at_n(rewards, uids) + + +def compute_per_env_metrics(rollouts: List[Dict[str, Any]], n_samples_per_prompt: int) -> Dict[str, float]: + """ + Compute per-environment metrics using the shared metrics module. + + This function is a thin wrapper around the shared compute_per_group_metrics for + backward compatibility with the rollout dict format. + """ + rewards = [r.get("reward", 0.0) for r in rollouts] + uids = [r.get("task_key", "unknown") for r in rollouts] + env_keys = [r.get("env_key", "unknown") for r in rollouts] + + return compute_per_group_metrics( + rewards=rewards, + uids=uids, + groups=env_keys, + n_samples_per_prompt=n_samples_per_prompt, + prefix="reward", + ) + + +def compute_rollout_metrics( + rollouts: List[Dict[str, Any]], + valid_rollouts: List[Dict[str, Any]], + rewards: List[float], + advantages: List[float], + n_samples_per_prompt: int, +) -> Dict[str, Any]: + """ + Compute all rollout metrics using the shared metrics module. + + Args: + rollouts: All rollouts (including failed ones) + valid_rollouts: Only valid rollouts + rewards: Rewards for valid rollouts + advantages: GRPO advantages for valid rollouts + n_samples_per_prompt: Number of samples per prompt + + Returns: + Dict of metrics for wandb logging + """ + metrics = {} + + # Extract data for shared module + uids = [r.get("task_key", "unknown") for r in valid_rollouts] + env_keys = [r.get("env_key", "unknown") for r in valid_rollouts] + + # Core reward metrics using shared module + core_metrics = compute_reward_metrics(rewards, uids, n_samples_per_prompt) + metrics[f"reward/avg_pass_at_{n_samples_per_prompt}"] = core_metrics[f"pass_at_{n_samples_per_prompt}"] + metrics["reward/avg_raw_reward"] = np.mean(rewards) + metrics["reward/variance_per_prompt"] = core_metrics["variance_per_prompt"] + metrics["reward/mean_positive_reward"] = core_metrics["mean_positive_reward"] + + # Advantage metrics (Tinker-specific) + metrics["advantage/mean"] = np.mean(advantages) + metrics["advantage/std"] = np.std(advantages) + metrics["rollouts/valid"] = len(valid_rollouts) + metrics["rollouts/total"] = len(rollouts) + + # Per-environment reward metrics using shared module + per_env_metrics = compute_per_group_metrics( + rewards=rewards, + uids=uids, + groups=env_keys, + n_samples_per_prompt=n_samples_per_prompt, + prefix="reward", + ) + metrics.update(per_env_metrics) + + # Per-environment rollout stats (turns, tool_calls, tool_errors, duration) - Tinker-specific + rollout_stats = defaultdict(list) + for r in valid_rollouts: + env_key = sanitize_metric_key(r.get("env_key", "unknown")) + rollout_stats[f"rollout/{env_key}/turns"].append(r.get("turns", 0)) + rollout_stats[f"rollout/{env_key}/tool_calls"].append(r.get("tool_calls", 0)) + rollout_stats[f"rollout/{env_key}/tool_errors"].append(r.get("tool_errors", 0)) + rollout_stats[f"rollout/{env_key}/duration"].append(r.get("duration", 0.0)) + + for key, values in rollout_stats.items(): + metrics[key] = np.mean(values) + + # Compute tool error rate per environment + env_keys_seen = set() + for r in valid_rollouts: + env_keys_seen.add(sanitize_metric_key(r.get("env_key", "unknown"))) + for env_key in env_keys_seen: + total_calls = sum(rollout_stats[f"rollout/{env_key}/tool_calls"]) + total_errors = sum(rollout_stats[f"rollout/{env_key}/tool_errors"]) + if total_calls > 0: + metrics[f"rollout/{env_key}/tool_error_rate"] = total_errors / total_calls + else: + metrics[f"rollout/{env_key}/tool_error_rate"] = 0.0 + + # Overall rollout duration stats + durations = [r.get("duration", 0.0) for r in valid_rollouts] + metrics["rollout/avg_duration"] = np.mean(durations) + metrics["rollout/max_duration"] = np.max(durations) + metrics["rollout/min_duration"] = np.min(durations) + + return metrics + + +def prepare_training_data( + rollouts: List[Dict[str, Any]], + advantages: List[float], + tokenizer: AutoTokenizer, + max_sequence_length: int, +) -> tuple: + """ + Prepare training data from rollouts (matching SkyRL's generate_batched pattern). + + Applies: + 1. DAPO overlong filtering (zero loss mask if response doesn't end with EOS) + 2. Sequence truncation for max_sequence_length + 3. Builds Tinker Datum objects for training + + Args: + rollouts: List of rollout dicts with prompt_ids, response_ids, logprobs, loss_mask + advantages: GRPO advantages for each rollout + tokenizer: Tokenizer for EOS token ID + max_sequence_length: Maximum sequence length for training + + Returns: + Tuple of (training_datums, truncated_count) + """ + # Apply DAPO overlong filtering (zero out loss mask for truncated responses) + all_loss_masks = [r.loss_mask for r in rollouts] + stop_reasons = [r.stop_reason for r in rollouts] + filtered_loss_masks = apply_overlong_filtering(all_loss_masks, stop_reasons) + + training_datums = [] + truncated_count = 0 + + for idx, rollout in enumerate(rollouts): + prompt_ids = rollout.prompt_ids + response_ids = rollout.response_ids + logprobs = rollout.logprobs + loss_mask_data = filtered_loss_masks[idx] + + full_sequence = prompt_ids + response_ids + prompt_len = len(prompt_ids) + + # Truncate sequences exceeding model's max length for Tinker API + if len(full_sequence) > max_sequence_length: + truncated_count += 1 + full_sequence = full_sequence[:max_sequence_length] + response_len = len(full_sequence) - prompt_len + response_ids = response_ids[:response_len] + logprobs = logprobs[:response_len] if logprobs else [] + loss_mask_data = loss_mask_data[:response_len] + + # Ensure logprobs and response_ids are in sync before building training data + if len(logprobs) != len(response_ids): + logger.warning( + f"Datum {idx}: logprobs ({len(logprobs)}) != response_ids ({len(response_ids)}), fixing" + ) + if len(logprobs) > len(response_ids): + logprobs = logprobs[: len(response_ids)] + else: + logprobs = logprobs + [0.0] * (len(response_ids) - len(logprobs)) + + # Target tokens (shifted by 1) + target_tokens = full_sequence[1:] + seq_len = len(target_tokens) + + # Logprobs (0 for prompt, actual for response) + full_logprobs = [0.0] * prompt_len + logprobs + full_logprobs = full_logprobs[1:] + + # Loss mask (0 for prompt, actual for response) + full_mask = [0] * prompt_len + loss_mask_data + full_mask = full_mask[1:] + + # Safety: ensure all arrays match target_tokens length + full_logprobs = full_logprobs[:seq_len] + [0.0] * max(0, seq_len - len(full_logprobs)) + full_mask = full_mask[:seq_len] + [0] * max(0, seq_len - len(full_mask)) + + # Advantages (apply only where loss mask is 1) + advantage_value = advantages[idx] + full_advantages = torch.zeros(len(full_sequence)) + for i in range(prompt_len, len(full_sequence)): + if i - 1 < len(full_mask) and full_mask[i - 1] > 0: + full_advantages[i] = advantage_value + full_advantages = full_advantages[1:] + + datum = types.Datum( + model_input=types.ModelInput.from_ints(tokens=full_sequence[:-1]), + loss_fn_inputs={ + "target_tokens": TensorData.from_torch(torch.tensor(target_tokens)), + "logprobs": TensorData.from_torch(torch.tensor(full_logprobs)), + "advantages": TensorData.from_torch(full_advantages), + }, + ) + training_datums.append(datum) + + return training_datums, truncated_count + + +def tokenize_chat(tokenizer: AutoTokenizer, chat_history: List[Dict], add_generation_prompt: bool = True) -> List[int]: + """ + Tokenize chat history and ensure we get a plain list of token IDs. + + apply_chat_template can return different types depending on the tokenizer: + - List[int] for some tokenizers + - BatchEncoding dict with 'input_ids' key for others + + Tinker's ModelInput.from_ints() requires a plain list of integers. + """ + result = tokenizer.apply_chat_template(chat_history, add_generation_prompt=add_generation_prompt, tokenize=True) + # Handle BatchEncoding (dict-like) vs plain list + if hasattr(result, "input_ids"): + return list(result.input_ids) + elif isinstance(result, dict) and "input_ids" in result: + return list(result["input_ids"]) + else: + return list(result) + + +async def collect_fleet_rollout( + task_config: Dict[str, Any], + tasks_file: str, + sampling_client: tinker.SamplingClient, + tokenizer: AutoTokenizer, + max_turns: int = 50, + max_generate_length: int = 2048, + max_input_length: int = 30720, + temperature: float = 1.0, + top_p: float = 1.0, + stop_sequences: List[str] = None, +) -> Dict[str, Any]: + """ + Collect a single trajectory using Fleet environment and Tinker inference. + + Uses SkyRL's FleetTaskEnv wrapper with async methods for environment interaction. + + Args: + max_generate_length: Max tokens per generation step. + max_input_length: Max context length before ending rollout (matching SkyRL). + """ + rollout_start = time.time() + + task_key = task_config.get("task_key") or task_config.get("key") + + # Create SkyRL FleetTaskEnv wrapper + # TTL of 2 hours - some rollouts with many turns can take 30+ minutes + env_config = OmegaConf.create({"tasks_file": tasks_file, "ttl_seconds": 7200}) + extras = {"task_key": task_key, "max_turns": max_turns} + + env = FleetTaskEnv(env_config=env_config, extras=extras) + + try: + # Initialize environment in thread pool - isolates MCP connections + chat_history, metadata = await _run_in_executor(env.init, []) + env_key = metadata.get("env_key", "unknown") + + # Tokenize initial prompt + prompt_ids = tokenize_chat(tokenizer, chat_history, add_generation_prompt=True) + + all_response_ids = [] + all_logprobs = [] + loss_mask = [] + done = False + total_reward = 0.0 + stop_reason = "stop" + # Timing accumulators for WandB + total_gen_time = 0.0 + total_step_time = 0.0 + total_tokens = 0 + + while not done and env.turns < max_turns: + turn_num = env.turns + 1 # 1-indexed for logging + + # Prepare input for Tinker (use env's chat_history) + input_ids = tokenize_chat(tokenizer, env.chat_history, add_generation_prompt=True) + + # Check context length limit (matching SkyRL's skyrl_gym_generator.py:274) + if len(input_ids) > max_input_length: + logger.info( + f"[{task_key}] Turn {turn_num}: context length ({len(input_ids)}) exceeds max ({max_input_length}), ending" + ) + stop_reason = "length" + break + + # Generate with Tinker + gen_start = time.time() + sampling_params_kwargs = { + "max_tokens": max_generate_length, + "temperature": temperature, + "top_p": top_p, + } + if stop_sequences: + sampling_params_kwargs["stop"] = stop_sequences + sampling_params = types.SamplingParams(**sampling_params_kwargs) + + # Use async sampling to avoid blocking the event loop + result = await sampling_client.sample_async( + prompt=types.ModelInput.from_ints(tokens=input_ids), + num_samples=1, + sampling_params=sampling_params, + ) + gen_time = time.time() - gen_start + total_gen_time += gen_time + + if not result.sequences or len(result.sequences) == 0: + logger.warning(f"[{task_key}] Turn {turn_num}: no sequences returned from Tinker") + break + + sequence = result.sequences[0] + output_ids = sequence.tokens + output_logprobs = sequence.logprobs if sequence.logprobs else [] + + # Guard: logprobs must match token count (Tinker may return different lengths) + if output_logprobs and len(output_logprobs) != len(output_ids): + logger.warning( + f"[{task_key}] Turn {turn_num}: logprobs length ({len(output_logprobs)}) != tokens length ({len(output_ids)}), truncating/padding" + ) + if len(output_logprobs) > len(output_ids): + output_logprobs = output_logprobs[: len(output_ids)] + else: + output_logprobs = output_logprobs + [0.0] * (len(output_ids) - len(output_logprobs)) + + # Decode output + output_text = tokenizer.decode(output_ids, skip_special_tokens=True) + + # Collect trajectory data (assistant response tokens - trainable) + all_response_ids.extend(output_ids) + if output_logprobs: + all_logprobs.extend(output_logprobs) + else: + all_logprobs.extend([0.0] * len(output_ids)) + loss_mask.extend([1] * len(output_ids)) + + # Step environment in thread pool - isolates MCP connections + step_start = time.time() + step_output = await _run_in_executor(env.step, output_text) + step_time = time.time() - step_start + total_step_time += step_time + total_tokens += len(output_ids) + + # Get observation content for tokenization (masked out for loss) + # Note: BaseTextEnvStepOutput is a TypedDict, use dict access + if step_output["observations"]: + obs_content = step_output["observations"][0].get("content", "") + obs_ids = tokenizer.encode(obs_content, add_special_tokens=False) + all_response_ids.extend(obs_ids) + all_logprobs.extend([0.0] * len(obs_ids)) + loss_mask.extend([0] * len(obs_ids)) + + total_reward = step_output["reward"] + done = step_output["done"] + + return RolloutOutput( + prompt_ids=prompt_ids, + response_ids=all_response_ids, + logprobs=all_logprobs, + loss_mask=loss_mask, + reward=total_reward, + task_key=task_key, + env_key=env_key, + turns=env.turns, + tool_calls=env.tool_calls, + tool_errors=env.tool_errors, + stop_reason=stop_reason, + duration=time.time() - rollout_start, + total_gen_time=total_gen_time, + total_step_time=total_step_time, + total_tokens=total_tokens, + ) + + finally: + env.close() + + +async def collect_batch_rollouts( + batch: List[Dict[str, Any]], + tasks_file: str, + sampling_client: tinker.SamplingClient, + tokenizer: AutoTokenizer, + max_turns: int = 50, + max_generate_length: int = 2048, + max_input_length: int = 30720, + n_samples_per_prompt: int = 1, + max_concurrent: int = 8, + temperature: float = 1.0, + top_p: float = 1.0, + stop_sequences: List[str] = None, +) -> List[Dict[str, Any]]: + """Collect rollouts for a batch of tasks with limited concurrency. + + Args: + max_concurrent: Maximum number of concurrent Fleet environment connections. + Now safe to increase since ThreadPoolExecutor isolates connections. + """ + # Semaphore to limit concurrent Fleet environment connections + semaphore = asyncio.Semaphore(max_concurrent) + + async def collect_single_rollout(task_config: Dict[str, Any], index: int) -> tuple: + """Wrapper to collect a single rollout with error handling and concurrency limit.""" + async with semaphore: + rollout_start = time.time() + try: + rollout = await collect_fleet_rollout( + task_config=task_config, + tasks_file=tasks_file, + sampling_client=sampling_client, + tokenizer=tokenizer, + max_turns=max_turns, + max_generate_length=max_generate_length, + max_input_length=max_input_length, + temperature=temperature, + top_p=top_p, + stop_sequences=stop_sequences, + ) + return index, rollout + except Exception as e: + logger.error(f"Failed to collect rollout for {task_config.get('task_key')}: {e}") + return index, RolloutOutput( + prompt_ids=[], + response_ids=[], + logprobs=[], + loss_mask=[], + reward=0.0, + task_key=task_config.get("task_key", "unknown"), + env_key=task_config.get("env_key", "unknown"), + turns=0, + tool_calls=0, + tool_errors=0, + stop_reason="error", + error=str(e), + duration=time.time() - rollout_start, + ) + + # Create all rollout tasks (batch_size * n_samples_per_prompt) + tasks = [] + index = 0 + for task_config in batch: + for _ in range(n_samples_per_prompt): + tasks.append(collect_single_rollout(task_config, index)) + index += 1 + + total = len(tasks) + logger.info(f" Collecting {total} rollouts (max {max_concurrent} concurrent)...") + rollouts = [None] * total + completed = 0 + last_logged = 0 + log_interval = max(1, total // 4) # Log at ~25%, 50%, 75%, 100% + + # Run rollouts with limited concurrency via semaphore + for coro in asyncio.as_completed(tasks): + idx, rollout = await coro + rollouts[idx] = rollout + completed += 1 + + # Log progress at intervals + if completed - last_logged >= log_interval or completed == total: + logger.info(f" Progress: {completed}/{total} rollouts completed") + last_logged = completed + + return rollouts + + +def collate_fn(batch): + """Return batch as-is without tensor collation.""" + return batch + + +async def main( + model_name: str = "Qwen/Qwen3-VL-30B-A3B-Instruct", + tasks_file: str = None, + dataset_file: str = None, + eval_dataset_file: str = None, + batch_size: int = 8, + eval_batch_size: int = 32, + learning_rate: float = 4e-5, + lora_rank: int = 16, + max_steps: int = 200, + max_turns: int = 50, + max_generate_length: int = 2048, + max_input_length: int = 30720, + max_sequence_length: int = 32768, + n_samples_per_prompt: int = 4, + eval_every: int = 20, + seed: int = 42, + wandb_project: str = "fleet-tinker-grpo", + wandb_name: str = None, + temperature: float = 1.0, + top_p: float = 1.0, + stop_sequences: List[str] = None, + loss_fn: str = "ppo", +): + """ + Main training loop using Tinker for training/inference and Fleet for environments. + """ + set_seed(seed) + + # Setup WandB run name + if wandb_name is None: + wandb_name = f"{model_name.split('/')[-1]}_{datetime.now().strftime('%m%d_%H%M')}" + + # Initialize WandB + if stop_sequences is None: + stop_sequences = [] + + wandb.init( + project=wandb_project, + name=wandb_name, + config={ + "model_name": model_name, + "batch_size": batch_size, + "learning_rate": learning_rate, + "lora_rank": lora_rank, + "max_turns": max_turns, + "max_generate_length": max_generate_length, + "max_input_length": max_input_length, + "max_sequence_length": max_sequence_length, + "n_samples_per_prompt": n_samples_per_prompt, + "temperature": temperature, + "top_p": top_p, + "stop_sequences": stop_sequences, + "loss_fn": loss_fn, + }, + ) + + # Load datasets + train_dataset = load_dataset("parquet", data_files=dataset_file)["train"] + eval_dataset = load_dataset("parquet", data_files=eval_dataset_file)["train"] if eval_dataset_file else None + + logger.info(f"Loaded {len(train_dataset)} training samples") + if eval_dataset: + logger.info(f"Loaded {len(eval_dataset)} eval samples") + + # Setup Tinker + tinker_url = os.environ.get("TINKER_API_URL") + tinker_api_key = os.environ.get("TINKER_API_KEY") + + service_client_kwargs = {} + if tinker_url: + service_client_kwargs["base_url"] = tinker_url + if tinker_api_key: + service_client_kwargs["api_key"] = tinker_api_key + + service_client = tinker.ServiceClient(**service_client_kwargs) + training_client = await service_client.create_lora_training_client_async(base_model=model_name, rank=lora_rank) + + adam_params = types.AdamParams(learning_rate=learning_rate, beta1=0.9, beta2=0.95, eps=1e-8) + tokenizer = AutoTokenizer.from_pretrained(model_name) + + # Create dataloader + def create_dataloader(epoch: int): + return DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + collate_fn=collate_fn, + generator=torch.Generator().manual_seed(seed + epoch), + ) + + steps_per_epoch = (len(train_dataset) + batch_size - 1) // batch_size + current_epoch = 0 + train_dataloader = create_dataloader(current_epoch) + train_iterator = iter(train_dataloader) + + # Training loop + pbar = tqdm(range(max_steps), desc="Training", unit="step") + for step in pbar: + step_start = time.time() + metrics = {"step": step, "epoch": step // steps_per_epoch} + + # Get sampler weights for rollout inference + sampling_path = training_client.save_weights_for_sampler(name=f"step_{step:06d}").result().path + sampling_client = service_client.create_sampling_client(model_path=sampling_path) + + # Get batch + try: + batch = next(train_iterator) + except StopIteration: + current_epoch += 1 + train_dataloader = create_dataloader(current_epoch) + train_iterator = iter(train_dataloader) + batch = next(train_iterator) + + # Collect rollouts + logger.info(f"Step {step}: Collecting rollouts for {len(batch)} tasks...") + rollout_start = time.time() + + rollouts = await collect_batch_rollouts( + batch=batch, + tasks_file=tasks_file, + sampling_client=sampling_client, + tokenizer=tokenizer, + max_turns=max_turns, + max_generate_length=max_generate_length, + max_input_length=max_input_length, + n_samples_per_prompt=n_samples_per_prompt, + temperature=temperature, + top_p=top_p, + stop_sequences=stop_sequences, + ) + + metrics["time/rollout"] = time.time() - rollout_start + + # Filter valid rollouts and log invalid ones + # Note: rollouts are RolloutOutput Pydantic objects - use attribute access + valid_rollouts = [] + invalid_rollouts = [] + for r in rollouts: + if r.response_ids and not r.error: + valid_rollouts.append(r) + else: + invalid_rollouts.append(r) + + if invalid_rollouts: + for r in invalid_rollouts: + task_key = r.task_key + error = r.error or "no response_ids" + stop_reason = r.stop_reason + logger.warning(f"Step {step}: Invalid rollout for {task_key}: {error} (stop_reason={stop_reason})") + metrics["rollouts/invalid"] = len(invalid_rollouts) + + if not valid_rollouts: + logger.warning(f"Step {step}: No valid rollouts, skipping") + continue + + # Compute GRPO advantages + rewards = [r.reward for r in valid_rollouts] + advantages = compute_advantages_grpo(rewards, group_size=n_samples_per_prompt, normalize=True) + + # Compute all rollout metrics (convert to dicts for metrics functions) + rollout_metrics = compute_rollout_metrics( + rollouts=[r.model_dump() for r in rollouts], + valid_rollouts=[r.model_dump() for r in valid_rollouts], + rewards=rewards, + advantages=advantages, + n_samples_per_prompt=n_samples_per_prompt, + ) + metrics.update(rollout_metrics) + + # Compute timing metrics from valid rollouts + gen_times = [r.total_gen_time for r in valid_rollouts] + step_times = [r.total_step_time for r in valid_rollouts] + tokens = [r.total_tokens for r in valid_rollouts] + durations = [r.duration for r in valid_rollouts] + + metrics["time/gen_total"] = sum(gen_times) + metrics["time/gen_mean"] = np.mean(gen_times) + metrics["time/step_total"] = sum(step_times) + metrics["time/step_mean"] = np.mean(step_times) + metrics["time/gen_pct"] = 100 * sum(gen_times) / sum(durations) if sum(durations) > 0 else 0 + metrics["time/step_pct"] = 100 * sum(step_times) / sum(durations) if sum(durations) > 0 else 0 + metrics["throughput/tokens_total"] = sum(tokens) + metrics["throughput/tokens_per_sec_gen"] = sum(tokens) / sum(gen_times) if sum(gen_times) > 0 else 0 + metrics["throughput/tokens_per_sec_effective"] = sum(tokens) / sum(durations) if sum(durations) > 0 else 0 + + # Prepare training data (DAPO filtering + truncation + datum creation) + training_datums, truncated_count = prepare_training_data( + rollouts=valid_rollouts, + advantages=advantages, + tokenizer=tokenizer, + max_sequence_length=max_sequence_length, + ) + + metrics["rollouts/truncated_overlong"] = truncated_count + if truncated_count > 0: + logger.info(f"Step {step}: Truncated {truncated_count} overlong sequences") + + if not training_datums: + logger.warning(f"Step {step}: No valid training sequences after filtering, skipping") + continue + + # Training step + logger.info(f"Step {step}: Training on {len(training_datums)} sequences...") + train_start = time.time() + + fwd_bwd_future = training_client.forward_backward(training_datums, loss_fn=loss_fn) + optim_step_future = training_client.optim_step(adam_params) + + fwd_bwd_future.result() + optim_step_future.result() + + metrics["time/train"] = time.time() - train_start + metrics["time/total"] = time.time() - step_start + + # Log metrics (commit=True forces immediate sync) + wandb.log(metrics, step=step, commit=True) + pbar.set_postfix( + { + f"pass@{n_samples_per_prompt}": f"{metrics[f'reward/avg_pass_at_{n_samples_per_prompt}']:.3f}", + "reward": f"{metrics['reward/avg_raw_reward']:.3f}", + "time": f"{metrics['time/total']:.1f}s", + } + ) + + # Evaluation + if eval_every > 0 and eval_dataset and step % eval_every == 0: + logger.info(f"Step {step}: Running evaluation...") + eval_dataloader = DataLoader(eval_dataset, batch_size=eval_batch_size, shuffle=False, collate_fn=collate_fn) + + all_eval_rollouts = [] + for eval_batch in eval_dataloader: + eval_rollouts = await collect_batch_rollouts( + batch=eval_batch, + tasks_file=tasks_file, + sampling_client=sampling_client, + tokenizer=tokenizer, + max_turns=max_turns, + max_generate_length=max_generate_length, + max_input_length=max_input_length, + n_samples_per_prompt=1, + temperature=temperature, + top_p=top_p, + stop_sequences=stop_sequences, + ) + all_eval_rollouts.extend([r for r in eval_rollouts if not r.error]) + + if all_eval_rollouts: + eval_rewards = [r.reward for r in all_eval_rollouts] + # Convert to dicts for metrics functions + eval_rollouts_dicts = [r.model_dump() for r in all_eval_rollouts] + eval_pass_at_1 = compute_pass_at_n(eval_rollouts_dicts, 1) + eval_per_env = compute_per_env_metrics(eval_rollouts_dicts, 1) + + eval_metrics = { + "eval/all/pass_at_1": eval_pass_at_1, + "eval/all/mean_positive_reward": ( + np.mean([r for r in eval_rewards if r > 0]) if any(r > 0 for r in eval_rewards) else 0.0 + ), + "eval/num_samples": len(all_eval_rollouts), + } + # Add per-env eval metrics (rename from reward/ to eval/) + for key, value in eval_per_env.items(): + eval_key = key.replace("reward/", "eval/") + eval_metrics[eval_key] = value + + wandb.log(eval_metrics, step=step, commit=True) + logger.info(f"Step {step}: eval pass@1={eval_pass_at_1:.3f}, num_samples={len(all_eval_rollouts)}") + + wandb.finish() + logger.info("Training completed!") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Fleet Task Training with Tinker") + parser.add_argument("--model-name", type=str, default="Qwen/Qwen3-VL-30B-A3B-Instruct") + parser.add_argument("--tasks-file", type=str, required=True, help="Path to tasks JSON file") + parser.add_argument("--dataset-file", type=str, required=True, help="Path to training parquet") + parser.add_argument("--eval-dataset-file", type=str, default=None, help="Path to eval parquet") + parser.add_argument("--batch-size", type=int, default=8) + parser.add_argument("--eval-batch-size", type=int, default=32) + parser.add_argument("--learning-rate", type=float, default=4e-5) + parser.add_argument("--lora-rank", type=int, default=16) + parser.add_argument("--max-steps", type=int, default=200) + parser.add_argument("--max-turns", type=int, default=50) + parser.add_argument("--max-generate-length", type=int, default=2048, help="Max tokens per generation") + parser.add_argument("--max-input-length", type=int, default=30720, help="Max context length before ending rollout") + parser.add_argument("--max-sequence-length", type=int, default=32768, help="Max sequence length for training") + parser.add_argument("--n-samples-per-prompt", type=int, default=4) + parser.add_argument("--eval-every", type=int, default=20) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--wandb-project", type=str, default="fleet-tinker-grpo") + parser.add_argument("--wandb-name", type=str, default=None) + parser.add_argument( + "--track-extra-gradient-metrics", + type=bool, + default=False, + help="Track additional gradient metrics (for parity with SkyRL config)", + ) + parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature") + parser.add_argument("--top-p", type=float, default=1.0, help="Top-p (nucleus) sampling") + parser.add_argument( + "--stop-sequences", + type=str, + default="[]", + help="JSON list of stop sequences (e.g. '[\"\"]')", + ) + parser.add_argument( + "--loss-fn", + type=str, + default="ppo", + help="Loss function for Tinker forward_backward (e.g. ppo, grpo)", + ) + + args = parser.parse_args() + + import json as _json + + stop_sequences = _json.loads(args.stop_sequences) + + asyncio.run( + main( + model_name=args.model_name, + tasks_file=args.tasks_file, + dataset_file=args.dataset_file, + eval_dataset_file=args.eval_dataset_file, + batch_size=args.batch_size, + eval_batch_size=args.eval_batch_size, + learning_rate=args.learning_rate, + lora_rank=args.lora_rank, + max_steps=args.max_steps, + max_turns=args.max_turns, + max_generate_length=args.max_generate_length, + max_input_length=args.max_input_length, + max_sequence_length=args.max_sequence_length, + n_samples_per_prompt=args.n_samples_per_prompt, + eval_every=args.eval_every, + seed=args.seed, + wandb_project=args.wandb_project, + wandb_name=args.wandb_name, + temperature=args.temperature, + top_p=args.top_p, + stop_sequences=stop_sequences, + loss_fn=args.loss_fn, + ) + ) diff --git a/integrations/fleet/entrypoints/main_task_gen.py b/integrations/fleet/entrypoints/main_task_gen.py new file mode 100644 index 0000000000..1c19fabd63 --- /dev/null +++ b/integrations/fleet/entrypoints/main_task_gen.py @@ -0,0 +1,83 @@ +""" +Task Generation Training Entrypoint for SkyRL. + +Registers the TaskGenEnv and runs GRPO training for task generation +with S3 checkpoint management. + +Usage: + python -m integrations.fleet.entrypoints.main_task_gen \ + environment.env_class=task_gen \ + data.train_data=./data/task_gen/train.parquet \ + data.val_data=./data/task_gen/validation.parquet +""" + +import asyncio +import logging +import os +import sys +from pathlib import Path + +import ray +from skyrl.train.config import SkyRLTrainConfig +from skyrl.train.entrypoints.main_base import BasePPOExp +from skyrl.train.utils import validate_cfg +from skyrl.train.utils.utils import initialize_ray + +logger = logging.getLogger(__name__) + + +class FleetPPOExp(BasePPOExp): + """Fleet-specific PPO experiment with S3 checkpoint management.""" + + def run(self): + trainer = self._setup_trainer() + + resume_run_name = os.environ.get("RESUME_RUN_NAME", "") + if resume_run_name: + try: + from integrations.fleet.s3_checkpoints import download_checkpoint_from_s3 + + ckpt_path = trainer.cfg.trainer.ckpt_path + model_path = getattr(trainer.cfg.trainer.policy.model, "path", "unknown-model") + model_name = Path(model_path).name + project_name = getattr(trainer.cfg.trainer, "project_name", "skyrl") + download_checkpoint_from_s3( + ckpt_path=ckpt_path, + run_name=resume_run_name, + project_name=project_name, + model_name=model_name, + ) + except Exception as e: + logger.warning(f"Failed to download checkpoint from S3: {e}") + + try: + from integrations.fleet.s3_checkpoints import wrap_trainer_with_s3_upload + + trainer = wrap_trainer_with_s3_upload(trainer) + except Exception as e: + logger.warning(f"Failed to setup checkpoint management: {e}") + + asyncio.run(trainer.train()) + + +@ray.remote(num_cpus=1) +def skyrl_entrypoint(cfg: SkyRLTrainConfig): + """Ray remote function that registers TaskGenEnv and runs training.""" + # task_gen env is registered in skyrl_gym.envs.__init__ + exp = FleetPPOExp(cfg) + exp.run() + + +def main() -> None: + """Main entry point for task generation training.""" + from integrations.fleet.entrypoints.main_fleet import _strip_hydra_prefixes + + args = _strip_hydra_prefixes(sys.argv[1:]) + cfg = SkyRLTrainConfig.from_cli_overrides(args) + validate_cfg(cfg) + initialize_ray(cfg) + ray.get(skyrl_entrypoint.remote(cfg)) + + +if __name__ == "__main__": + main() diff --git a/integrations/fleet/export_tasks.py b/integrations/fleet/export_tasks.py new file mode 100644 index 0000000000..7b1ed956fc --- /dev/null +++ b/integrations/fleet/export_tasks.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +"""Export tasks from Fleet API to JSON file. + +Usage: + python -m integrations.fleet.export_tasks --output ~/data/fleet/tasks.json + python -m integrations.fleet.export_tasks --output ~/data/fleet/tasks.json --env-key github +""" + +import argparse +import json +import os +import sys + + +def export_tasks(output_file: str, env_key: str | None = None, modality: str = "tool_use"): + """Export tasks from Fleet API to JSON file.""" + try: + from fleet import Fleet + except ImportError: + print("Fleet SDK not available. Install with: pip install fleet-python") + sys.exit(1) + + api_key = os.environ.get("FLEET_API_KEY") + if not api_key: + print("ERROR: FLEET_API_KEY environment variable not set") + sys.exit(1) + + fleet = Fleet(api_key=api_key) + + print(f"Loading tasks from Fleet API (env_key={env_key})...") + tasks = fleet.load_tasks(env_key=env_key) + print(f"Loaded {len(tasks)} tasks") + + # Convert to JSON format + task_dicts = [] + for task in tasks: + task_dicts.append( + { + "key": task.key, + "prompt": task.prompt, + "env_id": task.env_id, + "version": task.version, + "data_id": task.data_id, + "data_version": task.data_version, + "verifier_func": task.verifier_func, + "task_modality": modality, + } + ) + + # Ensure output directory exists + os.makedirs(os.path.dirname(os.path.expanduser(output_file)), exist_ok=True) + + output_path = os.path.expanduser(output_file) + with open(output_path, "w") as f: + json.dump(task_dicts, f, indent=2) + + print(f"Exported {len(task_dicts)} tasks to {output_path}") + + +def main(): + parser = argparse.ArgumentParser(description="Export Fleet tasks to JSON") + parser.add_argument("--output", "-o", required=True, help="Output JSON file path") + parser.add_argument("--env-key", default=None, help="Filter by environment key") + parser.add_argument("--modality", default="tool_use", help="Task modality") + args = parser.parse_args() + + export_tasks(args.output, args.env_key, args.modality) + + +if __name__ == "__main__": + main() diff --git a/integrations/fleet/prepare_dataset.py b/integrations/fleet/prepare_dataset.py new file mode 100644 index 0000000000..3bc8601ad9 --- /dev/null +++ b/integrations/fleet/prepare_dataset.py @@ -0,0 +1,633 @@ +""" +Prepare Fleet tasks for SkyRL training. + +Converts Fleet task JSON files to SkyRL parquet dataset format. + +Usage: + python -m integrations.fleet.prepare_dataset \ + --tasks-json /path/to/all_tool_use.json \ + --output-dir ./data/fleet \ + --modality tool_use + +Split Strategy: + - Stratified by environment (each env maintains train/eval ratio) + - Hash-based deterministic assignment (same task always goes to same split) + - 20% eval ratio, capped at 20 samples per env (MAX_EVAL_SAMPLES) + - Minimum 5 eval samples per env (otherwise all go to train) + - Held-out eval envs: instacart (computer_use only) + +v0.3.2 Changes: + - Increased eval_ratio from 10% to 20% to include carlisle/outlook in eval + - Result: 11 envs in eval (was 9), ~183 eval samples (was ~146) + +v0.3.1 Changes: + - Added MAX_ENV_TRAIN_RATIO=0.20 to prevent any single env from dominating + - Hash-based deterministic sampling for reproducibility + +v0.3.0 Changes: + - Increased eval_ratio from 2% to 10% + - Added MAX_EVAL_SAMPLES=30 cap per environment + - MIN_EVAL_SAMPLES stays at 5 + - Result: ticketmaster now gets ~22 eval samples for trace analysis +""" + +import argparse +import hashlib +import json +import os +from collections import defaultdict +from typing import Any, Dict, List, Optional + +from datasets import Dataset + +# Held-out environments for eval only (not used in train) +HELD_OUT_ENVS = { + "tool_use": [], # v0.3: all envs split normally (outlook now included in train) + "computer_use": [], + "browser_use": [], +} + +# Excluded environments (removed from both train and eval) +# v0.3.6: google-maps excluded due to broken MCP server (502 errors, "database is locked") +# v0.4.0: dropbox excluded due to broken env (instance creation timeouts) +EXCLUDED_ENVS = { + "tool_use": ["dropbox"], + "computer_use": ["dropbox"], + "browser_use": ["dropbox"], +} + +# Tasks excluded due to missing CURRENT_DATE in env_variables (v0.4.0) +# These tasks have partial dates (e.g., "January 30th" without year) but their +# tool calls require mm/dd/yy format. Without CURRENT_DATE, the model cannot +# compute the correct year, causing date validation failures. +# See: https://github.com/fleet-ai/SkyRL/pull/246 +TASKS_MISSING_CURRENT_DATE = { + "task_a44hx6crecg4_1769052238469_i7dxxtjvq", # zillow - February 1st + "task_a7rlslof7gdy_1768337837679_8be6pguu3", # zillow - March 11th + "task_axtmgwocana_1768544478249_k2ozcylyf", # zillow - January 21st + "task_b1fxgn0k3yms_1768542773490_ddbhj5bai", # zillow - January 30th + "task_b4v77hb3owof_1768546181946_efsedxv9g", # zillow - February 14th + "task_b5zt6ipf0nbl_1768346335430_i23gknp4t", # zillow - January 15th + "task_bafrpi5qgyzh_1768546181946_2cebmq91r", # zillow - February 14th + "task_bdmnfipwxlqv_1769052238469_4nglwjqfm", # zillow - February 1st + "task_bxqzfjc2dbte_1768337837679_2qvnm9rq7", # zillow - March 11th + "task_c3jwlxmfvbop_1768544478249_efo6hxylr", # zillow - January 21st + "task_c7o0c7ehhv9t_1768542773490_2t9w2l1z5", # zillow - January 30th + "task_ceqj4h9t0ygi_1768346335430_8j1w8w5xp", # zillow - January 15th + "task_cgpxfxp78bvp_1768346335430_6v4n8wlt8", # zillow - January 15th + "task_cgsz56tqjlv6_1768346335430_hqgsjy4wt", # zillow - January 15th + "task_dpv4bpdpz6db_1768542773490_f3g6w8e8g", # zillow - January 30th + "task_f7lgb6fxfwln_1768337837679_d1dxk6ahv", # zillow - March 11th + "task_fl1rq3d2wbj9_1768337837679_d2x4k8p93", # zillow - March 11th + "task_fn1k5mvjx6r1_1768544478249_1nfmnp6r2", # zillow - January 21st + "task_fnh5f0x7hv6w_1768544478249_8wptm6zqp", # zillow - January 21st + "task_g2dwb1rfx69c_1769052238469_bc1y9h9d7", # zillow - February 1st + "task_g3wpj1mcl0lf_1768546181946_59vtqn9fw", # zillow - February 14th +} + +# Minimum number of samples required to create an eval split for an env +MIN_EVAL_SAMPLES = 5 + +# Maximum number of eval samples per environment (v0.3.1: reduced from 30 to 20) +# Ensures small envs get eval traces without blowing up eval set size +MAX_EVAL_SAMPLES = 20 + +# Maximum fraction of training data any single environment can have (v0.3.1) +# Prevents dominant environments from skewing training +MAX_ENV_TRAIN_RATIO = 0.20 + +# Maximum total eval prompts across all environments (v0.3.2) +# With eval_n_samples_per_prompt=3 and 30s per trajectory: +# 96 prompts × 3 samples = 288 trajectories (~8 tasks/env × 12 envs) +MAX_EVAL_PROMPTS = 96 + + +def load_tasks_from_json(json_path: str) -> List[Dict[str, Any]]: + """Load tasks from JSON file (Fleet export format).""" + with open(json_path, "r") as f: + data = json.load(f) + + # Handle both formats: array or {"tasks": [...]} + if isinstance(data, list): + return data + elif isinstance(data, dict) and "tasks" in data: + return data["tasks"] + else: + raise ValueError("Invalid JSON format: expected array or object with 'tasks' key") + + +def hash_to_split(task_key: str, eval_ratio: float = 0.10) -> str: + """Deterministically assign task to train or eval based on hash. + + Uses MD5 hash of task_key to get a deterministic float in [0, 1). + This ensures the same task always goes to the same split. + """ + hash_bytes = hashlib.md5(task_key.encode()).digest() + hash_int = int.from_bytes(hash_bytes[:8], byteorder="big") + hash_float = hash_int / (2**64) + return "eval" if hash_float < eval_ratio else "train" + + +def hash_to_float(task_key: str) -> float: + """Convert task_key to deterministic float in [0, 1) for sampling.""" + hash_bytes = hashlib.md5(task_key.encode()).digest() + hash_int = int.from_bytes(hash_bytes[:8], byteorder="big") + return hash_int / (2**64) + + +def cap_training_distribution( + train_records: List[Dict[str, Any]], + max_env_ratio: float, +) -> tuple[List[Dict[str, Any]], Dict[str, Dict[str, int]]]: + """Cap each environment's contribution to training data. + + Uses hash-based deterministic sampling so the same tasks are always selected. + + Args: + train_records: List of training records with 'data_source' (env_key) and 'task_key' + max_env_ratio: Maximum fraction any single env can contribute (e.g., 0.20 = 20%) + + Returns: + Tuple of (capped_records, cap_stats) where cap_stats shows per-env before/after counts + """ + if max_env_ratio >= 1.0: + return train_records, {} + + total_train = len(train_records) + max_per_env = int(total_train * max_env_ratio) + + # Group by environment + records_by_env: Dict[str, List[Dict[str, Any]]] = defaultdict(list) + for record in train_records: + env_key = record.get("data_source", "unknown") + records_by_env[env_key].append(record) + + # Cap each environment + capped_records = [] + cap_stats: Dict[str, Dict[str, int]] = {} + + for env_key, records in records_by_env.items(): + before_count = len(records) + + if before_count <= max_per_env: + # No capping needed + capped_records.extend(records) + cap_stats[env_key] = {"before": before_count, "after": before_count, "capped": False} + else: + # Sort by hash for deterministic selection + records_sorted = sorted(records, key=lambda r: hash_to_float(r.get("task_key", ""))) + selected = records_sorted[:max_per_env] + capped_records.extend(selected) + cap_stats[env_key] = {"before": before_count, "after": max_per_env, "capped": True} + + return capped_records, cap_stats + + +def prepare_fleet_dataset( + tasks_json: str, + output_dir: str, + modality: Optional[str] = "tool_use", + eval_ratio: float = 0.20, # v0.3.2: increased to 20% to include carlisle/outlook in eval + env_filter: Optional[str] = None, + difficulty_filter: Optional[str] = None, # v0.4.0: filter by difficulty (1=easy, 2=medium, 3=hard) + max_tasks: Optional[int] = None, + max_env_ratio: float = MAX_ENV_TRAIN_RATIO, # v0.3.1: cap dominant environments + max_eval_prompts: Optional[int] = MAX_EVAL_PROMPTS, # v0.3.2: cap total eval prompts + env_class: str = "fleet_task", # SkyRL env_class per record (fleet_task or task_gen) +): + """ + Convert Fleet tasks JSON to SkyRL parquet dataset. + + Args: + tasks_json: Path to Fleet tasks JSON file + output_dir: Output directory for parquet files + modality: Task modality filter ("tool_use" or "computer_use"), None for all + eval_ratio: Fraction of data for evaluation (default: 0.02) + env_filter: Optional env_key filter (e.g., "github", "booking") + max_tasks: Optional maximum number of tasks to include + max_env_ratio: Maximum fraction any single env can contribute to training (default: 0.20) + env_class: SkyRL env_class per record (default: "fleet_task", use "task_gen" for task generation) + """ + # Log applied filters at the start + print("\n=== Dataset Filters ===") + print(f" Source: {tasks_json}") + print(f" Modality: {modality or 'all'}") + print(f" Env filter: {env_filter or 'none'}") + print(f" Difficulty filter: {difficulty_filter or 'all (1,2,3)'}") + print(f" Max tasks: {max_tasks or 'unlimited'}") + print(f" Max env ratio: {max_env_ratio:.0%}") + print(f" Max eval prompts: {max_eval_prompts or 'unlimited'}") + print() + + print(f"Loading tasks from {tasks_json}...") + tasks = load_tasks_from_json(tasks_json) + print(f"Loaded {len(tasks)} tasks") + + # Filter by modality if specified + if modality: + tasks = [t for t in tasks if t.get("task_modality") == modality] + print(f"After modality filter ({modality}): {len(tasks)} tasks") + + # Filter by env_key(s) if specified - supports comma-separated list + if env_filter: + env_list = [e.strip() for e in env_filter.split(",") if e.strip()] + tasks = [t for t in tasks if t.get("env_key") in env_list or t.get("env_id") in env_list] + print(f"After env filter ({env_list}): {len(tasks)} tasks") + + # Filter by difficulty if specified - supports comma-separated list (e.g., "1,2" for easy+medium) + if difficulty_filter: + diff_list = [int(d.strip()) for d in difficulty_filter.split(",") if d.strip()] + tasks = [t for t in tasks if t.get("difficulty") in diff_list] + print(f"After difficulty filter ({diff_list}): {len(tasks)} tasks") + + # Limit tasks if specified + if max_tasks and len(tasks) > max_tasks: + tasks = tasks[:max_tasks] + print(f"Limited to {max_tasks} tasks") + + if not tasks: + print("No tasks remaining after filtering. Exiting.") + return + + # Deduplicate by task_key (keep first occurrence) + seen_task_keys: set = set() + unique_tasks = [] + duplicate_count = 0 + env_duplicate_counts: Dict[str, int] = defaultdict(int) + + for task in tasks: + task_key = task.get("key") or task.get("task_key") + if not task_key: + continue + if task_key in seen_task_keys: + duplicate_count += 1 + env_key = task.get("env_key") or task.get("env_id") or "unknown" + env_duplicate_counts[env_key] += 1 + else: + seen_task_keys.add(task_key) + unique_tasks.append(task) + + if duplicate_count > 0: + print(f"\n⚠️ WARNING: Removed {duplicate_count} duplicate task_keys") + print(" By environment:") + for env, count in sorted(env_duplicate_counts.items(), key=lambda x: -x[1]): + print(f" {env}: {count} duplicates removed") + print() + + tasks = unique_tasks + print(f"After deduplication: {len(tasks)} unique tasks") + + # Get excluded envs for this modality (removed entirely) + excluded_envs = set(EXCLUDED_ENVS.get(modality, [])) + if excluded_envs: + before_count = len(tasks) + tasks = [t for t in tasks if t.get("env_key") not in excluded_envs] + print(f"Excluded environments: {excluded_envs}") + print(f"After excluding envs: {len(tasks)} tasks (removed {before_count - len(tasks)})") + + # Exclude specific tasks missing CURRENT_DATE + if TASKS_MISSING_CURRENT_DATE: + before_count = len(tasks) + tasks = [t for t in tasks if (t.get("key") or t.get("task_key")) not in TASKS_MISSING_CURRENT_DATE] + removed = before_count - len(tasks) + if removed > 0: + print(f"Excluded tasks missing CURRENT_DATE: {removed} tasks") + + # Get held-out envs for this modality + held_out_envs = set(HELD_OUT_ENVS.get(modality, [])) + if held_out_envs: + print(f"Held-out test environments: {held_out_envs}") + + # Group tasks by environment + tasks_by_env: Dict[str, List[Dict[str, Any]]] = defaultdict(list) + for task in tasks: + env_key = task.get("env_key") or task.get("env_id") or "unknown" + tasks_by_env[env_key].append(task) + + # Collect per-env metadata: representative env_variables and env_variable_keys + # (mirrors original SkyRL fork's _collect_env_metadata) + env_metadata: Dict[str, Dict[str, Any]] = {} + for env_key, env_tasks_list in tasks_by_env.items(): + all_var_keys: set = set() + representative_env_vars: Dict[str, Any] = {} + for t in env_tasks_list: + env_vars = t.get("env_variables") or {} + if isinstance(env_vars, str): + try: + env_vars = json.loads(env_vars) + except json.JSONDecodeError: + env_vars = {} + all_var_keys.update(env_vars.keys()) + if not representative_env_vars and env_vars: + representative_env_vars = dict(env_vars) + env_metadata[env_key] = { + "env_variable_keys": sorted(all_var_keys), + "env_variables": representative_env_vars, + } + print("\nEnvironment metadata:") + for ek in sorted(env_metadata): + meta = env_metadata[ek] + print(f" {ek}: env_vars={meta['env_variable_keys']}") + + # Prepare records with stratified split + train_records = [] + eval_records = [] + + # Track per-env counts for summary table + env_split_counts: Dict[str, Dict[str, int]] = {} + + print("\n=== Per-Environment Split ===") + for env_key in sorted(tasks_by_env.keys()): + env_tasks = tasks_by_env[env_key] + + # Check if this env is held out for eval only + if env_key in held_out_envs: + env_eval_count = 0 + for task in env_tasks: + record = _task_to_record(task, env_key, env_class=env_class, env_meta=env_metadata.get(env_key)) + if record: + eval_records.append(record) + env_eval_count += 1 + env_split_counts[env_key] = {"train": 0, "eval": env_eval_count} + print(f" {env_key}: {len(env_tasks)} -> EVAL only (held-out)") + continue + + # Calculate target eval size: use ratio but cap at MAX_EVAL_SAMPLES + target_eval_size = min(int(len(env_tasks) * eval_ratio), MAX_EVAL_SAMPLES) + + # If not enough samples for eval, put all in train + if target_eval_size < MIN_EVAL_SAMPLES: + env_train_count = 0 + for task in env_tasks: + record = _task_to_record(task, env_key, env_class=env_class, env_meta=env_metadata.get(env_key)) + if record: + train_records.append(record) + env_train_count += 1 + env_split_counts[env_key] = {"train": env_train_count, "eval": 0} + print(f" {env_key}: {len(env_tasks)} -> all TRAIN (< {MIN_EVAL_SAMPLES} eval samples)") + continue + + # Compute effective eval ratio to achieve target_eval_size (capped at MAX_EVAL_SAMPLES) + effective_eval_ratio = target_eval_size / len(env_tasks) + + # Stratified split using hash with effective ratio + env_train = 0 + env_eval = 0 + for task in env_tasks: + task_key = task.get("key") or task.get("task_key") + record = _task_to_record(task, env_key, env_class=env_class, env_meta=env_metadata.get(env_key)) + if not record: + continue + + split = hash_to_split(task_key, effective_eval_ratio) + if split == "eval": + eval_records.append(record) + env_eval += 1 + else: + train_records.append(record) + env_train += 1 + + env_split_counts[env_key] = {"train": env_train, "eval": env_eval} + print(f" {env_key}: {len(env_tasks)} -> {env_train} train, {env_eval} eval") + + print(f"\nTotal: {len(train_records)} train, {len(eval_records)} eval") + + # Apply total eval cap (v0.3.2) - stratified sampling across environments + if max_eval_prompts and len(eval_records) > max_eval_prompts: + print(f"\n=== Capping Eval Prompts ({max_eval_prompts} max total) ===") + + # Group by environment + eval_by_env: Dict[str, List[Dict[str, Any]]] = defaultdict(list) + for record in eval_records: + eval_by_env[record.get("data_source", "unknown")].append(record) + + # Take min(8, available) from each env, then distribute remaining quota proportionally + min_per_env = 8 + capped_eval_records = [] + + for env_key, records in eval_by_env.items(): + # Sort by hash for deterministic selection + records.sort(key=lambda r: hash_to_float(r.get("task_key", ""))) + # Take at least min_per_env (or all if fewer available) + take = min(min_per_env, len(records)) + capped_eval_records.extend(records[:take]) + + # If we have budget remaining, distribute round-robin across envs + remaining_budget = max_eval_prompts - len(capped_eval_records) + if remaining_budget > 0: + # Records not yet selected (sorted by hash for determinism) + remaining_by_env = { + env: records[min_per_env:] for env, records in eval_by_env.items() if len(records) > min_per_env + } + + # Round-robin until budget exhausted + env_keys = sorted(remaining_by_env.keys()) + idx = 0 + while remaining_budget > 0 and any(remaining_by_env.values()): + env = env_keys[idx % len(env_keys)] + if remaining_by_env[env]: + capped_eval_records.append(remaining_by_env[env].pop(0)) + remaining_budget -= 1 + idx += 1 + + # Update env_split_counts + for env_key in eval_by_env: + count = sum(1 for r in capped_eval_records if r.get("data_source") == env_key) + if env_key in env_split_counts: + env_split_counts[env_key]["eval"] = count + print(f" {env_key}: {len(eval_by_env[env_key])} -> {count}") + + eval_records = capped_eval_records + print(f"\nAfter capping: {len(eval_records)} eval prompts") + + # Apply per-environment cap to training data (v0.3.1) + if max_env_ratio < 1.0 and train_records: + train_records, cap_stats = cap_training_distribution(train_records, max_env_ratio) + + # Print capping summary + capped_envs = [env for env, stats in cap_stats.items() if stats["capped"]] + if capped_envs: + print(f"\n=== Training Distribution Cap ({max_env_ratio:.0%} max per env) ===") + for env in sorted(capped_envs): + stats = cap_stats[env] + print(f" {env}: {stats['before']} -> {stats['after']} ({stats['before'] - stats['after']} removed)") + print(f"\nAfter capping: {len(train_records)} train") + + # Update env_split_counts with capped values + for env, stats in cap_stats.items(): + if env in env_split_counts: + env_split_counts[env]["train"] = stats["after"] + + # Create datasets + train_dataset = Dataset.from_list(train_records) if train_records else None + eval_dataset = Dataset.from_list(eval_records) if eval_records else None + + # Save to parquet + os.makedirs(output_dir, exist_ok=True) + + if train_dataset: + train_path = os.path.join(output_dir, "train.parquet") + train_dataset.to_parquet(train_path) + print(f"Saved train dataset to {train_path}") + + if eval_dataset: + eval_path = os.path.join(output_dir, "validation.parquet") + eval_dataset.to_parquet(eval_path) + print(f"Saved validation dataset to {eval_path}") + + # Print summary statistics + print("\n=== Dataset Summary ===") + print(f"Train: {len(train_records)}") + print(f"Eval: {len(eval_records)} (includes held-out: {held_out_envs or 'none'})") + + # Print per-environment breakdown table + print("\n=== Per-Environment Breakdown ===") + print(f"{'Environment':<20} {'Train':>8} {'Eval':>8} {'Total':>8}") + print("-" * 48) + for env_key in sorted(env_split_counts.keys()): + counts = env_split_counts[env_key] + total = counts["train"] + counts["eval"] + print(f"{env_key:<20} {counts['train']:>8} {counts['eval']:>8} {total:>8}") + print("-" * 48) + print( + f"{'TOTAL':<20} {len(train_records):>8} {len(eval_records):>8} " f"{len(train_records) + len(eval_records):>8}" + ) + + +def _task_to_record( + task: Dict[str, Any], + env_key: str, + env_class: str = "fleet_task", + env_meta: Optional[Dict[str, Any]] = None, +) -> Optional[Dict[str, Any]]: + """Convert a task dict to a dataset record. + + Args: + task: Task dict from Fleet JSON + env_key: Environment identifier + env_class: SkyRL env class (fleet_task or task_gen) + env_meta: Per-env metadata with representative env_variables and env_variable_keys + """ + task_key = task.get("key") or task.get("task_key") + prompt = task.get("prompt", "") + + if not task_key or not prompt: + return None + + # Use per-task env_variables if available, otherwise fall back to + # representative per-env values (some tasks lack env_variables) + task_env_vars = task.get("env_variables") or {} + if isinstance(task_env_vars, str): + try: + task_env_vars = json.loads(task_env_vars) + except json.JSONDecodeError: + task_env_vars = {} + if not task_env_vars and env_meta: + task_env_vars = env_meta.get("env_variables", {}) + + env_var_keys = (env_meta or {}).get("env_variable_keys", []) + + record = { + # Required fields for SkyRL + "prompt": [{"role": "user", "content": prompt}], + "env_class": env_class, + # Task identification (passed as env_extras) + "task_key": task_key, + # Data source for per-environment metrics in WandB + "data_source": env_key, + # Environment/data fields needed by TaskGenEnv for orchestrator provisioning + "data_key": task.get("data_key") or "", + "data_version": task.get("data_version") or "", + "env_version": task.get("env_version") or "", + "env_variables": json.dumps(task_env_vars), + "env_variable_keys": json.dumps(env_var_keys), + } + return record + + +def main(): + parser = argparse.ArgumentParser(description="Prepare Fleet tasks for SkyRL training") + parser.add_argument( + "--tasks-json", + type=str, + required=True, + help="Path to Fleet tasks JSON file", + ) + parser.add_argument( + "--output-dir", + type=str, + default="./data/fleet", + help="Output directory for parquet files", + ) + parser.add_argument( + "--modality", + type=str, + default="tool_use", + choices=["tool_use", "computer_use", "browser_use", "all"], + help="Task modality filter ('all' for no filter)", + ) + parser.add_argument( + "--eval-ratio", + type=float, + default=0.20, + help="Fraction of data for evaluation (default: 0.20)", + ) + parser.add_argument( + "--env-filter", + type=str, + default=None, + help="Optional env_key filter (e.g., 'github', 'booking')", + ) + parser.add_argument( + "--difficulty-filter", + type=str, + default=None, + help="Optional difficulty filter: 1=easy, 2=medium, 3=hard (e.g., '1,2' for easy+medium)", + ) + parser.add_argument( + "--max-tasks", + type=int, + default=None, + help="Maximum number of tasks to include", + ) + parser.add_argument( + "--max-env-ratio", + type=float, + default=MAX_ENV_TRAIN_RATIO, + help=f"Maximum fraction of training data per environment (default: {MAX_ENV_TRAIN_RATIO})", + ) + parser.add_argument( + "--max-eval-prompts", + type=int, + default=MAX_EVAL_PROMPTS, + help=f"Maximum total eval prompts across all environments (default: {MAX_EVAL_PROMPTS})", + ) + parser.add_argument( + "--env-class", + type=str, + default="fleet_task", + choices=["fleet_task", "task_gen"], + help="SkyRL env_class per record (default: fleet_task, use task_gen for task generation)", + ) + + args = parser.parse_args() + + # Handle 'all' modality + modality = None if args.modality == "all" else args.modality + + prepare_fleet_dataset( + tasks_json=args.tasks_json, + output_dir=args.output_dir, + modality=modality, + eval_ratio=args.eval_ratio, + env_filter=args.env_filter, + difficulty_filter=args.difficulty_filter, + max_tasks=args.max_tasks, + max_env_ratio=args.max_env_ratio, + max_eval_prompts=args.max_eval_prompts, + env_class=args.env_class, + ) + + +if __name__ == "__main__": + main() diff --git a/integrations/fleet/reward_metrics.py b/integrations/fleet/reward_metrics.py new file mode 100644 index 0000000000..3978ee1acd --- /dev/null +++ b/integrations/fleet/reward_metrics.py @@ -0,0 +1,253 @@ +"""Unified reward metrics for SkyRL and Tinker. + +This module provides shared metric calculation functions used by both: +- SkyRL trainer (skyrl_train/trainer.py, skyrl_train/utils/trainer_utils.py) +- Tinker integration (integrations/fleet/entrypoints/main_fleet_tinker.py) + +All metrics follow the same naming convention for WandB logging: +- reward/{group}/pass_at_{n} - Pass@n metric for group +- reward/{group}/variance_per_prompt - Mean within-prompt reward variance (GRPO learning signal) +- reward/{group}/signal_ratio - Fraction of prompts with non-zero variance (% with signal) +- reward/{group}/mean_positive_reward - Mean of positive rewards for group + +Rewards can be in two formats: +- Scalar rewards: List[float] - one reward per trajectory +- Token-level rewards: List[List[float]] - per-token rewards per trajectory (summed to scalar) +""" + +from collections import defaultdict +from typing import Any, Dict, List, Union + +import numpy as np + + +def flatten_rewards(rewards: Union[List[float], List[List[float]]]) -> List[float]: + """Flatten rewards to scalar format. + + Handles both scalar rewards (List[float]) and token-level rewards (List[List[float]]). + For token-level rewards, sums each trajectory's rewards into a single scalar. + + Args: + rewards: Either List[float] (scalar per trajectory) or + List[List[float]] (token-level per trajectory) + + Returns: + List[float]: Flattened scalar rewards, one per trajectory + """ + if not rewards: + return [] + + flat_rewards: List[float] = [] + for r in rewards: + if isinstance(r, list): + # Token-level rewards: sum to get trajectory reward + flat_rewards.append(float(sum(r))) + else: + flat_rewards.append(float(r)) + return flat_rewards + + +def sanitize_metric_key(key: str) -> str: + """Sanitize metric key for wandb (replace / with _). + + Args: + key: Raw metric key that may contain slashes + + Returns: + Sanitized key with slashes replaced by underscores + """ + return key.replace("/", "_") + + +def compute_pass_at_n( + rewards: Union[List[float], List[List[float]]], + uids: List[str], +) -> float: + """Compute pass@n: fraction of unique prompts with at least one fully successful rollout. + + For each unique prompt (identified by uid), if ANY of its rollouts achieves a + perfect reward (>= 1.0), that prompt counts as a "pass". This metric measures + how often the model can fully solve a task when given multiple attempts. + Partial rewards (e.g. 0.3 from partial_reward mode) do not count as a pass. + + Args: + rewards: List of rewards (one per rollout). Can be scalar (List[float]) + or token-level (List[List[float]]) - will be flattened. + uids: List of unique IDs (one per rollout, same uid = same prompt) + + Returns: + Float between 0.0 and 1.0 representing the fraction of prompts that passed + """ + flat_rewards = flatten_rewards(rewards) + uid_to_rewards: Dict[str, List[float]] = defaultdict(list) + for uid, reward in zip(uids, flat_rewards): + uid_to_rewards[uid].append(reward) + + if not uid_to_rewards: + return 0.0 + + passed = sum(1 for r_list in uid_to_rewards.values() if any(r >= 1.0 for r in r_list)) + return passed / len(uid_to_rewards) + + +def compute_variance_per_prompt( + rewards: Union[List[float], List[List[float]]], + uids: List[str], +) -> float: + """Compute mean within-prompt reward variance (GRPO learning signal). + + For GRPO to learn, there must be variance in rewards within each prompt's rollouts. + If all rollouts for a prompt get the same reward, there's no learning signal. + + This metric computes the variance of rewards for each prompt, then returns the + mean variance across all prompts. + + Args: + rewards: List of rewards (one per rollout). Can be scalar (List[float]) + or token-level (List[List[float]]) - will be flattened. + uids: List of unique IDs (one per rollout, same uid = same prompt) + + Returns: + Mean variance across prompts. Higher = more learning signal. + Returns 0.0 if no prompts or all prompts have single rollouts. + """ + flat_rewards = flatten_rewards(rewards) + uid_to_rewards: Dict[str, List[float]] = defaultdict(list) + for uid, reward in zip(uids, flat_rewards): + uid_to_rewards[uid].append(reward) + + if not uid_to_rewards: + return 0.0 + + # Compute variance for each prompt (need at least 2 samples for variance) + variances = [] + for r_list in uid_to_rewards.values(): + if len(r_list) >= 2: + variances.append(float(np.var(r_list))) + + return float(np.mean(variances)) if variances else 0.0 + + +def compute_signal_ratio( + rewards: Union[List[float], List[List[float]]], + uids: List[str], +) -> float: + """Compute fraction of prompts with non-zero variance (GRPO signal ratio). + + This metric shows what percentage of prompts have any learning signal at all. + A prompt has signal if at least one rollout differs from others (variance > 0). + + Unlike variance_per_prompt (which averages variance magnitudes), this metric + counts how many prompts contribute ANY signal, making it easier to interpret: + - 100% = every prompt has at least one differing rollout + - 0% = all prompts have identical rewards across rollouts (no learning possible) + + Args: + rewards: List of rewards (one per rollout). Can be scalar (List[float]) + or token-level (List[List[float]]) - will be flattened. + uids: List of unique IDs (one per rollout, same uid = same prompt) + + Returns: + Float between 0.0 and 1.0 representing fraction of prompts with signal. + Returns 0.0 if no prompts or all prompts have single rollouts. + """ + flat_rewards = flatten_rewards(rewards) + uid_to_rewards: Dict[str, List[float]] = defaultdict(list) + for uid, reward in zip(uids, flat_rewards): + uid_to_rewards[uid].append(reward) + + if not uid_to_rewards: + return 0.0 + + # Count prompts with variance > 0 (need at least 2 samples) + prompts_with_signal = 0 + prompts_total = 0 + for r_list in uid_to_rewards.values(): + if len(r_list) >= 2: + prompts_total += 1 + if np.var(r_list) > 0: + prompts_with_signal += 1 + + return prompts_with_signal / prompts_total if prompts_total > 0 else 0.0 + + +def compute_reward_metrics( + rewards: Union[List[float], List[List[float]]], + uids: List[str], + n_samples_per_prompt: int, +) -> Dict[str, float]: + """Compute core reward metrics. + + Args: + rewards: List of rewards (one per rollout). Can be scalar (List[float]) + or token-level (List[List[float]]) - will be flattened. + uids: List of unique IDs for pass@n grouping + n_samples_per_prompt: Number of samples per prompt (used in metric key name) + + Returns: + Dictionary with keys: + - "pass_at_{n}": Pass@n metric + - "variance_per_prompt": Mean within-prompt reward variance (GRPO learning signal) + - "signal_ratio": Fraction of prompts with non-zero variance (% with signal) + - "mean_positive_reward": Mean of positive rewards only + """ + # Flatten rewards once for efficiency (each sub-function would otherwise flatten again) + flat_rewards = flatten_rewards(rewards) + pass_at_n = compute_pass_at_n(flat_rewards, uids) + variance = compute_variance_per_prompt(flat_rewards, uids) + signal_ratio = compute_signal_ratio(flat_rewards, uids) + positive_rewards = [r for r in flat_rewards if r > 0] + mean_positive = float(np.mean(positive_rewards)) if positive_rewards else 0.0 + + return { + f"pass_at_{n_samples_per_prompt}": pass_at_n, + "variance_per_prompt": variance, + "signal_ratio": signal_ratio, + "mean_positive_reward": mean_positive, + } + + +def compute_per_group_metrics( + rewards: Union[List[float], List[List[float]]], + uids: List[str], + groups: List[str], + n_samples_per_prompt: int, + prefix: str = "reward", +) -> Dict[str, float]: + """Compute metrics grouped by a key (env_key, data_source, etc). + + This function computes reward metrics for each group separately, enabling + per-environment analysis in training and evaluation. + + Args: + rewards: List of rewards (one per rollout). Can be scalar (List[float]) + or token-level (List[List[float]]) - will be flattened. + uids: List of unique IDs for pass@n grouping within each group + groups: List of group keys (e.g., env_key or data_source per rollout) + n_samples_per_prompt: Number of samples per prompt (used in metric key name) + prefix: Metric prefix ("reward" for training, "eval" for evaluation) + + Returns: + Dictionary with keys like: + - "{prefix}/{group}/avg_score" + - "{prefix}/{group}/pass_at_{n}" + - "{prefix}/{group}/mean_positive_reward" + """ + # Flatten rewards once before grouping + flat_rewards = flatten_rewards(rewards) + + # Group data by group key + group_data: Dict[str, Dict[str, List[Any]]] = defaultdict(lambda: {"rewards": [], "uids": []}) + for reward, uid, group in zip(flat_rewards, uids, groups): + group_key = group if group is not None else "unknown" + group_data[group_key]["rewards"].append(reward) + group_data[group_key]["uids"].append(uid) + + metrics: Dict[str, float] = {} + for group_key, data in group_data.items(): + sanitized = sanitize_metric_key(group_key) + group_metrics = compute_reward_metrics(data["rewards"], data["uids"], n_samples_per_prompt) + for metric_name, value in group_metrics.items(): + metrics[f"{prefix}/{sanitized}/{metric_name}"] = value + + return metrics diff --git a/integrations/fleet/s3_checkpoints.py b/integrations/fleet/s3_checkpoints.py new file mode 100644 index 0000000000..f6e75f8a7c --- /dev/null +++ b/integrations/fleet/s3_checkpoints.py @@ -0,0 +1,718 @@ +""" +S3 Checkpoint Management for SkyRL Training. + +Provides checkpoint upload to S3, download from S3 for resume, and local cleanup. + +Key behavior: +- Cleans up old local checkpoints BEFORE saving new one (prevents disk full) +- Uploads to S3 asynchronously (non-blocking, training continues) +- Downloads checkpoint from S3 before training for cross-VM resume +- Uploads eval results to S3 for persistence + +Usage: + from integrations.fleet.s3_checkpoints import ( + wrap_trainer_with_s3_upload, + download_checkpoint_from_s3, + upload_eval_results_to_s3, + ) + + # Download checkpoint before training (for resume on new VM) + download_checkpoint_from_s3(ckpt_path, run_name) + + trainer = wrap_trainer_with_s3_upload(trainer, bucket="skyrl-checkpoints") + upload_eval_results_to_s3(local_dir, run_name, global_step) + +Environment Variables: + AWS_ACCESS_KEY_ID: AWS access key + AWS_SECRET_ACCESS_KEY: AWS secret key + AWS_REGION: AWS region (default: us-east-1) + S3_CHECKPOINT_BUCKET: S3 bucket for checkpoints (default: skyrl-checkpoints) + S3_TRAJECTORY_BUCKET: S3 bucket for eval trajectories (default: skyrl-trajectories) +""" + +import os +import shutil +import threading +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import Optional +import logging + +logger = logging.getLogger(__name__) + + +class S3CheckpointUploader: + """ + Uploads checkpoint directories to S3 asynchronously. + + Uses a background thread pool to avoid blocking training. + Deletes local checkpoints after successful upload. + """ + + def __init__( + self, + bucket: str, + prefix: str, + region: str = "us-east-1", + max_workers: int = 2, + ): + self.bucket = bucket + self.prefix = prefix + self.region = region + self._executor = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="s3-upload") + self._pending: set = set() + self._lock = threading.Lock() + + def _gather_from_workers(self, local_dir: str) -> None: + """Gather checkpoint shards from worker nodes before S3 upload. + + FSDP saves each rank's shards locally on its node. The head has ranks 0-N, + workers have ranks N+1-M. We rsync worker shards to the head so the S3 + upload gets all shards. + """ + import subprocess + import socket + + node_ips_str = os.environ.get("SKYPILOT_NODE_IPS", "").strip() + if node_ips_str: + node_ips = [ip.strip() for ip in node_ips_str.split("\n") if ip.strip()] + else: + try: + import ray + nodes = ray.nodes() + node_ips = sorted(set( + n["NodeManagerAddress"] for n in nodes + if n.get("Alive", False) + )) + except Exception: + return + + if len(node_ips) <= 1: + return + + head_ip = socket.gethostbyname(socket.gethostname()) + worker_ips = [ip for ip in node_ips if ip != head_ip] + if not worker_ips: + worker_ips = node_ips[1:] + if not worker_ips: + return + + ssh_key = None + for key_path in ["~/.ssh/sky-cluster-key", "~/.ssh/sky-key", "~/.ssh/id_rsa"]: + expanded = os.path.expanduser(key_path) + if os.path.exists(expanded): + ssh_key = expanded + break + ssh_cmd = f"ssh -o StrictHostKeyChecking=no -o ConnectTimeout=30 -i {ssh_key}" if ssh_key else "ssh -o StrictHostKeyChecking=no -o ConnectTimeout=30" + + timeout = _estimate_rsync_timeout(local_dir) + + for worker_ip in worker_ips: + logger.info(f"Gathering checkpoint shards from worker {worker_ip} (timeout={timeout}s)...") + try: + subprocess.run( + [ + "rsync", "-az", + "-e", ssh_cmd, + f"gcpuser@{worker_ip}:{local_dir}/", + f"{local_dir}/", + ], + check=True, + timeout=timeout, + ) + logger.info(f"Gathered shards from {worker_ip}") + except subprocess.TimeoutExpired: + logger.warning(f"Gathering from {worker_ip} timed out ({timeout}s)") + except subprocess.CalledProcessError as e: + logger.warning(f"Gathering from {worker_ip} failed: {e}") + + def _upload_sync(self, local_dir: str) -> bool: + """Synchronous upload that runs in thread pool.""" + try: + # Gather shards from worker nodes before uploading + self._gather_from_workers(local_dir) + import boto3 + from botocore.config import Config + from boto3.s3.transfer import TransferConfig + + config = Config( + retries={"max_attempts": 3, "mode": "adaptive"}, + connect_timeout=30, + read_timeout=120, + ) + + s3 = boto3.client("s3", region_name=self.region, config=config) + + local_path = Path(local_dir) + if not local_path.exists(): + logger.warning(f"Checkpoint directory does not exist: {local_dir}") + return False + + checkpoint_name = local_path.name + s3_prefix = f"{self.prefix}/{checkpoint_name}" + + transfer_config = TransferConfig( + multipart_threshold=64 * 1024 * 1024, + multipart_chunksize=64 * 1024 * 1024, + max_concurrency=4, + use_threads=True, + ) + + uploaded_files = 0 + total_size = 0 + + for file_path in local_path.rglob("*"): + if file_path.is_file(): + relative_path = file_path.relative_to(local_path) + s3_key = f"{s3_prefix}/{relative_path}" + file_size = file_path.stat().st_size + total_size += file_size + + logger.info(f"Uploading {file_path.name} ({file_size / 1e6:.1f} MB)") + + s3.upload_file(str(file_path), self.bucket, s3_key, Config=transfer_config) + uploaded_files += 1 + + logger.info( + f"Uploaded {checkpoint_name}: {uploaded_files} files, {total_size / 1e9:.2f} GB to s3://{self.bucket}/{s3_prefix}/" + ) + + # Delete local after successful upload to free disk space + logger.info(f"Deleting local checkpoint after S3 upload: {local_dir}") + shutil.rmtree(local_dir) + + return True + + except Exception as e: + logger.error(f"S3 upload failed for {local_dir}: {e}") + return False + finally: + with self._lock: + self._pending.discard(local_dir) + + def upload_async(self, local_dir: str) -> None: + """Queue checkpoint for async upload. Non-blocking.""" + with self._lock: + if local_dir in self._pending: + return + self._pending.add(local_dir) + + logger.info(f"Queuing checkpoint for S3 upload: {local_dir}") + self._executor.submit(self._upload_sync, local_dir) + + def wait_for_uploads(self, timeout: Optional[float] = None) -> None: + """Wait for all pending uploads to complete.""" + self._executor.shutdown(wait=True) + self._executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="s3-upload") + + +def cleanup_old_local_checkpoints(ckpt_path: str, keep_n: int = 2) -> None: + """ + Delete old local checkpoints, keeping only the most recent N. + + Args: + ckpt_path: Base checkpoint directory + keep_n: Number of recent checkpoints to keep (default: 2 for safety) + """ + ckpt_dir = Path(ckpt_path) + if not ckpt_dir.exists(): + return + + checkpoint_dirs = sorted( + [d for d in ckpt_dir.iterdir() if d.is_dir() and d.name.startswith("global_step_")], + key=lambda x: int(x.name.split("_")[-1]), + reverse=True, + ) + + for old_dir in checkpoint_dirs[keep_n:]: + logger.info(f"Cleaning up old local checkpoint: {old_dir}") + try: + shutil.rmtree(old_dir) + except Exception as e: + logger.warning(f"Failed to delete {old_dir}: {e}") + + +def wrap_trainer_with_s3_upload( + trainer, + bucket: Optional[str] = None, + prefix: Optional[str] = None, + region: Optional[str] = None, +): + """ + Wrap a SkyRL trainer to: + 1. Clean up old checkpoints BEFORE saving (prevents disk full) + 2. Upload to S3 asynchronously AFTER saving (if credentials set) + 3. Delete local checkpoint after successful S3 upload (frees disk) + + Args: + trainer: SkyRL trainer instance + bucket: S3 bucket (default: from S3_CHECKPOINT_BUCKET env var) + prefix: S3 prefix (default: from trainer config) + region: AWS region (default: from AWS_REGION env var) + + Returns: + The trainer (modified in place) + """ + bucket = bucket or os.environ.get("S3_CHECKPOINT_BUCKET", "skyrl-checkpoints") + region = region or os.environ.get("AWS_REGION", "us-east-1") + + # Build prefix from trainer config + if prefix is None: + run_name = getattr(trainer.cfg.trainer, "run_name", None) + project_name = getattr(trainer.cfg.trainer, "project_name", "skyrl") + model_path = getattr(trainer.cfg.trainer.policy.model, "path", "unknown-model") + model_name = Path(model_path).name + prefix = f"{project_name}/{model_name}/{run_name}" if run_name else f"{project_name}/{model_name}" + + # Check AWS credentials + aws_key = os.environ.get("AWS_ACCESS_KEY_ID") + aws_secret = os.environ.get("AWS_SECRET_ACCESS_KEY") + s3_enabled = bool(aws_key and aws_secret) + + if s3_enabled: + logger.info(f"S3 checkpoint upload ENABLED: s3://{bucket}/{prefix}/") + uploader = S3CheckpointUploader(bucket=bucket, prefix=prefix, region=region) + else: + logger.warning( + "AWS credentials not found. S3 upload DISABLED. " + "Using aggressive local cleanup (keeping only 2 checkpoints). " + "Set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to enable S3." + ) + uploader = None + + original_save_checkpoints = trainer.save_checkpoints + ckpt_path = trainer.cfg.trainer.ckpt_path + + def save_checkpoints_with_cleanup(): + """Wrapped save_checkpoints with pre-save cleanup and async S3 upload.""" + # CRITICAL: Clean up old checkpoints BEFORE saving to free disk space + # With S3: keep only 1 (we have S3 backup), allows room for new checkpoint + # Without S3: keep 2 for safety + keep_n = 1 if s3_enabled else 2 + cleanup_old_local_checkpoints(ckpt_path, keep_n=keep_n) + + # Now save the new checkpoint (disk has space) + original_save_checkpoints() + + # Queue async S3 upload (non-blocking) + if s3_enabled and uploader: + global_step = trainer.global_step + checkpoint_dir = os.path.join(ckpt_path, f"global_step_{global_step}") + if os.path.exists(checkpoint_dir): + uploader.upload_async(checkpoint_dir) + + trainer.save_checkpoints = save_checkpoints_with_cleanup + trainer._s3_uploader = uploader + + return trainer + + +def _estimate_rsync_timeout(path: str, min_timeout: int = 300) -> int: + """Estimate rsync timeout based on directory size. + + Assumes ~100MB/s conservative transfer speed + 60s buffer. + + Args: + path: Directory to measure. + min_timeout: Minimum timeout in seconds (default 5 min). + + Returns: + Timeout in seconds. + """ + try: + total_size = sum( + f.stat().st_size for f in Path(path).rglob("*") if f.is_file() + ) + timeout = max(min_timeout, int(total_size / (100 * 1024 * 1024)) + 60) + logger.info(f"Estimated rsync timeout for {total_size / 1e9:.1f}GB: {timeout}s") + return timeout + except Exception: + return min_timeout + + +def broadcast_checkpoint_to_workers(ckpt_path: str, timeout: Optional[int] = None) -> None: + """Broadcast checkpoint from head node to all worker nodes via rsync. + + FSDP requires checkpoint shards on every node. The S3 download only runs + on the head node, so we rsync the checkpoint directory to all workers. + + Discovers worker IPs from SKYPILOT_NODE_IPS (shell env) or Ray cluster + nodes (when running inside a Ray task). No-op on single-node. + + Args: + ckpt_path: Local checkpoint directory to broadcast. + timeout: Rsync timeout in seconds. If None, auto-calculated from checkpoint size. + """ + import subprocess + import socket + + # Try SKYPILOT_NODE_IPS first (set by SkyPilot run script) + node_ips_str = os.environ.get("SKYPILOT_NODE_IPS", "").strip() + if node_ips_str: + node_ips = [ip.strip() for ip in node_ips_str.split("\n") if ip.strip()] + else: + # Fall back to Ray cluster node discovery + try: + import ray + nodes = ray.nodes() + node_ips = sorted(set( + n["NodeManagerAddress"] for n in nodes + if n.get("Alive", False) + )) + logger.info(f"Discovered {len(node_ips)} nodes from Ray cluster") + except Exception as e: + logger.warning(f"Could not discover nodes: {e}") + return + + if len(node_ips) <= 1: + return # single node, nothing to broadcast + + # Head IP is the current node + head_ip = socket.gethostbyname(socket.gethostname()) + worker_ips = [ip for ip in node_ips if ip != head_ip] + + if not worker_ips: + # Try: head is first in the list + worker_ips = node_ips[1:] + + if not worker_ips: + logger.info("No worker nodes found, skipping checkpoint broadcast") + return + + # Find SSH key — SkyPilot uses ~/.ssh/sky-cluster-key on provisioned VMs + ssh_key = None + for key_path in ["~/.ssh/sky-cluster-key", "~/.ssh/sky-key", "~/.ssh/id_rsa"]: + expanded = os.path.expanduser(key_path) + if os.path.exists(expanded): + ssh_key = expanded + break + ssh_cmd = f"ssh -o StrictHostKeyChecking=no -o ConnectTimeout=30 -i {ssh_key}" if ssh_key else "ssh -o StrictHostKeyChecking=no -o ConnectTimeout=30" + + if timeout is None: + timeout = _estimate_rsync_timeout(ckpt_path) + + for worker_ip in worker_ips: + logger.info(f"Broadcasting checkpoint to worker {worker_ip} (ssh key: {ssh_key}, timeout={timeout}s)...") + try: + # Create parent directory on worker (rsync can't create it) + subprocess.run( + ["ssh"] + ssh_cmd.split()[1:] + [f"gcpuser@{worker_ip}", f"mkdir -p {ckpt_path}"], + check=True, + timeout=30, + ) + subprocess.run( + [ + "rsync", "-az", + "-e", ssh_cmd, + f"{ckpt_path}/", + f"gcpuser@{worker_ip}:{ckpt_path}/", + ], + check=True, + timeout=timeout, + ) + logger.info(f"Checkpoint broadcast to {worker_ip} complete") + except subprocess.TimeoutExpired: + logger.warning(f"Checkpoint broadcast to {worker_ip} timed out") + except subprocess.CalledProcessError as e: + logger.warning(f"Checkpoint broadcast to {worker_ip} failed: {e}") + + +def download_checkpoint_from_s3( + ckpt_path: str, + run_name: str, + bucket: Optional[str] = None, + region: Optional[str] = None, + project_name: str = "fleet-task-grpo", + model_name: str = "Qwen3-32B", +) -> bool: + """ + Download the latest checkpoint from S3 for resume on a fresh VM. + + Looks for checkpoint directories under the S3 prefix matching the run_name, + downloads the latest one, and writes latest_ckpt_global_step.txt. + + Args: + ckpt_path: Local checkpoint directory (e.g., ~/ckpts/fleet_tool_use_32b) + run_name: W&B run name used as S3 prefix (e.g., fleet_tool_use_32b_d7167c1c) + bucket: S3 bucket (default: from S3_CHECKPOINT_BUCKET env var) + region: AWS region (default: from AWS_REGION env var) + project_name: Project name used in S3 prefix + model_name: Model name used in S3 prefix + + Returns: + True if checkpoint was downloaded, False otherwise + """ + bucket = bucket or os.environ.get("S3_CHECKPOINT_BUCKET", "skyrl-checkpoints") + region = region or os.environ.get("AWS_REGION", "us-east-1") + + aws_key = os.environ.get("AWS_ACCESS_KEY_ID") + aws_secret = os.environ.get("AWS_SECRET_ACCESS_KEY") + if not (aws_key and aws_secret): + logger.info("No AWS credentials, skipping S3 checkpoint download") + return False + + # Check if local checkpoint already exists + latest_file = os.path.join(ckpt_path, "latest_ckpt_global_step.txt") + if os.path.exists(latest_file): + with open(latest_file, "r") as f: + step = f.read().strip() + local_ckpt = os.path.join(ckpt_path, f"global_step_{step}") + if os.path.exists(local_ckpt): + logger.info(f"Local checkpoint already exists at step {step}, skipping S3 download") + return False + + try: + import boto3 + from botocore.config import Config + + config = Config( + retries={"max_attempts": 3, "mode": "adaptive"}, + connect_timeout=30, + read_timeout=120, + ) + s3 = boto3.client("s3", region_name=region, config=config) + + # S3 prefix matches what wrap_trainer_with_s3_upload builds + s3_prefix = f"{project_name}/{model_name}/{run_name}/" + + # List all checkpoint directories in S3 + paginator = s3.get_paginator("list_objects_v2") + checkpoint_steps = set() + for page in paginator.paginate(Bucket=bucket, Prefix=s3_prefix, Delimiter="/"): + for prefix_obj in page.get("CommonPrefixes", []): + dir_name = prefix_obj["Prefix"].rstrip("/").split("/")[-1] + if dir_name.startswith("global_step_"): + try: + step = int(dir_name.split("_")[-1]) + checkpoint_steps.add(step) + except ValueError: + pass + + if not checkpoint_steps: + logger.info(f"No checkpoints found in s3://{bucket}/{s3_prefix}") + return False + + latest_step = max(checkpoint_steps) + s3_ckpt_prefix = f"{s3_prefix}global_step_{latest_step}/" + local_ckpt_dir = os.path.join(ckpt_path, f"global_step_{latest_step}") + + logger.info(f"Downloading checkpoint step {latest_step} from s3://{bucket}/{s3_ckpt_prefix}") + + os.makedirs(local_ckpt_dir, exist_ok=True) + + downloaded_files = 0 + total_size = 0 + for page in paginator.paginate(Bucket=bucket, Prefix=s3_ckpt_prefix): + for obj in page.get("Contents", []): + s3_key = obj["Key"] + relative_path = s3_key[len(s3_ckpt_prefix) :] + if not relative_path: + continue + local_file = os.path.join(local_ckpt_dir, relative_path) + os.makedirs(os.path.dirname(local_file), exist_ok=True) + file_size = obj["Size"] + total_size += file_size + logger.info(f"Downloading {relative_path} ({file_size / 1e6:.1f} MB)") + s3.download_file(bucket, s3_key, local_file) + downloaded_files += 1 + + # Write latest_ckpt_global_step.txt so SkyRL's resume_mode=latest can find it + os.makedirs(ckpt_path, exist_ok=True) + with open(latest_file, "w") as f: + f.write(str(latest_step)) + + logger.info( + f"Downloaded checkpoint: {downloaded_files} files, {total_size / 1e9:.2f} GB " + f"from s3://{bucket}/{s3_ckpt_prefix} to {local_ckpt_dir}" + ) + return True + + except Exception as e: + logger.error(f"Failed to download checkpoint from S3: {e}") + return False + + +def upload_eval_results_to_s3( + local_dir: str, + run_name: str, + global_step: Optional[int] = None, + bucket: Optional[str] = None, + region: Optional[str] = None, + delete_local: bool = False, +) -> bool: + """ + Upload eval results directory to S3. + + Args: + local_dir: Local directory containing eval JSONL files + run_name: Run name for S3 prefix (e.g., "fleet_tool_use_abc123") + global_step: Global step number (for organizing in S3) + bucket: S3 bucket (default: from S3_TRAJECTORY_BUCKET env var) + region: AWS region (default: from AWS_REGION env var) + delete_local: If True, delete local files after upload + + Returns: + True if upload succeeded, False otherwise + """ + bucket = bucket or os.environ.get("S3_TRAJECTORY_BUCKET", "skyrl-trajectories") + region = region or os.environ.get("AWS_REGION", "us-east-1") + + # Check AWS credentials + aws_key = os.environ.get("AWS_ACCESS_KEY_ID") + aws_secret = os.environ.get("AWS_SECRET_ACCESS_KEY") + if not (aws_key and aws_secret): + logger.warning("AWS credentials not found. Skipping S3 upload for eval results.") + return False + + local_path = Path(local_dir) + if not local_path.exists(): + logger.warning(f"Eval directory does not exist: {local_dir}") + return False + + try: + import boto3 + from botocore.config import Config + + config = Config( + retries={"max_attempts": 3, "mode": "adaptive"}, + connect_timeout=30, + read_timeout=60, + ) + + s3 = boto3.client("s3", region_name=region, config=config) + + # Build S3 prefix: evals/{run_name}/global_step_{N}/ + step_suffix = f"global_step_{global_step}" if global_step is not None else "eval_only" + s3_prefix = f"evals/{run_name}/{step_suffix}" + + uploaded_files = 0 + total_size = 0 + + for file_path in local_path.rglob("*"): + if file_path.is_file(): + relative_path = file_path.relative_to(local_path) + s3_key = f"{s3_prefix}/{relative_path}" + file_size = file_path.stat().st_size + total_size += file_size + + s3.upload_file(str(file_path), bucket, s3_key) + uploaded_files += 1 + + logger.info( + f"Uploaded eval results: {uploaded_files} files, {total_size / 1e6:.2f} MB " + f"to s3://{bucket}/{s3_prefix}/" + ) + + if delete_local: + shutil.rmtree(local_dir) + logger.info(f"Deleted local eval directory: {local_dir}") + + return True + + except Exception as e: + logger.error(f"S3 upload failed for eval results {local_dir}: {e}") + return False + + +def upload_training_trajectories_to_s3( + local_path: str, + run_name: str, + global_step: int, + bucket: Optional[str] = None, + region: Optional[str] = None, +) -> bool: + """Upload a single training trajectory JSONL file to S3. + + Args: + local_path: Path to the JSONL file + run_name: Run name for S3 prefix + global_step: Global step number + bucket: S3 bucket (default: from S3_TRAJECTORY_BUCKET env var) + region: AWS region (default: from AWS_REGION env var) + + Returns: + True if upload succeeded + """ + bucket = bucket or os.environ.get("S3_TRAJECTORY_BUCKET", "skyrl-trajectories") + region = region or os.environ.get("AWS_REGION", "us-east-1") + + aws_key = os.environ.get("AWS_ACCESS_KEY_ID") + aws_secret = os.environ.get("AWS_SECRET_ACCESS_KEY") + if not (aws_key and aws_secret): + logger.warning("AWS credentials not found. Skipping training trajectory upload.") + return False + + if not os.path.exists(local_path): + logger.warning(f"Trajectory file does not exist: {local_path}") + return False + + try: + import boto3 + from botocore.config import Config + + config = Config(retries={"max_attempts": 3, "mode": "adaptive"}) + s3 = boto3.client("s3", region_name=region, config=config) + + s3_key = f"rollouts/{run_name}/global_step_{global_step}.jsonl" + s3.upload_file(local_path, bucket, s3_key) + logger.info(f"Uploaded training trajectories to s3://{bucket}/{s3_key}") + return True + + except Exception as e: + logger.error(f"S3 upload failed for training trajectories: {e}") + return False + + +def upload_reward_rollouts_to_s3( + rollout_dir: str, + run_name: str, + bucket: Optional[str] = None, + region: Optional[str] = None, +) -> bool: + """Upload reward rollout files to S3. + + Args: + rollout_dir: Local directory containing reward rollout JSONL files + run_name: Run name for S3 prefix + bucket: S3 bucket (default: from S3_TRAJECTORY_BUCKET env var) + region: AWS region (default: from AWS_REGION env var) + + Returns: + True if upload succeeded + """ + bucket = bucket or os.environ.get("S3_TRAJECTORY_BUCKET", "skyrl-trajectories") + region = region or os.environ.get("AWS_REGION", "us-east-1") + + aws_key = os.environ.get("AWS_ACCESS_KEY_ID") + aws_secret = os.environ.get("AWS_SECRET_ACCESS_KEY") + if not (aws_key and aws_secret): + logger.warning("AWS credentials not found. Skipping reward rollout upload.") + return False + + rollout_path = Path(rollout_dir) + if not rollout_path.exists(): + logger.info(f"No reward rollout directory at {rollout_dir}, skipping upload.") + return False + + try: + import boto3 + from botocore.config import Config + + config = Config(retries={"max_attempts": 3, "mode": "adaptive"}) + s3 = boto3.client("s3", region_name=region, config=config) + + uploaded = 0 + for file_path in rollout_path.rglob("*"): + if file_path.is_file(): + relative = file_path.relative_to(rollout_path) + s3_key = f"reward_rollouts/{run_name}/{relative}" + s3.upload_file(str(file_path), bucket, s3_key) + uploaded += 1 + + if uploaded: + logger.info(f"Uploaded {uploaded} reward rollout files to s3://{bucket}/reward_rollouts/{run_name}/") + return True + + except Exception as e: + logger.error(f"S3 upload failed for reward rollouts: {e}") + return False diff --git a/integrations/fleet/task_gen_reward.py b/integrations/fleet/task_gen_reward.py new file mode 100644 index 0000000000..55248ece7d --- /dev/null +++ b/integrations/fleet/task_gen_reward.py @@ -0,0 +1,41 @@ +""" +Reward functions for task generation RL. + +Binary reward: 1.0 if solver rollouts have mixed results (at least one pass +and one fail), 0.0 otherwise. Mixed results = the task is at the right +difficulty frontier. +""" + +from typing import Dict, List + + +def compute_variance(scores: List[float]) -> float: + if len(scores) < 2: + return 0.0 + mean = sum(scores) / len(scores) + return sum((s - mean) ** 2 for s in scores) / len(scores) + + +def compute_task_reward( + raw_scores: List[float], + hinted_scores: List[float], + validity: float = 1.0, + alpha: float = 1.0, +) -> Dict[str, float]: + """Binary reward: 1.0 if mixed solver results, 0.0 otherwise.""" + if not raw_scores: + return {"validity": validity, "p_raw": 0.0, "var_raw": 0.0, "total": 0.0} + + p_raw = sum(raw_scores) / len(raw_scores) + var_raw = compute_variance(raw_scores) + has_pass = any(s > 0 for s in raw_scores) + has_fail = any(s == 0 for s in raw_scores) + total = 1.0 if (has_pass and has_fail and validity > 0) else 0.0 + + return { + "validity": validity, + "p_raw": p_raw, + "var_raw": var_raw, + "hint_gap": 0.0, + "total": total, + } diff --git a/integrations/fleet/tests/__init__.py b/integrations/fleet/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/integrations/fleet/tests/test_main_eval.py b/integrations/fleet/tests/test_main_eval.py new file mode 100644 index 0000000000..58a95660ce --- /dev/null +++ b/integrations/fleet/tests/test_main_eval.py @@ -0,0 +1,177 @@ +"""Unit tests for the Fleet eval-only entrypoint. + +uv run --extra dev --extra skyrl-train pytest integrations/fleet/tests/test_main_eval.py +""" + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from integrations.fleet.entrypoints.main_eval import ( + FleetEvalExp, + _strip_hydra_prefixes, +) + + +# --------------------------------------------------------------------------- +# _strip_hydra_prefixes +# --------------------------------------------------------------------------- + + +def test_strip_hydra_prefixes_handles_all_three_arg_shapes(): + args = [ + "trainer.run_name=my_run", + "+trainer.eval_interval=1", + "++environment.skyrl_gym.fleet_task.tasks_file=/tmp/tasks.json", + ] + out = _strip_hydra_prefixes(args) + assert out == [ + "trainer.run_name=my_run", + "trainer.eval_interval=1", + "environment.skyrl_gym.fleet_task.tasks_file=/tmp/tasks.json", + ] + + +def test_strip_hydra_prefixes_empty(): + assert _strip_hydra_prefixes([]) == [] + + +def test_strip_hydra_prefixes_double_plus_takes_precedence_over_single(): + # "++" matches startswith("++") first, so it strips two characters, not one. + assert _strip_hydra_prefixes(["++key=value"]) == ["key=value"] + + +# --------------------------------------------------------------------------- +# FleetEvalExp.get_train_dataset +# --------------------------------------------------------------------------- + + +def test_get_train_dataset_returns_none(): + # Bypass __init__ so we don't pull in tokenizer / placement group. + exp = FleetEvalExp.__new__(FleetEvalExp) + assert exp.get_train_dataset() is None + + +# --------------------------------------------------------------------------- +# FleetEvalExp._load_policy_only — path resolution + dispatch wiring +# --------------------------------------------------------------------------- + + +def _make_trainer_mock(resume_mode_value: str, ckpt_path: str, resume_path: str = "") -> MagicMock: + """Build a minimal trainer mock for _load_policy_only tests. + + Mirrors the attribute shape `_load_policy_only` reads: trainer.resume_mode, + trainer.cfg.trainer.{ckpt_path, ckpt_interval, resume_path}, trainer.dispatch. + """ + from skyrl.train.utils.trainer_utils import ResumeMode + + trainer = MagicMock() + trainer.resume_mode = ResumeMode(resume_mode_value) + trainer.cfg = SimpleNamespace( + trainer=SimpleNamespace( + ckpt_path=ckpt_path, + ckpt_interval=10, + resume_path=resume_path, + ) + ) + trainer.global_step = 0 + return trainer + + +def _make_exp() -> FleetEvalExp: + """Create a FleetEvalExp bypassing __init__ (which loads a tokenizer).""" + return FleetEvalExp.__new__(FleetEvalExp) + + +def test_load_policy_only_resume_none_is_noop(): + exp = _make_exp() + trainer = _make_trainer_mock("none", ckpt_path="/tmp/does-not-matter") + + exp._load_policy_only(trainer) + + trainer.dispatch.load_checkpoint.assert_not_called() + assert trainer.global_step == 0 + + +def test_load_policy_only_latest_with_no_marker_file_is_noop(tmp_path): + exp = _make_exp() + trainer = _make_trainer_mock("latest", ckpt_path=str(tmp_path)) + # No latest_ckpt_global_step.txt written → fall through, no load. + + exp._load_policy_only(trainer) + + trainer.dispatch.load_checkpoint.assert_not_called() + assert trainer.global_step == 0 + + +def test_load_policy_only_latest_loads_policy_and_sets_global_step(tmp_path): + # Build a realistic checkpoint layout that the resolver expects. + ckpt_dir = tmp_path / "global_step_30" + (ckpt_dir / "policy").mkdir(parents=True) + (tmp_path / "latest_ckpt_global_step.txt").write_text("30") + + exp = _make_exp() + trainer = _make_trainer_mock("latest", ckpt_path=str(tmp_path)) + + # The consistency validator hits the filesystem in non-trivial ways; stub + # it out so the test stays focused on this method's contract. + with patch( + "skyrl.train.utils.trainer_utils.validate_consistency_for_latest_checkpoint" + ) as validator: + exp._load_policy_only(trainer) + + validator.assert_called_once() + trainer.dispatch.load_checkpoint.assert_called_once_with( + "policy", + str(ckpt_dir / "policy"), + load_optimizer_states=False, + load_lr_scheduler_states=False, + ) + assert trainer.global_step == 30 + + +def test_load_policy_only_from_path_loads_specified_checkpoint(tmp_path): + ckpt_dir = tmp_path / "global_step_42" + (ckpt_dir / "policy").mkdir(parents=True) + + exp = _make_exp() + trainer = _make_trainer_mock("from_path", ckpt_path=str(tmp_path), resume_path=str(ckpt_dir)) + + exp._load_policy_only(trainer) + + trainer.dispatch.load_checkpoint.assert_called_once_with( + "policy", + str(ckpt_dir / "policy"), + load_optimizer_states=False, + load_lr_scheduler_states=False, + ) + assert trainer.global_step == 42 + + +def test_load_policy_only_from_path_missing_dir_raises(tmp_path): + exp = _make_exp() + trainer = _make_trainer_mock( + "from_path", + ckpt_path=str(tmp_path), + resume_path=str(tmp_path / "global_step_99"), # never created + ) + + with pytest.raises(FileNotFoundError): + exp._load_policy_only(trainer) + + trainer.dispatch.load_checkpoint.assert_not_called() + + +def test_load_policy_only_from_path_invalid_dir_name_raises(tmp_path): + # extract_step_from_path returns -1 when the dir name has no global_step prefix. + bad_dir = tmp_path / "not_a_step_dir" + (bad_dir / "policy").mkdir(parents=True) + + exp = _make_exp() + trainer = _make_trainer_mock("from_path", ckpt_path=str(tmp_path), resume_path=str(bad_dir)) + + with pytest.raises(ValueError, match="not a valid global_step dir"): + exp._load_policy_only(trainer) + + trainer.dispatch.load_checkpoint.assert_not_called() diff --git a/integrations/fleet/utils.py b/integrations/fleet/utils.py new file mode 100644 index 0000000000..caa64eaab0 --- /dev/null +++ b/integrations/fleet/utils.py @@ -0,0 +1,118 @@ +""" +Utility functions for Fleet task training with Tinker. + +These functions handle sequence truncation and loss mask filtering, +matching SkyRL's skyrl_gym_generator patterns. +""" + +from typing import List, Tuple + + +def truncate_sequence( + prompt_ids: List[int], + response_ids: List[int], + max_sequence_length: int, +) -> Tuple[List[int], List[int], int]: + """ + Truncate a sequence to fit within max_sequence_length. + + The prompt is preserved fully; only the response is truncated. + + Args: + prompt_ids: Token IDs for the prompt. + response_ids: Token IDs for the response. + max_sequence_length: Maximum total sequence length. + + Returns: + Tuple of (full_sequence, truncated_response_ids, response_len). + """ + full_sequence = prompt_ids + response_ids + prompt_len = len(prompt_ids) + + if len(full_sequence) > max_sequence_length: + full_sequence = full_sequence[:max_sequence_length] + response_len = len(full_sequence) - prompt_len + truncated_response_ids = response_ids[:response_len] + else: + response_len = len(response_ids) + truncated_response_ids = response_ids + + return full_sequence, truncated_response_ids, response_len + + +def truncate_auxiliary_data( + data: List, + response_len: int, +) -> List: + """ + Truncate auxiliary data (logprobs, loss_mask) to match truncated response length. + + Args: + data: List of values corresponding to response tokens. + response_len: Target length after truncation. + + Returns: + Truncated list. + """ + if len(data) > response_len: + return data[:response_len] + return data + + +def apply_overlong_filtering_simple( + loss_masks: List[List[int]], + response_ids: List[List[int]], + eos_token_id: int, +) -> List[List[int]]: + """ + Apply DAPO overlong filtering: zero out loss mask for responses not ending with EOS. + + This is a simplified version for testing - the actual SkyRL function is in + skyrl_train.generators.utils.apply_overlong_filtering. + + Args: + loss_masks: List of loss masks for each response. + response_ids: List of response token IDs for each response. + eos_token_id: The EOS token ID. + + Returns: + Filtered loss masks (zeroed if response doesn't end with EOS). + """ + filtered = [] + for mask, response in zip(loss_masks, response_ids): + # Empty response or doesn't end with EOS -> zero out mask + if not response or response[-1] != eos_token_id: + filtered.append([0] * len(mask)) + else: + filtered.append(list(mask)) + return filtered + + +def prepare_training_sequence( + prompt_ids: List[int], + response_ids: List[int], + logprobs: List[float], + loss_mask: List[int], + max_sequence_length: int, +) -> Tuple[List[int], List[float], List[int], bool]: + """ + Prepare a training sequence with truncation if needed. + + Args: + prompt_ids: Token IDs for the prompt. + response_ids: Token IDs for the response. + logprobs: Log probabilities for response tokens. + loss_mask: Loss mask for response tokens. + max_sequence_length: Maximum total sequence length. + + Returns: + Tuple of (full_sequence, truncated_logprobs, truncated_loss_mask, was_truncated). + """ + full_sequence, truncated_response, response_len = truncate_sequence(prompt_ids, response_ids, max_sequence_length) + + was_truncated = len(prompt_ids) + len(response_ids) > max_sequence_length + + truncated_logprobs = truncate_auxiliary_data(logprobs, response_len) + truncated_loss_mask = truncate_auxiliary_data(loss_mask, response_len) + + return full_sequence, truncated_logprobs, truncated_loss_mask, was_truncated diff --git a/scripts/fleet-35b-run.sh b/scripts/fleet-35b-run.sh new file mode 100755 index 0000000000..0192acc262 --- /dev/null +++ b/scripts/fleet-35b-run.sh @@ -0,0 +1,91 @@ +#!/usr/bin/env bash +# Single source of truth for Qwen3.5-35B-A3B GRPO training config. +# Called by the SkyPilot YAML and by fleet-research run.sh. +# +# Required env vars: FLEET_API_KEY, WANDB_API_KEY, OPENROUTER_API_KEY +# Optional: AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY (for S3 checkpoints) +set -euo pipefail +cd "$(dirname "$0")/.." # cd to SkyRL root (scripts/ is directly under repo root) + +# Defaults for vars normally set by SkyPilot YAML envs block +export LOGGER="${LOGGER:-wandb}" +export INFERENCE_BACKEND="${INFERENCE_BACKEND:-vllm}" +export DATA_VERSION="${DATA_VERSION:-v6}" +export MODALITY="${MODALITY:-tool_use}" +export NUM_EPOCHS="${NUM_EPOCHS:-20}" +export MAX_TURNS="${MAX_TURNS:-50}" +export MAX_INPUT_LENGTH="${MAX_INPUT_LENGTH:-72000}" +export MAX_GENERATE_LENGTH="${MAX_GENERATE_LENGTH:-4096}" +export NUM_INFERENCE_ENGINES="${NUM_INFERENCE_ENGINES:-8}" +export ENV_KEYS="${ENV_KEYS:-}" +export DIFFICULTY="${DIFFICULTY:-}" +export RUN_ID="${RUN_ID:-}" +export MAX_TASKS="${MAX_TASKS:-}" +export RESUME_RUN_NAME="${RESUME_RUN_NAME:-}" +export AWS_REGION="${AWS_REGION:-us-east-1}" +export S3_DATASET_BUCKET="${S3_DATASET_BUCKET:-fleet-internal-datasets}" +export S3_CHECKPOINT_BUCKET="${S3_CHECKPOINT_BUCKET:-skyrl-checkpoints}" +export S3_TRAJECTORY_BUCKET="${S3_TRAJECTORY_BUCKET:-skyrl-trajectories}" + +: "${FLEET_API_KEY:?Set FLEET_API_KEY before running}" +: "${WANDB_API_KEY:?Set WANDB_API_KEY before running}" +# OPENROUTER_API_KEY only needed when enable_hints=true (LLM hint synthesis) +export OPENROUTER_API_KEY="${OPENROUTER_API_KEY:-}" + +bash scripts/fleet-common-run.sh \ + --use-python-direct --cuda-env "$HOME/.cuda_env" \ + --set-ulimit --no-pytorch-alloc-conf \ + --nccl-heartbeat 1800 -- \ + environment.skyrl_gym.fleet_task.ttl_seconds=900 \ + environment.skyrl_gym.fleet_task.partial_reward=true \ + environment.skyrl_gym.fleet_task.enable_hints=false \ + trainer.algorithm.advantage_estimator=grpo \ + trainer.policy.model.path="Qwen/Qwen3.5-35B-A3B" \ + trainer.flash_attn=false \ + trainer.loss_chunk_size=4096 \ + trainer.use_sample_packing=false \ + +generator.chat_template_kwargs='{enable_thinking:true}' \ + generator.inference_engine_tensor_parallel_size=2 \ + trainer.epochs=${NUM_EPOCHS} \ + trainer.eval_batch_size=8 \ + trainer.eval_before_train=true \ + trainer.eval_interval=10 \ + trainer.update_epochs_per_batch=1 \ + trainer.train_batch_size=16 \ + trainer.use_hybrid_env_sampling=true \ + trainer.min_samples_per_env=1 \ + trainer.policy_mini_batch_size=16 \ + trainer.micro_forward_batch_size_per_gpu=1 \ + trainer.micro_train_batch_size_per_gpu=1 \ + trainer.ckpt_interval=10 \ + trainer.max_ckpts_to_keep=1 \ + trainer.max_prompt_length=2048 \ + generator.max_input_length=$MAX_INPUT_LENGTH \ + generator.sampling_params.max_generate_length=$MAX_GENERATE_LENGTH \ + generator.sampling_params.temperature=0.9 \ + generator.sampling_params.top_p=0.95 \ + 'generator.sampling_params.stop=[""]' \ + 'generator.eval_sampling_params.stop=[""]' \ + trainer.policy.optimizer_config.lr=5.0e-7 \ + trainer.algorithm.use_kl_loss=true \ + generator.max_turns=$MAX_TURNS \ + generator.backend=$INFERENCE_BACKEND \ + generator.run_engines_locally=true \ + generator.weight_sync_backend=nccl \ + generator.async_engine=true \ + generator.batched=false \ + generator.use_conversation_multi_turn=true \ + generator.n_samples_per_prompt=8 \ + generator.eval_n_samples_per_prompt=3 \ + generator.enforce_eager=false \ + generator.gpu_memory_utilization=0.65 \ + generator.inject_context_status=true \ + generator.context_warning_threshold=0.90 \ + trainer.logger="$LOGGER" \ + trainer.project_name="fleet-tool-use-grpo" \ + trainer.run_name="fleet_qwen35_35b_${MODALITY}_${RUN_ID:-$(head -c 4 /dev/urandom | xxd -p)}" \ + trainer.resume_mode=latest \ + trainer.ckpt_path="$HOME/ckpts/fleet_qwen35_35b_${MODALITY}" \ + trainer.export_path="$HOME/exports" \ + trainer.dump_data_batch=true \ + "$@" diff --git a/scripts/fleet-common-run.sh b/scripts/fleet-common-run.sh new file mode 100755 index 0000000000..7e12cb7e73 --- /dev/null +++ b/scripts/fleet-common-run.sh @@ -0,0 +1,314 @@ +#!/usr/bin/env bash +# Fleet shared run: Ray cluster setup (multi-node aware) + training launch +# +# Usage (from SkyPilot YAML run block): +# bash skyrl-train/scripts/fleet-common-run.sh \ +# --use-python-direct --cuda-env "$HOME/.cuda_env" \ +# --set-ulimit --no-pytorch-alloc-conf -- \ +# trainer.policy.model.path="Qwen/Qwen3.5-9B" \ +# trainer.epochs=20 ... +# +# Multi-node: +# Rank 0 (head): starts Ray head, launches training +# Rank >0 (workers): joins Ray cluster, sleeps +# +# Required env vars: WANDB_API_KEY, MODALITY, INFERENCE_BACKEND, +# SKYPILOT_NUM_GPUS_PER_NODE, SKYPILOT_NODE_IPS +# Optional env vars: SKYPILOT_NUM_NODES, SKYPILOT_NODE_RANK +set -euo pipefail + +# Defaults +DATA_ROOT="" +CKPT_ROOT="" +USE_PYTHON_DIRECT=false +CUDA_ENV="" +SET_ULIMIT=false +NO_PYTORCH_ALLOC_CONF=false +NCCL_HEARTBEAT="" +ENTRYPOINT="integrations.fleet.entrypoints.main_fleet" +ENV_CLASS="fleet_task" +DATA_DIR_NAME="" +HYDRA_OVERRIDES=() + +# Parse args +while [[ $# -gt 0 ]]; do + case "$1" in + --data-root) DATA_ROOT="$2"; shift 2 ;; + --ckpt-root) CKPT_ROOT="$2"; shift 2 ;; + --use-python-direct) USE_PYTHON_DIRECT=true; shift ;; + --cuda-env) CUDA_ENV="$2"; shift 2 ;; + --set-ulimit) SET_ULIMIT=true; shift ;; + --no-pytorch-alloc-conf) NO_PYTORCH_ALLOC_CONF=true; shift ;; + --nccl-heartbeat) NCCL_HEARTBEAT="$2"; shift 2 ;; + --entrypoint) ENTRYPOINT="$2"; shift 2 ;; + --env-class) ENV_CLASS="$2"; shift 2 ;; + --data-dir-name) DATA_DIR_NAME="$2"; shift 2 ;; + --) shift; HYDRA_OVERRIDES=("$@"); break ;; + *) echo "ERROR: Unknown arg: $1"; exit 1 ;; + esac +done + +# Auto-detect data/ckpt root: /workspace if writable (RunPod), else $HOME (GCP, Lambda, etc.) +if [ -z "$DATA_ROOT" ]; then + if [ -d "/workspace" ] && [ -w "/workspace" ]; then + DATA_ROOT="/workspace" + else + DATA_ROOT="$HOME" + fi +fi +if [ -z "$CKPT_ROOT" ]; then + CKPT_ROOT="$DATA_ROOT" +fi +DATA_DIR_NAME="${DATA_DIR_NAME:-$MODALITY}" + +echo "=== Fleet Common Run ===" +echo "Entrypoint: $ENTRYPOINT" +echo "Env class: $ENV_CLASS" +echo "Data root: $DATA_ROOT" +echo "Data dir name: $DATA_DIR_NAME" +echo "Ckpt root: $CKPT_ROOT" + +# Activate venv from repo root (upstream SkyRL layout) +source .venv/bin/activate + +# --- Optional settings --- +if [ "$SET_ULIMIT" = true ]; then + # Set open files limit. Try 1M first, fall back to hard limit if too high. + ulimit -n 1048576 2>/dev/null || ulimit -n "$(ulimit -Hn)" 2>/dev/null || true +fi + +# vLLM TP>1 uses pidfd_getfd for CUDA IPC weight sync between Ray workers. +# This requires ptrace access, which is blocked by default (ptrace_scope=1). +sudo sysctl -w kernel.yama.ptrace_scope=0 2>/dev/null || true + +if [ -n "$CUDA_ENV" ]; then + source "$CUDA_ENV" 2>/dev/null || true +fi + +if [ "$NO_PYTORCH_ALLOC_CONF" = false ]; then + export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +fi + +if [ -n "$NCCL_HEARTBEAT" ]; then + export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC="$NCCL_HEARTBEAT" +fi + +TMP_DIR="${CKPT_ROOT}/skyrl-tmp" +mkdir -p "$TMP_DIR" +export TMPDIR="$TMP_DIR" + +TASKS_FILE="${DATA_ROOT}/data/fleet/tasks_${MODALITY}.json" +DATA_DIR="${DATA_ROOT}/data/fleet/${DATA_DIR_NAME}" + +# --- System diagnostics --- +echo "=== System Diagnostics ===" +free -h +nvidia-smi --query-gpu=name,driver_version,memory.total,memory.free --format=csv 2>/dev/null || true +echo "--- /dev/shm ---" +df -h /dev/shm 2>/dev/null || echo "/dev/shm not mounted" +ls -la /dev/shm/ 2>/dev/null | head -5 || true +echo "--- GPU Topology ---" +nvidia-smi topo -m 2>/dev/null || true +echo "--- cgroup memory limits ---" +cat /sys/fs/cgroup/memory.max 2>/dev/null || cat /sys/fs/cgroup/memory/memory.limit_in_bytes 2>/dev/null || echo "No cgroup memory limit found" +echo "--- ulimits ---" +ulimit -a 2>/dev/null || true +echo "--- NCCL env vars ---" +env | grep -i NCCL || echo "No NCCL env vars set" +echo "--- kernel overcommit ---" +cat /proc/sys/vm/overcommit_memory 2>/dev/null || true +echo "=== End Diagnostics ===" + +# --- wandb login --- +python3 -c "import wandb; wandb.login(relogin=True, key='$WANDB_API_KEY')" + +# --- Fabric Manager check (NVSwitch GPUs: B200, H200 SXM) --- +# On non-GCP clouds (RunPod, Lambda, etc.), Fabric Manager is required for NVLink +# P2P on NVSwitch systems. Without it, dist.broadcast() in FSDP causes SIGKILL. +# +# On GCP, NVSwitch is managed at the HOST level — the guest VM does not have +# NVSwitch devices, so FM reports "NV_WARN_NOTHING_TO_DO" and cannot start. +# This is EXPECTED. NVLink P2P works through GCP's host-managed fabric without FM. +# GCP also provides a custom NCCL shim (gIB) that manages all NCCL configuration. +# Do NOT set NCCL_P2P_DISABLE or NCCL_NVLS_ENABLE on GCP with RDMA — +# the shim's "Guest Config Checker" expects these to be unset. +# NCCL_CUMEM_ENABLE=0 is set below for GCP WITHOUT RDMA to disable multicast. +ON_GCP=false +if [ -d "/usr/local/gib" ]; then + ON_GCP=true +elif [ -f "/sys/class/dmi/id/product_name" ] && grep -qi "google" /sys/class/dmi/id/product_name 2>/dev/null; then + ON_GCP=true +fi + +FM_STATUS=$(systemctl is-active nvidia-fabricmanager 2>/dev/null || echo "unknown") +echo "Fabric Manager status: $FM_STATUS" +echo "On GCP: $ON_GCP" + +if [ "$ON_GCP" = true ]; then + echo "GCP detected — skipping Fabric Manager restart (host manages NVSwitch)" + + # GCP's deep learning images install /etc/profile.d/nccl_env.sh which auto-sources + # /usr/local/gib/scripts/set_nccl_env.sh and adds /usr/local/gib/lib64 to LD_LIBRARY_PATH. + # This sets NCCL_NET=gIB, forcing the gIB network plugin for RDMA/InfiniBand. + # + # Problem: gIB requires RDMA hardware (ConnectX NICs + multiple GPUDirect VPC networks). + # SkyPilot provisions VMs with a single management NIC — no RDMA networking. + # When NCCL_NET=gIB is forced but gIB can't init, NCCL fails with + # "Failed to initialize any NET plugin" → SIGKILL during dist.broadcast(). + # + # Fix: check for RDMA devices. If absent, strip gIB so NCCL falls back to + # NVLink P2P for intra-node communication. Multi-node uses GKE with RDMA. + if [ -d "/sys/class/infiniband" ] && [ "$(ls /sys/class/infiniband/ 2>/dev/null)" ]; then + echo "RDMA devices found — keeping gIB for GPUDirect RDMA" + else + echo "No RDMA devices — disabling gIB" + # Remove gIB from LD_LIBRARY_PATH (set by /etc/profile.d/nccl_env.sh) + export LD_LIBRARY_PATH=$(echo "${LD_LIBRARY_PATH:-}" | sed 's|/usr/local/gib/lib64:||g; s|:/usr/local/gib/lib64||g; s|/usr/local/gib/lib64||g') + # Unset NCCL_NET=gIB so NCCL can fall back to NVLink P2P + unset NCCL_NET + # Clear gIB-specific vars set by set_nccl_env.sh + unset NCCL_CROSS_NIC NCCL_NET_GDR_LEVEL NCCL_P2P_NET_CHUNKSIZE NCCL_NVLS_CHUNKSIZE + unset NCCL_IB_ADAPTIVE_ROUTING NCCL_IB_QPS_PER_CONNECTION NCCL_IB_TC NCCL_IB_FIFO_TC + unset NCCL_TUNER_CONFIG_PATH + # Disable CUDA multicast (requires NVSwitch fabric manager for GPU multicast + # team setup). Without this, vLLM TP>1 hangs on CUDASymmetricMemory init. + export NCCL_CUMEM_ENABLE=0 + echo "Cleared gIB NCCL env vars. Using NVLink P2P (intra-node)." + fi + echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH" + echo "NCCL vars:" + env | grep -i NCCL || echo " (none)" + + # Ensure /dev/shm is large enough for NCCL IPC (some GCP images have small default) + SHM_SIZE=$(df --output=size /dev/shm 2>/dev/null | tail -1 | tr -d ' ') + echo "Current /dev/shm size: ${SHM_SIZE}K" + if [ -n "$SHM_SIZE" ] && [ "$SHM_SIZE" -lt 16777216 ]; then + echo "WARNING: /dev/shm is only ${SHM_SIZE}K — remounting to 16G for NCCL" + sudo mount -o remount,size=16G /dev/shm 2>&1 || echo "Failed to remount /dev/shm" + df -h /dev/shm + fi +elif [ "$FM_STATUS" != "active" ]; then + echo "WARNING: Fabric Manager not active. Attempting restart..." + sudo nvidia-smi -pm 1 2>&1 || true + sudo systemctl stop nvidia-fabricmanager 2>&1 || true + sleep 1 + sudo systemctl start nvidia-fabricmanager 2>&1 || true + sleep 5 + FM_STATUS=$(systemctl is-active nvidia-fabricmanager 2>/dev/null || echo "unknown") + echo "Fabric Manager status after restart: $FM_STATUS" + if [ "$FM_STATUS" != "active" ]; then + echo "=== WARNING: Fabric Manager failed to start ===" + echo "Training may fail if this system has NVSwitch GPUs." + sudo journalctl -u nvidia-fabricmanager --no-pager -n 10 2>&1 || true + fi +fi + +# --- Ray cluster setup (multi-node aware) --- +export RAY_RUNTIME_ENV_HOOK=ray._private.runtime_env.uv_runtime_env_hook.hook +export RAY_object_store_memory=10000000000 +# Disable Ray's memory monitor to prevent spurious worker kills +export RAY_DISABLE_MEMORY_MONITOR=1 +# NOTE: On GCP VMs without RDMA, gIB NCCL vars are stripped above. +# On GKE with RDMA, gIB is preserved for inter-node GPUDirect. + +read -r head_ip _ <<< "$SKYPILOT_NODE_IPS" + +wait_for_ray() { + local address=$1 + for _ in $(seq 1 24); do + if ray status --address "$address" >/dev/null 2>&1; then + return 0 + fi + sleep 5 + done + echo "ERROR: Ray cluster at $address failed to become ready" >&2 + return 1 +} + +if [ "${SKYPILOT_NODE_RANK:-0}" = "0" ]; then + # === Head node: start Ray head + launch training === + if ! ray status --address 127.0.0.1:6479 >/dev/null 2>&1; then + ray start --head --disable-usage-stats --port 6479 --object-store-memory=10000000000 + fi + wait_for_ray 127.0.0.1:6479 + + TOTAL_GPUS=$((SKYPILOT_NUM_GPUS_PER_NODE * ${SKYPILOT_NUM_NODES:-1})) + export TOTAL_GPUS + # NUM_INFERENCE_ENGINES can be overridden via env var for TP>1 (engines = GPUs / TP) + NUM_INFERENCE_ENGINES=${NUM_INFERENCE_ENGINES:-$TOTAL_GPUS} + echo "=== Head node: $TOTAL_GPUS GPUs across ${SKYPILOT_NUM_NODES:-1} node(s), $NUM_INFERENCE_ENGINES inference engines ===" + + # Build training command + CMD_ARGS=() + if [ "$USE_PYTHON_DIRECT" = true ]; then + CMD_ARGS=(python -m "$ENTRYPOINT") + else + CMD_ARGS=(uv run --isolated --extra "$INFERENCE_BACKEND" -m "$ENTRYPOINT") + fi + + # Common hydra overrides (data paths, placement, strategy, checkpoints) + CMD_ARGS+=( + "data.train_data=['${DATA_DIR}/train.parquet']" + "data.val_data=['${DATA_DIR}/validation.parquet']" + "environment.env_class=$ENV_CLASS" + ) + + # fleet_task-specific: pass tasks_file path + if [ "$ENV_CLASS" = "fleet_task" ]; then + CMD_ARGS+=("environment.skyrl_gym.fleet_task.tasks_file=$TASKS_FILE") + fi + + CMD_ARGS+=( + trainer.placement.colocate_all=true + trainer.strategy=fsdp2 + "trainer.placement.policy_num_gpus_per_node=$SKYPILOT_NUM_GPUS_PER_NODE" + "trainer.placement.ref_num_gpus_per_node=$SKYPILOT_NUM_GPUS_PER_NODE" + "trainer.placement.policy_num_nodes=${SKYPILOT_NUM_NODES:-1}" + "trainer.placement.ref_num_nodes=${SKYPILOT_NUM_NODES:-1}" + "generator.num_inference_engines=$NUM_INFERENCE_ENGINES" + "trainer.ckpt_path=${CKPT_ROOT}/ckpts" + "trainer.export_path=${CKPT_ROOT}/exports" + trainer.dump_training_trajectories=true + ) + + # Append model-specific hydra overrides (passed after --) + if [ ${#HYDRA_OVERRIDES[@]} -gt 0 ]; then + CMD_ARGS+=("${HYDRA_OVERRIDES[@]}") + fi + + export HYDRA_FULL_ERROR=1 + echo "=== Launching Training ===" + set +e + "${CMD_ARGS[@]}" + EXIT_CODE=$? + set -e + + if [ $EXIT_CODE -ne 0 ]; then + echo "=== Training failed (exit code $EXIT_CODE) ===" + echo "--- dmesg (last 50 lines, unfiltered) ---" + sudo dmesg -T 2>/dev/null | tail -50 || true + echo "--- dmesg (OOM/kill/segfault) ---" + sudo dmesg -T 2>/dev/null | grep -iE "oom|kill|out of memory|segfault|sigsegv|general protection|cgroup" | tail -20 || true + echo "--- memory ---" + free -h + echo "--- GPU memory ---" + nvidia-smi --query-gpu=memory.used,memory.free --format=csv 2>/dev/null || true + echo "--- /dev/shm after crash ---" + df -h /dev/shm 2>/dev/null || true + echo "--- cgroup memory events ---" + cat /sys/fs/cgroup/memory.events 2>/dev/null || cat /sys/fs/cgroup/memory/memory.oom_control 2>/dev/null || true + echo "--- Ray worker logs (last errors) ---" + grep -r "SIGKILL\|SIGABRT\|SIGSEGV\|SYSTEM_ERROR\|RuntimeError\|NCCL" /tmp/ray/session_latest/logs/ 2>/dev/null | tail -30 || true + exit $EXIT_CODE + fi + +else + # === Worker node: join Ray cluster and wait === + echo "=== Worker node (rank ${SKYPILOT_NODE_RANK}), joining Ray cluster at $head_ip:6479 ===" + if ! ray status --address "$head_ip:6479" >/dev/null 2>&1; then + ray start --address "$head_ip:6479" --disable-usage-stats + fi + wait_for_ray "$head_ip:6479" + echo "Worker node joined. Sleeping..." + sleep infinity +fi diff --git a/scripts/fleet-common-setup.sh b/scripts/fleet-common-setup.sh new file mode 100755 index 0000000000..fe07c020fa --- /dev/null +++ b/scripts/fleet-common-setup.sh @@ -0,0 +1,132 @@ +#!/usr/bin/env bash +# Fleet shared setup: env validation, venv, dependencies, OpenEnv, dataset download +# +# Usage (from SkyPilot YAML setup block): +# bash skyrl-train/scripts/fleet-common-setup.sh \ +# --openenv-branch deniz/fleet_client \ +# --extra-setup skyrl-train/scripts/fleet-qwen35-extra-setup.sh +# +# Required env vars: FLEET_API_KEY, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, +# MODALITY, DATA_VERSION, S3_DATASET_BUCKET +# Optional env vars: ENV_KEYS, DIFFICULTY +set -euo pipefail + +# Defaults +OPENENV_BRANCH="deniz/fleet_client" +EXTRA_SETUP="" +DATA_ROOT="" +SKIP_UV_ISOLATED=false +EXTRA_PIP="" +SKIP_PREPARE=false +ENV_CLASS="fleet_task" + +# Parse args +while [[ $# -gt 0 ]]; do + case "$1" in + --openenv-branch) OPENENV_BRANCH="$2"; shift 2 ;; + --extra-setup) EXTRA_SETUP="$2"; shift 2 ;; + --data-root) DATA_ROOT="$2"; shift 2 ;; + --skip-uv-isolated) SKIP_UV_ISOLATED=true; shift ;; + --extra-pip) EXTRA_PIP="$2"; shift 2 ;; + --skip-prepare) SKIP_PREPARE=true; shift ;; + --env-class) ENV_CLASS="$2"; shift 2 ;; + *) echo "ERROR: Unknown arg: $1"; exit 1 ;; + esac +done + +# Auto-detect data root: /workspace if writable (RunPod), else $HOME (GCP, Lambda, etc.) +if [ -z "$DATA_ROOT" ]; then + if [ -d "/workspace" ] && [ -w "/workspace" ]; then + DATA_ROOT="/workspace" + else + DATA_ROOT="$HOME" + fi +fi + +# Resolve extra-setup path to absolute before cd (it's relative to repo root) +if [ -n "$EXTRA_SETUP" ]; then + EXTRA_SETUP="$(cd "$(dirname "$EXTRA_SETUP")" && pwd)/$(basename "$EXTRA_SETUP")" +fi + +# In upstream SkyRL, training packages live at repo root (skyrl/, skyrl-gym/, integrations/) +# No need to cd into skyrl-train/ — the venv and dependencies are at root level + +echo "=== Fleet Common Setup ===" +echo "OpenEnv branch: $OPENENV_BRANCH" +echo "Data root: $DATA_ROOT" +echo "Extra setup: ${EXTRA_SETUP:-none}" + +# --- Environment validation --- +echo "Validating environment variables..." +if [ -z "${FLEET_API_KEY:-}" ]; then + echo "ERROR: FLEET_API_KEY is required"; exit 1 +fi +if [ -z "${AWS_ACCESS_KEY_ID:-}" ] || [ -z "${AWS_SECRET_ACCESS_KEY:-}" ]; then + echo "ERROR: AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are required for S3 dataset download"; exit 1 +fi +if [ "${MODALITY:-}" != "tool_use" ] && [ "${MODALITY:-}" != "computer_use" ] && [ "${MODALITY:-}" != "browser_use" ]; then + echo "ERROR: MODALITY must be 'tool_use', 'computer_use', or 'browser_use', got: ${MODALITY:-unset}"; exit 1 +fi +echo "Environment validation passed" + +# --- Fix Ray binary permissions (some cloud images strip +x) --- +for f in .venv/bin/ray .venv/lib/python*/site-packages/ray/core/src/ray/raylet/raylet; do + [ -f "$f" ] && chmod +x "$f" 2>/dev/null || true +done + +# --- System dependencies (GCP images may lack build tools) --- +if ! command -v c++ &>/dev/null; then + echo "Installing build-essential (c++ compiler required for causal-conv1d)..." + sudo apt-get update -qq && sudo apt-get install -y --no-install-recommends build-essential +fi + +# --- Python environment --- +if [ -d ".venv" ]; then + echo "Virtual environment already exists, reusing" +else + uv venv --python 3.12 --seed +fi +source .venv/bin/activate +# vLLM 0.17.0 has native Qwen3.5 support (GDN via torch.ops.vllm.gdn_attention_core), +# FlashAttention 4, and PyTorch 2.10.0 +uv sync --extra fsdp +uv pip install wandb boto3 awscli +# Pin fleet-python<=0.2.119: 0.2.120+ has async BaseWrapper bug (missing jwt/team_id params) +uv pip install "litellm>=1.75.5" "fleet-python<=0.2.119" logfire "mcp>=1.0.0" + +# --- Extra pip packages (installed before extra-setup to avoid dependency downgrades) --- +if [ -n "$EXTRA_PIP" ]; then + echo "Installing extra pip packages: $EXTRA_PIP" + uv pip install $EXTRA_PIP +fi + +# --- Extra setup hook (model-specific dependencies) --- +if [ -n "$EXTRA_SETUP" ]; then + echo "Running extra setup: $EXTRA_SETUP" + source "$EXTRA_SETUP" +fi + +# --- OpenEnv (force reinstall for latest changes) --- +uv pip install --force-reinstall --no-cache-dir --no-deps "git+https://github.com/fleet-ai/OpenEnv.git@${OPENENV_BRANCH}" + +# --- Dataset download --- +mkdir -p "${DATA_ROOT}/data/fleet" +TASKS_FILE="${DATA_ROOT}/data/fleet/tasks_${MODALITY}.json" +S3_PATH="s3://${S3_DATASET_BUCKET}/${DATA_VERSION}/openenv/all_${MODALITY}.json" +echo "Downloading dataset from $S3_PATH..." +aws s3 cp "$S3_PATH" "$TASKS_FILE" +TASK_COUNT=$(python3 -c "import json; print(len(json.load(open('$TASKS_FILE'))['tasks']))") +echo "Downloaded $TASK_COUNT tasks for modality: $MODALITY" + +# --- Prepare dataset (parquet files) --- +if [ "$SKIP_PREPARE" = true ]; then + echo "Skipping prepare_dataset (--skip-prepare). Caller handles preparation." +else + DATA_DIR="${DATA_ROOT}/data/fleet/${MODALITY}" + PREPARE_CMD="python -m integrations.fleet.prepare_dataset --tasks-json $TASKS_FILE --output-dir $DATA_DIR --modality $MODALITY --env-class $ENV_CLASS" + [ -n "${ENV_KEYS:-}" ] && PREPARE_CMD="$PREPARE_CMD --env-filter $ENV_KEYS" + [ -n "${DIFFICULTY:-}" ] && PREPARE_CMD="$PREPARE_CMD --difficulty-filter $DIFFICULTY" + eval "$PREPARE_CMD" +fi + +echo "=== Fleet Common Setup Complete ===" diff --git a/scripts/fleet-eval-only-run.sh b/scripts/fleet-eval-only-run.sh new file mode 100644 index 0000000000..4d1a0b07d4 --- /dev/null +++ b/scripts/fleet-eval-only-run.sh @@ -0,0 +1,115 @@ +#!/usr/bin/env bash +# Eval-only run on Fleet envs with optional S3 checkpoint resume. +# +# When RESUME_RUN_NAME is set, downloads the latest FSDP checkpoint from S3, +# broadcasts it to worker nodes, loads policy weights, and runs a single eval +# pass. Eval results are dumped locally and uploaded to S3. +# +# When RESUME_RUN_NAME is unset, evaluates the base model at trainer.policy.model.path. +# +# Required env vars: FLEET_API_KEY, WANDB_API_KEY +# Optional env vars: +# RESUME_RUN_NAME Run name to resume from (S3 prefix). Empty = base model eval. +# RESUME_CKPT_PATH Local checkpoint dir to download into. Default: $HOME/ckpts/eval_only +# AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY (required for S3 resume / upload) +# MODEL_PATH HF model repo or path. Default: Qwen/Qwen3.5-9B +# PROJECT_NAME W&B / S3 project prefix. Default: fleet-tool-use-grpo +# RUN_NAME W&B run name + S3 eval upload prefix. Default: fleet_eval_only__ +# EVAL_N_SAMPLES pass@K samples per prompt. Default: 8 +# +set -euo pipefail +cd "$(dirname "$0")/.." # cd to SkyRL root + +export LOGGER="${LOGGER:-wandb}" +export INFERENCE_BACKEND="${INFERENCE_BACKEND:-vllm}" +export MODALITY="${MODALITY:-tool_use}" +export MAX_TURNS="${MAX_TURNS:-50}" +export MAX_INPUT_LENGTH="${MAX_INPUT_LENGTH:-72000}" +export MAX_GENERATE_LENGTH="${MAX_GENERATE_LENGTH:-4096}" +export NUM_INFERENCE_ENGINES="${NUM_INFERENCE_ENGINES:-8}" +export EVAL_N_SAMPLES="${EVAL_N_SAMPLES:-8}" +export AWS_REGION="${AWS_REGION:-us-east-1}" +export S3_DATASET_BUCKET="${S3_DATASET_BUCKET:-fleet-internal-datasets}" +export S3_CHECKPOINT_BUCKET="${S3_CHECKPOINT_BUCKET:-skyrl-checkpoints}" +export S3_TRAJECTORY_BUCKET="${S3_TRAJECTORY_BUCKET:-skyrl-trajectories}" + +: "${FLEET_API_KEY:?Set FLEET_API_KEY before running}" +: "${WANDB_API_KEY:?Set WANDB_API_KEY before running}" +export OPENROUTER_API_KEY="${OPENROUTER_API_KEY:-}" + +MODEL_PATH="${MODEL_PATH:-Qwen/Qwen3.5-9B}" +PROJECT_NAME="${PROJECT_NAME:-fleet-tool-use-grpo}" +RUN_NAME="${RUN_NAME:-fleet_eval_only_${MODALITY}_pass_at_${EVAL_N_SAMPLES}}" +RESUME_CKPT_PATH="${RESUME_CKPT_PATH:-$HOME/ckpts/eval_only}" + +# resume_mode controls how main_eval picks the checkpoint inside RESUME_CKPT_PATH. +# latest = read latest_ckpt_global_step.txt (written by S3 download); none = base weights. +if [ -n "${RESUME_RUN_NAME:-}" ]; then + RESUME_MODE="${RESUME_MODE:-latest}" +else + RESUME_MODE="${RESUME_MODE:-none}" +fi +export RESUME_RUN_NAME="${RESUME_RUN_NAME:-}" + +DATA_ROOT="" +if [ -d "/workspace" ] && [ -w "/workspace" ]; then + DATA_ROOT="/workspace" +else + DATA_ROOT="$HOME" +fi + +EVAL_PARQUET="${DATA_ROOT}/data/fleet/${MODALITY}/validation.parquet" +TASKS_FILE="${DATA_ROOT}/data/fleet/tasks_${MODALITY}.json" + +echo "=== Fleet Eval-Only Run ===" +echo "Model: $MODEL_PATH" +echo "Project / Run: $PROJECT_NAME / $RUN_NAME" +echo "Resume run name: ${RESUME_RUN_NAME:-(none — base model eval)}" +echo "Resume mode: $RESUME_MODE" +echo "Ckpt path: $RESUME_CKPT_PATH" +echo "Eval data: $EVAL_PARQUET" +echo "Samples/prompt: $EVAL_N_SAMPLES" + +bash scripts/fleet-common-run.sh \ + --use-python-direct --cuda-env "$HOME/.cuda_env" \ + --set-ulimit --no-pytorch-alloc-conf \ + --entrypoint integrations.fleet.entrypoints.main_eval \ + --nccl-heartbeat 1800 -- \ + environment.skyrl_gym.fleet_task.ttl_seconds=900 \ + environment.skyrl_gym.fleet_task.partial_reward=true \ + environment.skyrl_gym.fleet_task.enable_hints=false \ + trainer.policy.model.path="$MODEL_PATH" \ + trainer.flash_attn=false \ + trainer.use_sample_packing=false \ + trainer.resume_mode="$RESUME_MODE" \ + trainer.ckpt_path="$RESUME_CKPT_PATH" \ + trainer.eval_batch_size=4 \ + trainer.eval_interval=1 \ + trainer.max_prompt_length=2048 \ + trainer.dump_eval_results=true \ + trainer.export_path="$HOME/exports" \ + generator.chat_template_kwargs='{enable_thinking:true}' \ + generator.inference_engine_tensor_parallel_size=1 \ + generator.max_input_length=$MAX_INPUT_LENGTH \ + generator.eval_sampling_params.max_generate_length=$MAX_GENERATE_LENGTH \ + generator.eval_sampling_params.temperature=0.9 \ + generator.eval_sampling_params.top_p=0.95 \ + 'generator.eval_sampling_params.stop=[""]' \ + generator.max_turns=$MAX_TURNS \ + generator.backend=$INFERENCE_BACKEND \ + generator.run_engines_locally=true \ + generator.weight_sync_backend=nccl \ + generator.async_engine=true \ + generator.batched=false \ + generator.use_conversation_multi_turn=true \ + generator.eval_n_samples_per_prompt=$EVAL_N_SAMPLES \ + generator.enforce_eager=false \ + generator.gpu_memory_utilization=0.65 \ + generator.inject_context_status=true \ + generator.context_warning_threshold=0.90 \ + trainer.logger="$LOGGER" \ + trainer.project_name="$PROJECT_NAME" \ + trainer.run_name="$RUN_NAME" \ + "data.val_data=['${EVAL_PARQUET}']" \ + "environment.skyrl_gym.fleet_task.tasks_file=${TASKS_FILE}" \ + "$@" diff --git a/scripts/fleet-qwen35-extra-setup.sh b/scripts/fleet-qwen35-extra-setup.sh new file mode 100755 index 0000000000..66e77038fc --- /dev/null +++ b/scripts/fleet-qwen35-extra-setup.sh @@ -0,0 +1,73 @@ +#!/usr/bin/env bash +# Qwen3.5-specific dependencies (sourced by fleet-common-setup.sh via --extra-setup) +# +# Installs: transformers 5.3.0, flash-attn 2.8.3 wheel, CUDA toolkit (nvcc), causal-conv1d +# Writes: $HOME/.cuda_env (sourced at run time for FlashInfer JIT) + +# Upgrade transformers to 5.3.0 for Qwen3.5-MoE (model_type=qwen3_5_moe). +# - Qwen3.5 launched Feb 2026; all 4.x releases predate it. +# - 5.1.0 doesn't register qwen3_5_moe in AUTO_CONFIG_MAPPING. +# - 5.3.0 is the first stable release with full qwen3_5_moe support. +# - Do NOT install from git main (renamed layer_type_validation, breaks vLLM 0.17). +uv pip install -U "transformers==5.3.0" + +# flash-attn 2.8.3 prebuilt wheel for torch 2.10 + CUDA 12 (training forward/backward) +uv pip install "https://github.com/lesj0610/flash-attention/releases/download/v2.8.3-cu12-torch2.10-cp312/flash_attn-2.8.3%2Bcu12torch2.10cxx11abiTRUE-cp312-cp312-linux_x86_64.whl" + +python -c "import torch; import torchvision; print(f'torch={torch.__version__}, torchvision={torchvision.__version__}')" + +# --- CUDA toolkit for FlashInfer JIT (GatedDeltaNet kernels) --- +# pip CUDA packages are incomplete (missing nv/target headers); use NVIDIA apt repo instead +CUDA_HOME="" +for d in /usr/local/cuda /usr/local/cuda-12.8 /usr/local/cuda-12.6 /usr/local/cuda-12.4; do + if [ -x "$d/bin/nvcc" ]; then + CUDA_HOME="$d" + break + fi +done +if [ -z "$CUDA_HOME" ] && command -v nvcc &>/dev/null; then + NVCC_PATH=$(command -v nvcc) + CUDA_HOME=$(dirname "$(dirname "$NVCC_PATH")") +fi +if [ -z "$CUDA_HOME" ]; then + echo "nvcc not found on system. Installing CUDA toolkit from NVIDIA apt repo..." + sudo apt-get update -qq + UBUNTU_VER=$(lsb_release -rs 2>/dev/null | tr -d '.' || echo "2204") + KEYRING_URL="https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VER}/x86_64/cuda-keyring_1.1-1_all.deb" + echo "Installing CUDA keyring from $KEYRING_URL" + wget -qO /tmp/cuda-keyring.deb "$KEYRING_URL" 2>&1 || curl -sLo /tmp/cuda-keyring.deb "$KEYRING_URL" + file /tmp/cuda-keyring.deb + sudo dpkg -i /tmp/cuda-keyring.deb + sudo apt-get update -qq + sudo apt-get install -y --no-install-recommends cuda-nvcc-12-8 libcublas-dev-12-8 cuda-nvrtc-dev-12-8 + CUDA_HOME="/usr/local/cuda-12.8" +fi +export CUDA_HOME +export PATH="$CUDA_HOME/bin:$PATH" +echo "CUDA_HOME=$CUDA_HOME" +"$CUDA_HOME/bin/nvcc" --version + +# Write cuda_env for run phase (fleet-common-run.sh sources this via --cuda-env) +echo "export CUDA_HOME=$CUDA_HOME" > "$HOME/.cuda_env" +echo "export PATH=$CUDA_HOME/bin:\$PATH" >> "$HOME/.cuda_env" + +# causal-conv1d: required for GatedDeltaNet fast CUDA kernels in Qwen3.5-MoE. +# Without it, fla-core falls back to a naive PyTorch implementation that crashes +# with cudaErrorIllegalAddress on multi-node FSDP2 (Xid 31 MMU fault). +# Must be built from source (needs nvcc + g++) — install AFTER CUDA toolkit setup. +# Build from source with --no-build-isolation so it finds torch from the venv. +# uv pip install can silently fail on CUDA extensions; use pip directly. +pip install --no-cache-dir --no-build-isolation "causal-conv1d>=1.6.0" +python -c "import causal_conv1d; print(f'causal-conv1d OK: {causal_conv1d.__version__}')" + +# Verify pinned packages survived dependency resolution +python -c "import transformers; assert transformers.__version__ == '5.3.0', f'Expected 5.3.0 got {transformers.__version__}'" +# Ensure torch 2.10.0 — uv pip install can downgrade it during transitive resolution +TORCH_VER=$(python -c "import torch; print(torch.__version__)") +echo "torch version after setup: $TORCH_VER" +if [[ "$TORCH_VER" != 2.10.0* ]]; then + echo "WARNING: torch was downgraded to $TORCH_VER, reinstalling 2.10.0+cu128" + pip install --force-reinstall --no-deps torch==2.10.0 --index-url https://download.pytorch.org/whl/cu128 +fi +python -c "import torch; assert torch.__version__.startswith('2.10.0'), f'Expected 2.10.0 got {torch.__version__}'" +python -c "import torch; import flash_attn_2_cuda; print('flash_attn CUDA extension OK')" diff --git a/scripts/fleet-task-gen-35b-run.sh b/scripts/fleet-task-gen-35b-run.sh new file mode 100755 index 0000000000..6cad5e5dab --- /dev/null +++ b/scripts/fleet-task-gen-35b-run.sh @@ -0,0 +1,113 @@ +#!/usr/bin/env bash +# Task-gen specific run for Qwen3.5-35B: calls common run with task-gen entrypoint +# and 35B-specific config (TP=2, flash_attn=false, 72K input, chunked lm_head). +# +# Usage (from SkyPilot YAML run block): +# bash scripts/fleet-task-gen-35b-run.sh +# +# Required env vars: WANDB_API_KEY, FLEET_API_KEY +# SkyPilot env vars: SKYPILOT_NUM_GPUS_PER_NODE, SKYPILOT_NODE_IPS +set -euo pipefail + +# Export RUN_NAME so task_gen_env can tag rollout dumps +export RUN_NAME="task_gen_35b_$(python3 -c 'import os; print(os.urandom(4).hex())')" + +# Defaults for vars normally set by SkyPilot YAML envs block +export LOGGER="${LOGGER:-wandb}" +export INFERENCE_BACKEND="${INFERENCE_BACKEND:-vllm}" +export MODALITY="${MODALITY:-tool_use}" +export NUM_EPOCHS="${NUM_EPOCHS:-20}" +export MAX_TURNS="${MAX_TURNS:-10}" +export MAX_INPUT_LENGTH="${MAX_INPUT_LENGTH:-72000}" +export MAX_GENERATE_LENGTH="${MAX_GENERATE_LENGTH:-4096}" +export NUM_INFERENCE_ENGINES="${NUM_INFERENCE_ENGINES:-8}" +export JUDGE_MODEL="${JUDGE_MODEL:-anthropic/claude-sonnet-4.5}" +export EVALUATOR_MODEL="${EVALUATOR_MODEL:-anthropic/claude-sonnet-4.5}" +export K_ROLLOUTS="${K_ROLLOUTS:-4}" +export ALPHA="${ALPHA:-1.0}" +export MAX_EVAL_STEPS="${MAX_EVAL_STEPS:-20}" + +: "${FLEET_API_KEY:?Set FLEET_API_KEY before running}" +: "${WANDB_API_KEY:?Set WANDB_API_KEY before running}" + +# Optional: per-env dataset filtering via TASK_GEN_ENV_CLASSES env var +ENV_FILTER_ARGS=() +if [ -n "${TASK_GEN_ENV_CLASSES:-}" ]; then + echo "=== env_filter: $TASK_GEN_ENV_CLASSES ===" + ENV_FILTER_ARGS+=("data.env_filter=$TASK_GEN_ENV_CLASSES") +fi + +# Task-gen GRPO training with 35B model +# --entrypoint: task-gen entrypoint (not main_fleet) +# --env-class: task_gen environment (not fleet_task) +# TP=2: 8 engines × 2 GPUs each across 2 nodes (16 GPUs total) +# flash_attn=false: SDPA to avoid Xid 31 in GatedDeltaNet with vLLM 0.18.0 +# loss_chunk_size=4096: chunked lm_head to avoid OOM on 131K vocab +# --no-pytorch-alloc-conf: disables expandable_segments (conflicts with vLLM CuMemAllocator) +bash scripts/fleet-common-run.sh \ + --use-python-direct --cuda-env "$HOME/.cuda_env" \ + --set-ulimit --no-pytorch-alloc-conf \ + --nccl-heartbeat 1800 \ + --entrypoint integrations.fleet.entrypoints.main_task_gen \ + --env-class task_gen -- \ + trainer.algorithm.advantage_estimator="grpo" \ + trainer.policy.model.path="Qwen/Qwen3.5-35B-A3B" \ + trainer.flash_attn=false \ + trainer.loss_chunk_size=4096 \ + trainer.use_sample_packing=false \ + generator.inference_engine_tensor_parallel_size=2 \ + trainer.epochs=${NUM_EPOCHS} \ + trainer.eval_batch_size=8 \ + trainer.eval_before_train=false \ + trainer.eval_interval=10 \ + trainer.update_epochs_per_batch=1 \ + trainer.train_batch_size=12 \ + trainer.use_hybrid_env_sampling=true \ + trainer.min_samples_per_env=1 \ + trainer.policy_mini_batch_size=12 \ + trainer.micro_forward_batch_size_per_gpu=1 \ + trainer.micro_train_batch_size_per_gpu=1 \ + trainer.ckpt_interval=10 \ + trainer.max_ckpts_to_keep=1 \ + trainer.max_prompt_length=4096 \ + generator.max_input_length=$MAX_INPUT_LENGTH \ + generator.sampling_params.max_generate_length=$MAX_GENERATE_LENGTH \ + generator.sampling_params.temperature=0.95 \ + generator.sampling_params.top_p=0.95 \ + 'generator.sampling_params.stop=["", ""]' \ + generator.eval_sampling_params.temperature=0.95 \ + generator.eval_sampling_params.top_p=0.95 \ + 'generator.eval_sampling_params.stop=["", ""]' \ + trainer.policy.optimizer_config.lr=5.0e-7 \ + trainer.algorithm.use_kl_loss=true \ + trainer.algorithm.kl_loss_coef=1.0 \ + trainer.algorithm.entropy_loss_coef=0.001 \ + trainer.algorithm.zero_variance_filter=true \ + generator.max_turns=$MAX_TURNS \ + generator.backend=$INFERENCE_BACKEND \ + generator.run_engines_locally=true \ + generator.weight_sync_backend=nccl \ + generator.async_engine=true \ + generator.batched=false \ + generator.trajectory_timeout_seconds=1800 \ + generator.use_conversation_multi_turn=true \ + generator.n_samples_per_prompt=8 \ + generator.eval_n_samples_per_prompt=3 \ + generator.enforce_eager=false \ + generator.gpu_memory_utilization=0.65 \ + trainer.logger="$LOGGER" \ + trainer.project_name="fleet-task-gen" \ + trainer.run_name="$RUN_NAME" \ + trainer.resume_mode=latest \ + trainer.ckpt_path="$HOME/ckpts/task_gen_35b" \ + trainer.dump_data_batch=true \ + ++environment.skyrl_gym.task_gen.max_turns=$MAX_TURNS \ + ++environment.skyrl_gym.task_gen.judge_model="$JUDGE_MODEL" \ + ++environment.skyrl_gym.task_gen.k_rollouts=$K_ROLLOUTS \ + ++environment.skyrl_gym.task_gen.alpha=$ALPHA \ + ++environment.skyrl_gym.task_gen.max_eval_steps=$MAX_EVAL_STEPS \ + ++environment.skyrl_gym.task_gen.evaluator_model="$EVALUATOR_MODEL" \ + ++environment.skyrl_gym.task_gen.eval_k_rollouts=8 \ + ++environment.skyrl_gym.task_gen.tool_call_reward_per_call=0.02 \ + "${ENV_FILTER_ARGS[@]}" \ + "$@" diff --git a/scripts/fleet-task-gen-launch-per-env.sh b/scripts/fleet-task-gen-launch-per-env.sh new file mode 100755 index 0000000000..a1d22970f8 --- /dev/null +++ b/scripts/fleet-task-gen-launch-per-env.sh @@ -0,0 +1,59 @@ +#!/usr/bin/env bash +# Launch per-env task-gen experiments — one SkyPilot cluster per environment. +# Targets ~40 training steps per env by computing NUM_EPOCHS from seed counts. +# +# Usage: +# export FLEET_API_KEY=... WANDB_API_KEY=... OPENROUTER_API_KEY=... +# export AWS_ACCESS_KEY_ID=... AWS_SECRET_ACCESS_KEY=... +# bash scripts/fleet-task-gen-launch-per-env.sh [env1 env2 ...] +# +# If no envs specified, defaults to: ticketmaster zillow outlook +set -eo pipefail + +YAML="tasks/task-gen-grpo-qwen3_5-9b.yaml" +TARGET_STEPS=40 +BATCH_SIZE=12 + +# Required env vars +: "${FLEET_API_KEY:?set FLEET_API_KEY}" +: "${WANDB_API_KEY:?set WANDB_API_KEY}" +: "${OPENROUTER_API_KEY:?set OPENROUTER_API_KEY}" +: "${AWS_ACCESS_KEY_ID:?set AWS_ACCESS_KEY_ID}" +: "${AWS_SECRET_ACCESS_KEY:?set AWS_SECRET_ACCESS_KEY}" + +# Seed counts per env from v55 dataset (after EVAL_RATIO=0.05 split) +get_seeds() { + case "$1" in + booking) echo 539 ;; budget) echo 567 ;; carlisle) echo 336 ;; + outlook) echo 181 ;; reddit) echo 505 ;; rops-mail) echo 44 ;; + ticketmaster) echo 212 ;; zillow) echo 106 ;; *) echo 100 ;; + esac +} + +# Default envs if none specified on command line +if [[ $# -gt 0 ]]; then + ENVS=("$@") +else + ENVS=(ticketmaster zillow outlook) +fi + +for env in "${ENVS[@]}"; do + seeds=$(get_seeds "$env") + steps_per_epoch=$(( (seeds + BATCH_SIZE - 1) / BATCH_SIZE )) + num_epochs=$(( (TARGET_STEPS + steps_per_epoch - 1) / steps_per_epoch )) + total_steps=$(( steps_per_epoch * num_epochs )) + + echo "=== Launching task-gen-${env}: ${seeds} seeds, ${steps_per_epoch} steps/epoch, ${num_epochs} epochs (${total_steps} steps) ===" + sky launch -c "task-gen-${env}" "$YAML" \ + --env TASK_GEN_ENV_CLASSES="$env" \ + --env NUM_EPOCHS="$num_epochs" \ + --env FLEET_API_KEY="$FLEET_API_KEY" \ + --env WANDB_API_KEY="$WANDB_API_KEY" \ + --env OPENROUTER_API_KEY="$OPENROUTER_API_KEY" \ + --env AWS_ACCESS_KEY_ID="$AWS_ACCESS_KEY_ID" \ + --env AWS_SECRET_ACCESS_KEY="$AWS_SECRET_ACCESS_KEY" \ + --yes --async +done + +echo "" +echo "Launched ${#ENVS[@]} clusters. Monitor with: sky status" diff --git a/scripts/fleet-task-gen-run.sh b/scripts/fleet-task-gen-run.sh new file mode 100755 index 0000000000..ca1717183a --- /dev/null +++ b/scripts/fleet-task-gen-run.sh @@ -0,0 +1,91 @@ +#!/usr/bin/env bash +# Task-gen specific run: calls common run with task-gen entrypoint and hydra overrides +# +# Usage (from SkyPilot YAML run block): +# bash skyrl-train/scripts/fleet-task-gen-run.sh +# +# Required env vars: WANDB_API_KEY, MODALITY, INFERENCE_BACKEND, LOGGER, +# MAX_TURNS, MAX_INPUT_LENGTH, MAX_GENERATE_LENGTH, NUM_EPOCHS, +# JUDGE_MODEL, K_ROLLOUTS, ALPHA, MAX_EVAL_STEPS +# SkyPilot env vars: SKYPILOT_NUM_GPUS_PER_NODE, SKYPILOT_NODE_IPS +set -euo pipefail + +# Export RUN_NAME so task_gen_env can tag rollout dumps +# Always use random hex suffix for unique run names +export RUN_NAME="task_gen_$(python3 -c 'import os; print(os.urandom(4).hex())')" + +# Optional: per-env dataset filtering via TASK_GEN_ENV_CLASSES env var +# e.g. TASK_GEN_ENV_CLASSES="outlook" or TASK_GEN_ENV_CLASSES="outlook,booking" +ENV_FILTER_ARGS=() +if [ -n "${TASK_GEN_ENV_CLASSES:-}" ]; then + echo "=== env_filter: $TASK_GEN_ENV_CLASSES ===" + ENV_FILTER_ARGS+=("data.env_filter=$TASK_GEN_ENV_CLASSES") +fi + +# Task-gen GRPO training via shared run script +# --entrypoint: task-gen entrypoint (not main_fleet) +# --env-class: task_gen environment (not fleet_task) +# --data-dir-name: parquet files are in data/fleet/task_gen/ (not data/fleet/tool_use/) +# TP=1: N engines × 1 GPU each (Qwen3.5-9B fits in single H200) +# num_inference_engines auto-detected from SKYPILOT_NUM_GPUS_PER_NODE by fleet-common-run.sh +bash scripts/fleet-common-run.sh \ + --use-python-direct --cuda-env "$HOME/.cuda_env" \ + --set-ulimit --no-pytorch-alloc-conf \ + --entrypoint integrations.fleet.entrypoints.main_task_gen \ + --env-class task_gen -- \ + trainer.algorithm.advantage_estimator="grpo" \ + trainer.policy.model.path="Qwen/Qwen3.5-9B" \ + trainer.flash_attn=false \ + trainer.use_sample_packing=false \ + generator.inference_engine_tensor_parallel_size=1 \ + trainer.epochs=${NUM_EPOCHS} \ + trainer.eval_batch_size=12 \ + trainer.eval_before_train=false \ + trainer.eval_interval=10 \ + trainer.update_epochs_per_batch=1 \ + trainer.train_batch_size=12 \ + trainer.use_hybrid_env_sampling=true \ + trainer.min_samples_per_env=1 \ + trainer.policy_mini_batch_size=12 \ + trainer.micro_forward_batch_size_per_gpu=1 \ + trainer.micro_train_batch_size_per_gpu=1 \ + trainer.loss_chunk_size=4096 \ + trainer.ckpt_interval=10 \ + trainer.max_prompt_length=4096 \ + generator.max_input_length=$MAX_INPUT_LENGTH \ + generator.sampling_params.max_generate_length=$MAX_GENERATE_LENGTH \ + generator.sampling_params.temperature=0.95 \ + generator.sampling_params.top_p=0.95 \ + 'generator.sampling_params.stop=["", ""]' \ + generator.eval_sampling_params.temperature=0.95 \ + generator.eval_sampling_params.top_p=0.95 \ + 'generator.eval_sampling_params.stop=["", ""]' \ + trainer.policy.optimizer_config.lr=1.0e-6 \ + trainer.algorithm.use_kl_loss=true \ + generator.max_turns=$MAX_TURNS \ + generator.backend=$INFERENCE_BACKEND \ + generator.run_engines_locally=true \ + generator.weight_sync_backend=nccl \ + generator.async_engine=true \ + generator.batched=false \ + generator.trajectory_timeout_seconds=1800 \ + generator.use_conversation_multi_turn=true \ + generator.n_samples_per_prompt=8 \ + generator.eval_n_samples_per_prompt=3 \ + generator.gpu_memory_utilization=0.75 \ + trainer.logger="$LOGGER" \ + trainer.project_name="fleet-task-gen" \ + trainer.run_name="$RUN_NAME" \ + trainer.resume_mode=latest \ + trainer.ckpt_path="$HOME/ckpts/task_gen" \ + trainer.dump_data_batch=true \ + ++environment.skyrl_gym.task_gen.max_turns=$MAX_TURNS \ + ++environment.skyrl_gym.task_gen.judge_model="$JUDGE_MODEL" \ + ++environment.skyrl_gym.task_gen.k_rollouts=$K_ROLLOUTS \ + ++environment.skyrl_gym.task_gen.alpha=$ALPHA \ + ++environment.skyrl_gym.task_gen.max_eval_steps=$MAX_EVAL_STEPS \ + ++environment.skyrl_gym.task_gen.evaluator_model="${EVALUATOR_MODEL:-anthropic/claude-sonnet-4.5}" \ + ++environment.skyrl_gym.task_gen.eval_k_rollouts=8 \ + ++environment.skyrl_gym.task_gen.tool_call_reward_per_call=0.02 \ + "${ENV_FILTER_ARGS[@]}" \ + "$@" diff --git a/scripts/fleet-tinker-tool-use-run.sh b/scripts/fleet-tinker-tool-use-run.sh new file mode 100755 index 0000000000..380a67ff44 --- /dev/null +++ b/scripts/fleet-tinker-tool-use-run.sh @@ -0,0 +1,34 @@ +#!/usr/bin/env bash +# Launch Fleet tool-use training via Tinker hosted service. +# Mirrors fleet-35b-run.sh config but uses the Tinker backend. +# +# Required env vars: TINKER_API_KEY, FLEET_API_KEY, WANDB_API_KEY +# Optional: TINKER_API_URL (SDK uses default if not set) +set -euo pipefail + +export TINKER_API_KEY="${TINKER_API_KEY:?Set TINKER_API_KEY}" +export TINKER_API_URL="${TINKER_API_URL:-}" +export FLEET_API_KEY="${FLEET_API_KEY:?Set FLEET_API_KEY}" +export WANDB_API_KEY="${WANDB_API_KEY:?Set WANDB_API_KEY}" + +cd "$(dirname "$0")/.." # cd to SkyRL-v2 root + +python -m integrations.fleet.entrypoints.main_fleet_tinker \ + --model-name Qwen/Qwen3.5-35B-A3B \ + --tasks-file "${TASKS_FILE:?Set TASKS_FILE}" \ + --dataset-file "${DATASET_FILE:?Set DATASET_FILE}" \ + --batch-size 16 \ + --learning-rate 5.0e-7 \ + --lora-rank 16 \ + --max-steps 200 \ + --max-turns 50 \ + --max-generate-length 4096 \ + --max-input-length 96000 \ + --n-samples-per-prompt 8 \ + --eval-every 20 \ + --temperature 0.9 \ + --top-p 0.95 \ + --stop-sequences '[""]' \ + --loss-fn ppo \ + --wandb-project fleet-tinker-grpo \ + "$@" diff --git a/scripts/fleet-vl-run.sh b/scripts/fleet-vl-run.sh new file mode 100755 index 0000000000..12e728fbcb --- /dev/null +++ b/scripts/fleet-vl-run.sh @@ -0,0 +1,94 @@ +#!/usr/bin/env bash +# VL/CUA (Vision-Language / Computer Use Agent) GRPO training config. +# Called by the SkyPilot YAML and by fleet-research run.sh. +# +# Based on working config from SkyRL PR #288 (feat/vl-support-clean), +# adapted to SkyRL-v2's fleet-common-run.sh pattern. +# +# Model: Qwen/Qwen3.5-9B (9B params, natively multimodal, GatedDeltaNet) +# TP=1 (single GPU per engine, 8 engines on 8x H200) +# Modality: browser_use (screenshots + coordinate normalization) +# +# Required env vars: FLEET_API_KEY, WANDB_API_KEY +# Optional: AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY (for S3 checkpoints) +set -euo pipefail +cd "$(dirname "$0")/.." # cd to SkyRL root (scripts/ is directly under repo root) + +# Defaults for vars normally set by SkyPilot YAML envs block +export LOGGER="${LOGGER:-wandb}" +export INFERENCE_BACKEND="${INFERENCE_BACKEND:-vllm}" +export DATA_VERSION="${DATA_VERSION:-v6}" +export MODALITY="${MODALITY:-browser_use}" +export NUM_EPOCHS="${NUM_EPOCHS:-10}" +export MAX_TURNS="${MAX_TURNS:-80}" +export MAX_INPUT_LENGTH="${MAX_INPUT_LENGTH:-80000}" +export MAX_GENERATE_LENGTH="${MAX_GENERATE_LENGTH:-4096}" +export ENV_KEYS="${ENV_KEYS:-}" +export DIFFICULTY="${DIFFICULTY:-}" +export RUN_ID="${RUN_ID:-}" +export MAX_TASKS="${MAX_TASKS:-}" +export RESUME_RUN_NAME="${RESUME_RUN_NAME:-}" +export AWS_REGION="${AWS_REGION:-us-east-1}" +export S3_DATASET_BUCKET="${S3_DATASET_BUCKET:-fleet-internal-datasets}" +export S3_CHECKPOINT_BUCKET="${S3_CHECKPOINT_BUCKET:-skyrl-checkpoints}" +export S3_TRAJECTORY_BUCKET="${S3_TRAJECTORY_BUCKET:-skyrl-trajectories}" + +: "${FLEET_API_KEY:?Set FLEET_API_KEY before running}" +: "${WANDB_API_KEY:?Set WANDB_API_KEY before running}" + +bash scripts/fleet-common-run.sh \ + --use-python-direct --cuda-env "$HOME/.cuda_env" \ + --set-ulimit --no-pytorch-alloc-conf -- \ + environment.skyrl_gym.fleet_task.ttl_seconds=1800 \ + environment.skyrl_gym.fleet_task.partial_reward=true \ + environment.skyrl_gym.fleet_task.enable_hints=false \ + trainer.algorithm.advantage_estimator=grpo \ + trainer.policy.model.path="Qwen/Qwen3.5-9B" \ + trainer.flash_attn=false \ + trainer.loss_chunk_size=4096 \ + trainer.use_sample_packing=false \ + trainer.algorithm.loss_reduction="sequence_mean" \ + +generator.chat_template_kwargs='{enable_thinking:true}' \ + +generator.engine_init_kwargs.mm_processor_cache_gb=0 \ + generator.inference_engine_tensor_parallel_size=1 \ + trainer.epochs=${NUM_EPOCHS} \ + trainer.eval_batch_size=12 \ + trainer.eval_before_train=false \ + trainer.eval_interval=10 \ + trainer.update_epochs_per_batch=1 \ + trainer.train_batch_size=50 \ + trainer.use_hybrid_env_sampling=true \ + trainer.min_samples_per_env=2 \ + trainer.policy_mini_batch_size=50 \ + trainer.micro_forward_batch_size_per_gpu=1 \ + trainer.micro_train_batch_size_per_gpu=1 \ + trainer.ckpt_interval=10 \ + trainer.max_prompt_length=2048 \ + generator.max_input_length=$MAX_INPUT_LENGTH \ + generator.sampling_params.max_generate_length=$MAX_GENERATE_LENGTH \ + generator.sampling_params.temperature=0.9 \ + generator.sampling_params.top_p=0.95 \ + 'generator.sampling_params.stop=[""]' \ + 'generator.eval_sampling_params.stop=[""]' \ + trainer.policy.optimizer_config.lr=5.0e-7 \ + trainer.algorithm.use_kl_loss=true \ + trainer.algorithm.zero_variance_filter=true \ + generator.max_turns=$MAX_TURNS \ + generator.backend=$INFERENCE_BACKEND \ + generator.run_engines_locally=true \ + generator.weight_sync_backend=nccl \ + generator.async_engine=true \ + generator.batched=false \ + generator.use_conversation_multi_turn=true \ + generator.n_samples_per_prompt=4 \ + generator.eval_n_samples_per_prompt=3 \ + generator.gpu_memory_utilization=0.80 \ + generator.trajectory_timeout_seconds=900 \ + trainer.logger="$LOGGER" \ + trainer.project_name="fleet-browser-use-grpo" \ + trainer.run_name="fleet_qwen35_${MODALITY}_${RUN_ID:-$(head -c 4 /dev/urandom | xxd -p)}" \ + trainer.resume_mode=latest \ + trainer.ckpt_path="$HOME/ckpts/fleet_qwen35_${MODALITY}" \ + trainer.export_path="$HOME/exports" \ + trainer.dump_data_batch=true \ + "$@" diff --git a/skyrl-gym/pyproject.toml b/skyrl-gym/pyproject.toml index b7beb59d9f..88f42c68eb 100644 --- a/skyrl-gym/pyproject.toml +++ b/skyrl-gym/pyproject.toml @@ -25,12 +25,15 @@ dependencies = [ Repository = "https://github.com/NovaSky-AI/SkyRL" [tool.setuptools.packages.find] -include = ["skyrl_gym*"] +include = ["skyrl_gym*", "skyrl_taste*"] [project.optional-dependencies] dev = [ "pytest" ] +fleet = [ + "openenv[fleet]", +] [tool.black] line-length = 120 diff --git a/skyrl-gym/skyrl_gym/envs/__init__.py b/skyrl-gym/skyrl_gym/envs/__init__.py index 770b65e1e8..6a28c661e5 100644 --- a/skyrl-gym/skyrl_gym/envs/__init__.py +++ b/skyrl-gym/skyrl_gym/envs/__init__.py @@ -36,3 +36,13 @@ id="searchcode", entry_point="skyrl_gym.envs.searchcode.env:SearchCodeEnv", ) + +register( + id="fleet_task", + entry_point="skyrl_gym.envs.fleet_task.env:FleetTaskEnv", +) + +register( + id="task_gen", + entry_point="skyrl_gym.envs.task_gen.task_gen_env:TaskGenEnv", +) diff --git a/skyrl-gym/skyrl_gym/envs/fleet_task/__init__.py b/skyrl-gym/skyrl_gym/envs/fleet_task/__init__.py new file mode 100644 index 0000000000..922066c478 --- /dev/null +++ b/skyrl-gym/skyrl_gym/envs/fleet_task/__init__.py @@ -0,0 +1,10 @@ +"""Fleet Task Environment for SkyRL-Gym. + +Provides a multi-turn tool-use environment backed by Fleet-hosted environments, +using OpenEnv's FleetTaskEnv as the abstraction layer. +""" + +from skyrl_gym.envs.fleet_task.env import FleetTaskEnv +from skyrl_gym.envs.fleet_task.tool_call_parser import parse_tool_call + +__all__ = ["FleetTaskEnv", "parse_tool_call"] diff --git a/skyrl-gym/skyrl_gym/envs/fleet_task/env.py b/skyrl-gym/skyrl_gym/envs/fleet_task/env.py new file mode 100644 index 0000000000..1dc4471fb6 --- /dev/null +++ b/skyrl-gym/skyrl_gym/envs/fleet_task/env.py @@ -0,0 +1,1126 @@ +"""Fleet Task Environment for SkyRL-Gym. + +Provides a SkyRL-compatible environment wrapper for Fleet-hosted tasks. +Uses OpenEnv's FleetTaskEnv as the abstraction layer for Fleet environments, +keeping a clean separation between SkyRL's training interface and Fleet's +environment management. + +Multi-modal support: When the task modality is "computer_use" or "browser_use", step() returns +multimodal observations in OpenAI format (image_url content blocks). Upstream +SkyRL's generator already handles these via extract_images_from_conversation() +and passes them as multi_modal_data to vLLM — no upstream changes needed. +""" + +import ast +import asyncio +import json +import logging +import os +import re +import time +from datetime import datetime +from typing import Any, Dict, List, Optional, Tuple + +from skyrl_gym.envs.base_text_env import ( + BaseTextEnv, + BaseTextEnvStepOutput, + ConversationType, +) +from skyrl_gym.envs.fleet_task.tool_call_parser import parse_tool_call + +# Reduce MCP client log noise +try: + from loguru import logger as loguru_logger + + loguru_logger.disable("mcp") +except ImportError: + pass +logging.getLogger("mcp").setLevel(logging.WARNING) + +logger = logging.getLogger(__name__) + +# Global task cache to avoid reloading JSON for each env instance +_TASK_CACHE: Dict[str, Dict[str, Any]] = {} + + +def load_tasks_from_json(tasks_file: str) -> Dict[str, Any]: + """Load tasks from JSON file with caching. + + Returns a dict mapping task_key -> task_config dict. + """ + if tasks_file not in _TASK_CACHE: + expanded_path = os.path.expanduser(tasks_file) + if not os.path.exists(expanded_path): + raise FileNotFoundError(f"Tasks file not found: {expanded_path}") + + with open(expanded_path, "r") as f: + data = json.load(f) + + # Handle both formats: array or {"tasks": [...]} + if isinstance(data, list): + tasks = data + elif isinstance(data, dict) and "tasks" in data: + tasks = data["tasks"] + else: + raise ValueError( + f"Invalid JSON format in {tasks_file}: expected array or object with 'tasks' key" + ) + + if not tasks: + raise ValueError(f"No tasks found in {tasks_file}") + + # Index by task_key (support both 'key' and 'task_key' fields) + _TASK_CACHE[tasks_file] = { + t.get("key") or t.get("task_key"): t for t in tasks + } + + return _TASK_CACHE[tasks_file] + + +def clear_caches(): + """Clear global caches. Useful for testing.""" + global _TASK_CACHE + _TASK_CACHE = {} + + +class FleetTaskEnv(BaseTextEnv): + """SkyRL environment for Fleet-hosted tasks. + + Uses OpenEnv's FleetTaskEnv as the abstraction layer for Fleet environments. + This provides a clean separation between SkyRL's training interface and + Fleet's environment management. + + Constructor signature follows upstream convention: + __init__(self, env_config=None, extras={}) + + Where: + env_config: Dict or DictConfig from skyrl_gym_config YAML + extras: Per-sample data from the training dataset (task_key, max_turns, etc.) + """ + + _trace_config: Optional[Dict[str, str]] = None + + @classmethod + def set_trace_config(cls, job_id: str, model: str): + """Set trace config for uploading eval traces to Fleet.""" + cls._trace_config = {"job_id": job_id, "model": model} + + @classmethod + def clear_trace_config(cls): + """Clear trace config after eval is done.""" + cls._trace_config = None + + def __init__( + self, + env_config=None, + extras: Dict[str, Any] = {}, + ): + super().__init__() + + if env_config is None: + env_config = {} + + self.extras = extras + self.max_turns = extras.get("max_turns", 50) + + # Task configuration from extras (set by dataset) + self.task_key = extras.get("task_key") + self.tasks_file = ( + env_config.get("tasks_file") if hasattr(env_config, "get") else None + ) or extras.get("tasks_file") + + if not self.task_key: + raise ValueError("task_key must be provided in extras (from dataset)") + if not self.tasks_file: + raise ValueError( + "tasks_file must be provided in env_config or extras" + ) + + # Expand path + self.tasks_file = os.path.expanduser(self.tasks_file) + + # Load task config from JSON + tasks = load_tasks_from_json(self.tasks_file) + self.task_config = tasks.get(self.task_key) + if not self.task_config: + available_keys = list(tasks.keys())[:5] + raise ValueError( + f"Task '{self.task_key}' not found in {self.tasks_file}. " + f"Available keys (first 5): {available_keys}" + ) + + # API key + self.api_key = ( + env_config.get("api_key") if hasattr(env_config, "get") else None + ) or os.environ.get("FLEET_API_KEY") + if not self.api_key: + raise ValueError( + "FLEET_API_KEY must be set in env_config or environment" + ) + + # Logfire telemetry (no-op if LOGFIRE_TOKEN is not set) + logfire_token = os.environ.get("LOGFIRE_TOKEN") + if logfire_token: + try: + from envs.fleet_env import configure_fleet_telemetry + + configure_fleet_telemetry(token=logfire_token) + except ImportError: + pass + + # TTL for Fleet environment instances + self.ttl_seconds = ( + env_config.get("ttl_seconds") if hasattr(env_config, "get") else None + ) + + # Partial reward: use verifier accumulator counts instead of binary 0/1 + self.partial_reward = ( + env_config.get("partial_reward", False) + if hasattr(env_config, "get") + else False + ) + + # Taste judge (LLM-as-judge) GATED reward: + # effective_taste = max(taste_floor, taste_score) (1.0 on judge fail) + # final_reward = verifier_reward * effective_taste + self.taste_floor = float( + env_config.get("taste_floor", 0.1) + if hasattr(env_config, "get") + else 0.1 + ) + if not 0.0 <= self.taste_floor <= 1.0: + raise ValueError( + f"taste_floor must be in [0,1], got {self.taste_floor}" + ) + self.taste_judge_timeout_s = float( + env_config.get("taste_judge_timeout_s", 10.0) + if hasattr(env_config, "get") + else 10.0 + ) + self.last_verifier_reward: Optional[float] = None + self.last_taste_reward: Optional[float] = None + self.last_effective_taste: Optional[float] = None + self.last_taste_judge_failed: bool = False + + # Hint config + self.enable_hints = ( + env_config.get("enable_hints", False) + if hasattr(env_config, "get") + else False + ) + + # Environment state (initialized on init()) + self.openenv_task_env = None + self.chat_history: ConversationType = [] + self.turns = 0 + self.tool_calls = 0 + self.tool_errors = 0 + self.last_reward: Optional[float] = None + self.tools: List[Dict[str, Any]] = [] + + # Verifier feedback (captured at close time for hint generation) + self._verifier_stdout: Optional[str] = None + self._verifier_error: Optional[str] = None + self._tool_error_messages: List[str] = [] + + # Context management (uses OpenEnv's ContextManager) + self.enable_context_tools = ( + env_config.get("enable_context_tools", False) + if hasattr(env_config, "get") + else False + ) + self.context_manager = None + if self.enable_context_tools: + try: + from envs.fleet_env import ContextManager + + logger.info( + "Enabling context management tools with " + f"max_output_chars={extras.get('max_output_chars', 10000)}" + ) + self.context_manager = ContextManager( + max_output_chars=extras.get("max_output_chars", 10000) + ) + except ImportError: + logger.warning( + "ContextManager not available, disabling context tools" + ) + + def _adapt_computer_tool_for_qwen(self): + """Adapt computer tool description for Qwen VL's [0, 1000] coordinate space. + + Qwen3-VL/3.5 output coordinates in a normalized [0, 1000] grid regardless + of screen resolution. This rewrites tool descriptions to match, and + _convert_qwen_coordinates() converts back to pixels before MCP execution. + """ + for tool in self.tools: + func = tool.get("function", {}) + if func.get("name") != "computer": + continue + + desc = func.get("description", "") + + # Parse actual screen dimensions + res_match = re.search(r"Screen resolution:\s*(\d+)x(\d+)", desc) + if res_match: + self.screen_width = int(res_match.group(1)) + self.screen_height = int(res_match.group(2)) + else: + self.screen_width = 1366 + self.screen_height = 768 + + w, h = self.screen_width, self.screen_height + + # Rewrite description for Qwen's [0, 1000] coordinate space + desc = re.sub( + r"Screen resolution:\s*\d+x\d+\s*pixels\s*(\([^)]*\))?", + "Screen resolution: 1000x1000", + desc, + ) + desc = re.sub( + r"\(0, 0\) is top-left,\s*\(\d+, \d+\) is bottom-right", + "(0, 0) is top-left, (999, 999) is bottom-right", + desc, + ) + desc = re.sub( + r"valid range: x=0-\d+, y=0-\d+", + "valid range: x=0-999, y=0-999", + desc, + ) + desc = re.sub( + r"JPEG format at \d+x\d+", + "JPEG format at 1000x1000", + desc, + ) + func["description"] = desc + + logger.info( + f"Adapted computer tool for Qwen VL: actual_screen={w}x{h}, " + f"model coordinate space=[0, 1000]" + ) + break + + def _convert_qwen_coordinates(self, tool_call: Dict[str, Any]): + """Convert Qwen's [0, 1000] normalized coordinates to pixel coordinates. + + Modifies tool_call arguments in-place. + """ + if not getattr(self, "screen_width", None) or not getattr( + self, "screen_height", None + ): + return + args = tool_call.get("arguments", {}) + if not args or tool_call.get("name") != "computer": + return + for field in ("coordinate", "start_coordinate"): + coords = args.get(field) + if ( + coords + and isinstance(coords, (list, tuple)) + and len(coords) == 2 + ): + args[field] = [ + int(coords[0] / 1000 * self.screen_width), + int(coords[1] / 1000 * self.screen_height), + ] + + def _normalize_task_config(self) -> Dict[str, Any]: + """Normalize task config to OpenEnv's expected format.""" + config = self.task_config.copy() + + # Map field names if needed + if "key" in config and "task_key" not in config: + config["task_key"] = config["key"] + if "env_id" in config and "env_key" not in config: + config["env_key"] = config["env_id"] + if "version" in config and "env_version" not in config: + config["env_version"] = config["version"] + + return config + + async def init_async( + self, prompt: ConversationType + ) -> Tuple[ConversationType, Dict[str, Any]]: + """Initialize the Fleet environment and return initial observation. + + Creates Fleet environment via OpenEnv's FleetTaskEnv and returns + the task prompt with tool definitions. + """ + from envs.fleet_env import FleetTaskEnv as OpenEnvFleetTaskEnv + + # Close any existing environment + self.close() + + # Create OpenEnv's FleetTaskEnv with normalized config + task_config = self._normalize_task_config() + + try: + self.openenv_task_env = OpenEnvFleetTaskEnv( + task_config=task_config, + api_key=self.api_key, + ttl_seconds=self.ttl_seconds, + max_steps=self.max_turns, + partial_reward=self.partial_reward, + ) + except Exception as e: + raise RuntimeError( + f"Failed to create OpenEnv FleetTaskEnv: {e}" + ) from e + + # Reset episode state (tools are already cached from __init__) + obs = await self.openenv_task_env.reset_async() + + # Reset state + self.turns = 0 + self.tool_calls = 0 + self.tool_errors = 0 + self.last_reward = None + + # Reset context manager if enabled + if self.context_manager: + self.context_manager.reset() + + # Get tools from observation + self.tools = obs.get("tools", []) + + # Add context management tools if enabled + if self.context_manager: + self.tools = self.tools + self.context_manager.get_tools() + if not self.tools: + raise RuntimeError( + f"Task {self.task_key}: no tools found. Fleet env requires tools." + ) + + # VL: adapt computer tool for Qwen's normalized coordinate space + modality = self.task_config.get("task_modality", "tool_use") + if modality in ("computer_use", "browser_use"): + self._adapt_computer_tool_for_qwen() + + # Build initial prompt with task instruction + task_prompt = self.task_config.get("prompt", "") + + # Inject hint from previous failed attempt if provided + hint = self.extras.get("hint") + if hint: + task_prompt = ( + f"{task_prompt}\n\nHere is feedback from a previous attempt " + f"to help you:\n{hint}" + ) + + # Build system prompt with tool definitions + tools_json = json.dumps(self.tools, indent=2) + current_date = datetime.now().strftime("%Y-%m-%d") + + # Build environment context section from env_variables + env_context = "" + env_vars = self.task_config.get("env_variables", {}) + if env_vars: + env_lines = [] + if "LOGGED_IN_USER" in env_vars: + env_lines.append( + f"- Logged in user ID: {env_vars['LOGGED_IN_USER']}" + ) + if "LOGGED_IN_NAME" in env_vars: + env_lines.append( + f"- Logged in as: {env_vars['LOGGED_IN_NAME']}" + ) + for key, value in env_vars.items(): + if key not in ( + "LOGGED_IN_USER", + "LOGGED_IN_NAME", + "CURRENT_DATE", + ): + env_lines.append(f"- {key}: {value}") + if env_lines: + env_context = ( + "\n## Environment Context\n" + + "\n".join(env_lines) + + "\n" + ) + + # Add environment-specific hints + env_key = self.task_config.get("env_key") or self.task_config.get( + "env_id" + ) + env_hints = "" + if env_key == "fostgres": + env_hints = ( + "\n## Database Exploration\n" + "Before writing SQL queries, first explore the database schema:\n" + "- List tables: SELECT table_name FROM information_schema.tables " + "WHERE table_schema = 'public'\n" + "- List columns: SELECT column_name, data_type FROM " + "information_schema.columns WHERE table_name = 'your_table'\n" + ) + + # Computer-use hints for VL models + computer_use_hints = "" + if modality in ("computer_use", "browser_use"): + computer_use_hints = ( + "\n## Browser Interaction Strategy\n" + "You are controlling a web browser via screenshots. Follow this loop:\n\n" + "1. **Act**: Perform ONE action (click, type, scroll, etc.)\n" + "2. **Observe**: Take a screenshot to see the result\n" + "3. **Adapt**: If the screen hasn't changed, try a DIFFERENT action\n\n" + "Key rules:\n" + "- After clicking or typing, ALWAYS take a screenshot next to see what happened\n" + "- NEVER repeat the same action more than twice. If it didn't work, try something different:\n" + " - Can't find an element by scrolling? Use the search bar or navigation menu instead\n" + " - Page not loading after a click? Try refreshing with key(\"F5\") or clicking a different element\n" + " - Form not submitting? Check if required fields are missing\n" + "- Use wait() only ONCE after a page navigation, then screenshot to check. Do not wait repeatedly\n" + "- When the task is fully complete, say . Do not keep clicking after finishing\n" + ) + + tool_names = [ + t["function"]["name"] for t in self.tools if "function" in t + ] + tool_names_str = ", ".join(tool_names) + + system_content = ( + f"You are a helpful agent. Complete the task by calling tools.\n\n" + f"## Current Date\n" + f"Today's date is {current_date}. When dates are mentioned without " + f"a year, assume the current year ({datetime.now().year}) or a " + f"future date.\n" + f"{env_context}{env_hints}{computer_use_hints}\n" + f"## Available Tools\n{tools_json}\n\n" + f"## Tool Call Format\n" + f"Use the tools listed above by name ({tool_names_str}). " + f"Format each call as:\n" + f'{{"name": "", "arguments": ' + f"{{...}}}}\n\n" + f"## Error Handling\n" + f"If a tool call returns an error:\n" + f"- Read the error message carefully\n" + f"- Do NOT repeat the same call with identical arguments\n" + f"- Change your approach: use different parameters, try a different " + f"tool, or break the task into smaller steps\n\n" + f"## Response Format\n" + f"EVERY response MUST end with exactly ONE of:\n" + f"1. A tool call: ... - to perform an action\n" + f"2. Done signal: - ONLY when the task is fully complete\n\n" + f"IMPORTANT: When the task is complete, first output your final " + f"answer with the requested information, THEN say . Do not " + f"just say without providing the answer.\n\n" + f"NEVER respond with just a message. NEVER say \"feel free to ask\" " + f"or offer further help.\n" + f"If the task is complete, provide your answer then say . " + f"Otherwise, make a tool call." + ) + + system_message = {"role": "system", "content": system_content} + + # VL: include initial screenshot in multimodal user message + initial_screenshot = obs.get("initial_screenshot") + if initial_screenshot and isinstance(initial_screenshot, list): + user_content = [{"type": "text", "text": task_prompt}] + for item in initial_screenshot: + if isinstance(item, dict) and item.get("type") == "image_url": + user_content.append(item) + user_message = {"role": "user", "content": user_content} + else: + user_message = {"role": "user", "content": task_prompt} + + self.chat_history = [system_message, user_message] + + metadata = { + "task_key": self.task_key, + "env_key": env_key, + "tools": self.tools, + "modality": self.task_config.get("task_modality", "tool_use"), + } + + return self.chat_history.copy(), metadata + + def init( + self, prompt: ConversationType + ) -> Tuple[ConversationType, Dict[str, Any]]: + """Initialize the Fleet environment (sync wrapper). + + Uses asyncio.run() for sync contexts. For async contexts, the upstream + generator's _run_in_executor_if_available will call this in a thread pool, + where asyncio.run() is safe. + """ + return asyncio.run(self.init_async(prompt)) + + async def step_async(self, action: str) -> BaseTextEnvStepOutput: + """Execute one step in the Fleet environment. + + Parses the action for tool calls, executes via OpenEnv's FleetTaskEnv, + and returns observation. Reward is computed by the verifier on completion. + + For computer_use/browser_use modality, observations may include multimodal content + (image_url blocks with base64 screenshots). Upstream SkyRL's generator + handles these via extract_images_from_conversation(). + """ + step_start = time.time() + self.turns += 1 + assistant_msg = {"role": "assistant", "content": action} + self.chat_history.append(assistant_msg) + if self.context_manager: + self.context_manager.track_message(assistant_msg) + + max_turns_reached = self.turns >= self.max_turns + + # Check if agent signals completion + agent_done = "" in action.lower() or "[done]" in action.lower() + + # Parse tool call from LLM response + tool_call = parse_tool_call(action) + + tool_result = None + error = None + reward = 0.0 + mcp_time = 0.0 + + # VL: catch done signal wrapped in a computer tool call + if ( + not agent_done + and tool_call + and tool_call.get("arguments", {}).get("action") == "done" + ): + agent_done = True + tool_call = None + + # VL: convert Qwen [0,1000] coordinates to pixel coordinates + if tool_call and getattr(self, "screen_width", None): + self._convert_qwen_coordinates(tool_call) + + # Handle context management tools locally (no MCP call) + if ( + tool_call + and self.context_manager + and self.context_manager.is_context_tool(tool_call["name"]) + ): + tool_result, self.chat_history = self.context_manager.execute_tool( + tool_call["name"], + tool_call.get("arguments", {}), + self.chat_history, + ) + # Execute tool call if present via OpenEnv + elif tool_call and self.openenv_task_env: + self.tool_calls += 1 + openenv_action = { + "tool": tool_call["name"], + "params": tool_call.get("arguments", {}), + "done": agent_done, + } + + try: + mcp_start = time.time() + obs, reward, done, info = ( + await self.openenv_task_env.step_async(openenv_action) + ) + mcp_time = time.time() - mcp_start + tool_result = obs.get("observation") + if "tool_error" in info: + error = info["tool_error"] + + # Truncate long outputs if context management is enabled + if ( + tool_result + and isinstance(tool_result, str) + and self.context_manager + ): + tool_result = self.context_manager.truncate_output( + tool_result + ) + except Exception as e: + mcp_time = time.time() - mcp_start + error = str(e) + elif agent_done and self.openenv_task_env: + # Agent signaled done without tool call + openenv_action = {"done": True} + try: + mcp_start = time.time() + obs, reward, done, info = ( + await self.openenv_task_env.step_async(openenv_action) + ) + mcp_time = time.time() - mcp_start + except Exception as e: + mcp_time = time.time() - mcp_start + error = str(e) + + # Detect error patterns in tool_result + if not error and tool_result: + result_str = ( + str(tool_result) + if not isinstance(tool_result, str) + else tool_result + ) + if result_str.strip().startswith( + "Error:" + ) or result_str.strip().startswith("error:"): + error = result_str + tool_result = None + elif isinstance(tool_result, dict) and tool_result.get("error"): + error = tool_result["error"] + tool_result = None + + episode_done = agent_done or max_turns_reached + + # Upload trace at episode end if trace config is set + if episode_done and FleetTaskEnv._trace_config: + try: + from envs.fleet_env.trace import upload_trace + + inst_id = None + orch = getattr(self.openenv_task_env, "_orch", None) + if orch: + fleet_env = getattr(orch, "_fleet_env", None) + if fleet_env: + inst_id = getattr(fleet_env, "instance_id", None) + await upload_trace( + api_key=self.api_key, + job_id=FleetTaskEnv._trace_config["job_id"], + task_key=self.task_key, + model=FleetTaskEnv._trace_config["model"], + chat_history=self.chat_history, + reward=reward, + instance_id=inst_id, + metadata={ + "env_key": self.task_config.get("env_key"), + "turns": self.turns, + }, + ) + except Exception as e: + logger.warning( + f"Failed to upload trace for {self.task_key}: {e}" + ) + + # Apply taste reward gating at episode end + if episode_done: + reward = await self._apply_taste_reward(reward, episode_done) + + # Build observation message + if max_turns_reached: + metadata = { + "done_reason": "max_turns", + "task_key": self.task_key, + "taste_reward": self.last_taste_reward, + "effective_taste": self.last_effective_taste, + "taste_floor": self.taste_floor, + "taste_judge_failed": self.last_taste_judge_failed, + } + return BaseTextEnvStepOutput( + observations=[], + reward=reward, + done=True, + metadata=metadata, + ) + + # Build response observation + if error: + self.tool_errors += 1 + self._tool_error_messages.append(str(error)[:500]) + obs_content = f"Error: {error}" + elif tool_result: + # Handle multimodal results (list with image_url blocks) + if isinstance(tool_result, list): + # Multimodal: return as structured content for VL models + new_obs = {"role": "user", "content": tool_result} + self.chat_history.append(new_obs) + if self.context_manager: + self.context_manager.track_message(new_obs) + + step_time = time.time() - step_start + metadata = { + "task_key": self.task_key, + "turn": self.turns, + "tool_call": tool_call, + "error": None, + "done_reason": "agent_done" if agent_done else None, + "step_time": step_time, + "mcp_time": mcp_time, + } + if episode_done: + metadata["taste_reward"] = self.last_taste_reward + metadata["effective_taste"] = self.last_effective_taste + metadata["taste_floor"] = self.taste_floor + metadata["taste_judge_failed"] = self.last_taste_judge_failed + return BaseTextEnvStepOutput( + observations=[new_obs], + reward=reward, + done=episode_done, + metadata=metadata, + ) + elif isinstance(tool_result, dict): + obs_content = ( + f"Tool result:\n{json.dumps(tool_result, indent=2)}" + ) + else: + obs_content = f"Tool result:\n{tool_result}" + elif agent_done: + obs_content = "Task marked as complete." + elif not tool_call: + obs_content = ( + "No tool call found. Use " + '{"name": "...", "arguments": {...}} ' + "format." + ) + else: + obs_content = "Action executed." + + new_obs = {"role": "user", "content": obs_content} + self.chat_history.append(new_obs) + if self.context_manager: + self.context_manager.track_message(new_obs) + + step_time = time.time() - step_start + metadata = { + "task_key": self.task_key, + "turn": self.turns, + "tool_call": tool_call, + "tool_result": ( + tool_result[:200] + if isinstance(tool_result, str) and len(tool_result) > 200 + else tool_result + ), + "error": error, + "done_reason": "agent_done" if agent_done else None, + "step_time": step_time, + "mcp_time": mcp_time, + } + + # If context was modified, return full chat_history so the generator + # can replace its copy (required for stepwise training). + if ( + tool_call + and self.context_manager + and self.context_manager.is_context_tool(tool_call["name"]) + ): + if tool_call["name"] == "manage_context": + metadata["modified_chat_history"] = self.chat_history.copy() + + if episode_done: + metadata["taste_reward"] = self.last_taste_reward + metadata["effective_taste"] = self.last_effective_taste + metadata["taste_floor"] = self.taste_floor + metadata["taste_judge_failed"] = self.last_taste_judge_failed + + return BaseTextEnvStepOutput( + observations=[new_obs], + reward=reward, + done=episode_done, + metadata=metadata, + ) + + async def _apply_taste_reward( + self, verifier_reward: float, episode_done: bool + ) -> float: + """Gate the binary verifier reward by the taste-judge score. + + On non-terminal steps we pass through verifier_reward unchanged. + On terminal steps we call the judge with a hard timeout; on + timeout/exception/None we set effective_taste=1.0 (pure verifier). + """ + if not episode_done: + return verifier_reward + + self.last_verifier_reward = float(verifier_reward) + self.last_taste_reward = None + self.last_effective_taste = None + self.last_taste_judge_failed = False + + try: + from skyrl_taste.judge import score_trajectory_async + except Exception as e: + logger.warning( + "skyrl_taste import failed (%s); verifier-only reward", e + ) + self.last_taste_judge_failed = True + self.last_effective_taste = 1.0 + return verifier_reward + + actions = [ + {"role": m.get("role"), "content": m.get("content")} + for m in self.chat_history + if m.get("role") == "assistant" + ] + task_text = self.task_config.get("prompt", "") if self.task_config else "" + outcome = bool(self.last_verifier_reward >= 1.0) + + taste_score: Optional[float] + try: + taste_score = await asyncio.wait_for( + score_trajectory_async(task_text, actions, outcome), + timeout=self.taste_judge_timeout_s, + ) + except asyncio.TimeoutError: + self.last_taste_judge_failed = True + logger.warning( + "taste judge timed out after %.1fs for task_key=%s", + self.taste_judge_timeout_s, + getattr(self, "task_key", "?"), + ) + taste_score = None + except Exception as e: + self.last_taste_judge_failed = True + logger.warning( + "taste judge raised %s: %s for task_key=%s", + type(e).__name__, e, getattr(self, "task_key", "?"), + ) + taste_score = None + + if taste_score is None: + self.last_effective_taste = 1.0 + return verifier_reward + + taste_score = max(0.0, min(1.0, float(taste_score))) + self.last_taste_reward = taste_score + effective_taste = max(self.taste_floor, taste_score) + self.last_effective_taste = effective_taste + return verifier_reward * effective_taste + + def step(self, action: str) -> BaseTextEnvStepOutput: + """Execute one step in the Fleet environment (sync wrapper).""" + return asyncio.run(self.step_async(action)) + + def _capture_verifier_feedback(self): + """Capture verifier feedback from OpenEnv before nulling the env.""" + if self.openenv_task_env: + self._verifier_stdout = getattr( + self.openenv_task_env, "verifier_stdout", None + ) + self._verifier_error = getattr( + self.openenv_task_env, "verifier_error", None + ) + self._tool_error_messages = getattr( + self.openenv_task_env, "tool_errors_list", [] + ) + + def close(self): + """Close the Fleet environment and cleanup resources.""" + if self.openenv_task_env: + try: + self.openenv_task_env.close() + if self.openenv_task_env.final_reward is not None: + self.last_reward = self.openenv_task_env.final_reward + self._capture_verifier_feedback() + except Exception as e: + logger.warning(f"Failed to close Fleet environment: {e}") + self.openenv_task_env = None + + async def close_async(self): + """Close the Fleet environment (async version). + + Runs verifier via OpenEnv's close_async() to get actual reward for + orphaned rollouts (context overflow, early termination by SkyRL). + """ + if self.openenv_task_env: + try: + await self.openenv_task_env.close_async() + if self.openenv_task_env.final_reward is not None: + self.last_reward = self.openenv_task_env.final_reward + self._capture_verifier_feedback() + except Exception as e: + logger.warning(f"Failed to close Fleet environment: {e}") + self.openenv_task_env = None + + def get_metrics(self) -> Dict[str, Any]: + """Return environment metrics for this episode.""" + metrics = { + "task_key": self.task_key, + "env_key": self.task_config.get("env_key") + or self.task_config.get("env_id"), + "turns": self.turns, + "tool_calls": self.tool_calls, + "tool_errors": self.tool_errors, + "is_hinted": bool(self.extras.get("hint")), + } + if self.last_reward is not None: + metrics["final_reward"] = self.last_reward + # Taste judge metrics + if self.last_taste_reward is not None: + metrics["taste_reward"] = self.last_taste_reward + if self.last_effective_taste is not None: + metrics["effective_taste"] = self.last_effective_taste + metrics["taste_floor"] = self.taste_floor + metrics["taste_judge_failed"] = self.last_taste_judge_failed + # Include verifier feedback for hint generation + if self._verifier_stdout is not None: + metrics["verifier_stdout"] = self._verifier_stdout + if self._verifier_error is not None: + metrics["verifier_error"] = self._verifier_error + if self._tool_error_messages: + metrics["tool_error_messages"] = self._tool_error_messages + # Include chat_history for LLM hint synthesis (consumed then deleted by generator) + if self.chat_history: + metrics["chat_history"] = self.chat_history + return metrics + + @staticmethod + def build_hint_text( + verifier_stdout: Optional[str], + verifier_error: Optional[str], + tool_error_messages: Optional[List[str]], + ) -> str: + """Build hint text from verifier feedback. No LLM call. + + Parses ERROR_ACCUMULATOR / SUCCESS_ACCUMULATOR from verifier stdout + and formats tool errors into structured feedback for the next attempt. + """ + parts = [] + + if verifier_stdout: + err_match = re.search( + r">>> ERROR_ACCUMULATOR >>>\n(.+?)\n<<< ERROR_ACCUMULATOR <<<", + verifier_stdout, + re.DOTALL, + ) + suc_match = re.search( + r">>> SUCCESS_ACCUMULATOR >>>\n(.+?)\n" + r"<<< SUCCESS_ACCUMULATOR <<<", + verifier_stdout, + re.DOTALL, + ) + if err_match or suc_match: + try: + errors = ( + ast.literal_eval(err_match.group(1)) + if err_match + else [] + ) + successes = ( + ast.literal_eval(suc_match.group(1)) + if suc_match + else [] + ) + except Exception: + errors, successes = [], [] + if successes: + parts.append( + f"Checks passed ({len(successes)}): " + + ", ".join( + str(s)[:100] for s in successes[:5] + ) + ) + if errors: + parts.append( + f"Checks failed ({len(errors)}): " + + ", ".join(str(e)[:100] for e in errors[:5]) + ) + + if verifier_error: + parts.append(f"Verifier: {verifier_error}") + + if tool_error_messages: + unique = list(dict.fromkeys(tool_error_messages))[:5] + parts.append( + "Tool errors: " + "; ".join(e[:200] for e in unique) + ) + + return ( + "\n".join(parts) + if parts + else "The previous attempt failed. Try a different approach." + ) + + @staticmethod + def aggregate_metrics( + metrics: List[Dict[str, Any]], + ) -> Dict[str, Any]: + """Aggregate metrics across episodes with per-env breakdown.""" + if not metrics: + return {} + + env_init_failures: Dict[str, int] = {} + total_init_failures = 0 + + env_data: Dict[str, Dict[str, List[int]]] = {} + for m in metrics: + # Check for init failure metrics + for key, value in m.items(): + if key.startswith("env_init_failed/"): + env_key = key.split("/", 1)[1] + env_init_failures[env_key] = ( + env_init_failures.get(env_key, 0) + int(value) + ) + total_init_failures += int(value) + + env_key = m.get("env_key") + if env_key: + if env_key not in env_data: + env_data[env_key] = { + "turns": [], + "tool_calls": [], + "tool_errors": [], + } + env_data[env_key]["turns"].append(m.get("turns", 0)) + env_data[env_key]["tool_calls"].append( + m.get("tool_calls", 0) + ) + env_data[env_key]["tool_errors"].append( + m.get("tool_errors", 0) + ) + + result: Dict[str, Any] = {} + total_turns = 0 + total_tool_calls = 0 + total_tool_errors = 0 + total_episodes = 0 + + for env_key, data in env_data.items(): + turns_list = data["turns"] + tool_calls_list = data["tool_calls"] + tool_errors_list = data["tool_errors"] + + avg_turns = sum(turns_list) / len(turns_list) + avg_tool_calls = sum(tool_calls_list) / len(tool_calls_list) + avg_tool_errors = sum(tool_errors_list) / len(tool_errors_list) + total_env_turns = sum(turns_list) + total_env_tool_calls = sum(tool_calls_list) + total_env_tool_errors = sum(tool_errors_list) + tool_calls_per_turn = ( + total_env_tool_calls / total_env_turns + if total_env_turns > 0 + else 0 + ) + tool_error_rate = ( + total_env_tool_errors / total_env_tool_calls + if total_env_tool_calls > 0 + else 0 + ) + + result[f"{env_key}/avg_turns"] = avg_turns + result[f"{env_key}/min_turns"] = min(turns_list) + result[f"{env_key}/max_turns"] = max(turns_list) + result[f"{env_key}/avg_tool_calls"] = avg_tool_calls + result[f"{env_key}/tool_calls_per_turn"] = tool_calls_per_turn + result[f"{env_key}/avg_tool_errors"] = avg_tool_errors + result[f"{env_key}/total_tool_errors"] = total_env_tool_errors + result[f"{env_key}/tool_error_rate"] = tool_error_rate + result[f"{env_key}/num_episodes"] = len(turns_list) + + total_turns += total_env_turns + total_tool_calls += total_env_tool_calls + total_tool_errors += total_env_tool_errors + total_episodes += len(turns_list) + + result["avg_turns"] = ( + total_turns / total_episodes if total_episodes > 0 else 0 + ) + result["avg_tool_calls"] = ( + total_tool_calls / total_episodes if total_episodes > 0 else 0 + ) + result["tool_calls_per_turn"] = ( + total_tool_calls / total_turns if total_turns > 0 else 0 + ) + result["avg_tool_errors"] = ( + total_tool_errors / total_episodes if total_episodes > 0 else 0 + ) + result["total_tool_errors"] = total_tool_errors + result["tool_error_rate"] = ( + total_tool_errors / total_tool_calls + if total_tool_calls > 0 + else 0 + ) + result["total_episodes"] = total_episodes + + for env_key, failures in env_init_failures.items(): + result[f"{env_key}/env_init_failed"] = failures + if total_init_failures > 0: + result["total_env_init_failed"] = total_init_failures + + return result diff --git a/skyrl-gym/skyrl_gym/envs/fleet_task/hint_synthesizer.py b/skyrl-gym/skyrl_gym/envs/fleet_task/hint_synthesizer.py new file mode 100644 index 0000000000..f71edbac53 --- /dev/null +++ b/skyrl-gym/skyrl_gym/envs/fleet_task/hint_synthesizer.py @@ -0,0 +1,262 @@ +"""LLM-synthesized hints for failed trajectories. + +Analyzes the full failed trajectory + verifier errors and produces actionable +guidance via an LLM (via litellm/OpenRouter). Falls back to static +build_hint_text() on failure. +""" + +import asyncio +import logging +import os +import time +from typing import Any, Dict, List, Optional, Tuple + +logger = logging.getLogger(__name__) + +# Category tag for LLM-synthesized hints +CATEGORY_LLM = "llm_synthesized" +CATEGORY_STATIC = "static_fallback" +CATEGORY_LLM_FAILED = "llm_failed_static_fallback" + +HINT_SYSTEM_PROMPT = """\ +You are a debugging assistant for an AI agent that failed a task. \ +Analyze the failed trajectory and verifier feedback, then provide \ +2-5 sentences of actionable guidance for the agent's next attempt. + +Rules: +- Be specific: reference exact actions that failed and why. +- Be actionable: tell the agent what to do differently, not just what went wrong. +- If the agent ran out of context/turns, suggest being more efficient (fewer unnecessary steps). +- If tool calls errored, explain the correct usage pattern. +- Do NOT repeat the task instructions verbatim. +- Do NOT say "the previous attempt failed" — the agent already knows that.""" + + +def format_trajectory_for_hint( + chat_history: List[Dict[str, Any]], + max_turns: int = 15, + max_msg_chars: int = 3000, + max_total_chars: int = 150_000, +) -> str: + """Format chat_history into readable text for LLM hint synthesis. + + Truncates to the last `max_turns` messages, caps individual messages, + and enforces a total character budget. + """ + if not chat_history: + return "(empty trajectory)" + + # Take last N turns + recent = chat_history[-max_turns:] + parts = [] + total = 0 + + for msg in recent: + role = msg.get("role", "unknown") + content = msg.get("content", "") + + # Handle list-type content (multimodal messages) + if isinstance(content, list): + text_parts = [] + for block in content: + if isinstance(block, dict): + if block.get("type") == "text": + text_parts.append(block.get("text", "")) + elif block.get("type") == "image_url": + text_parts.append("[image]") + elif block.get("type") == "tool_use": + name = block.get("name", "unknown_tool") + inp = str(block.get("input", ""))[:500] + text_parts.append(f"[tool_use: {name}({inp})]") + elif block.get("type") == "tool_result": + text_parts.append(f"[tool_result: {str(block.get('content', ''))[:500]}]") + else: + text_parts.append(str(block)[:200]) + else: + text_parts.append(str(block)[:200]) + content = "\n".join(text_parts) + + if isinstance(content, str) and len(content) > max_msg_chars: + content = content[:max_msg_chars] + f"... [truncated, {len(content)} chars total]" + + line = f"[{role}]: {content}" + if total + len(line) > max_total_chars: + parts.append(f"... [trajectory truncated at {max_total_chars} chars]") + break + parts.append(line) + total += len(line) + + return "\n\n".join(parts) + + +def format_verifier_feedback( + verifier_stdout: Optional[str], + verifier_error: Optional[str], + tool_error_messages: Optional[List[str]], +) -> str: + """Extract verifier errors/successes and tool errors into readable text.""" + import ast + import re + + parts = [] + + if verifier_stdout: + err_match = re.search( + r">>> ERROR_ACCUMULATOR >>>\n(.+?)\n<<< ERROR_ACCUMULATOR <<<", + verifier_stdout, + re.DOTALL, + ) + suc_match = re.search( + r">>> SUCCESS_ACCUMULATOR >>>\n(.+?)\n<<< SUCCESS_ACCUMULATOR <<<", + verifier_stdout, + re.DOTALL, + ) + if err_match or suc_match: + try: + errors = ast.literal_eval(err_match.group(1)) if err_match else [] + successes = ast.literal_eval(suc_match.group(1)) if suc_match else [] + except Exception: + errors, successes = [], [] + if successes: + parts.append(f"Verifier checks PASSED ({len(successes)}):") + for s in successes[:10]: + parts.append(f" - {str(s)[:200]}") + if errors: + parts.append(f"Verifier checks FAILED ({len(errors)}):") + for e in errors[:10]: + parts.append(f" - {str(e)[:200]}") + + if verifier_error: + parts.append(f"Verifier error: {verifier_error[:500]}") + + if tool_error_messages: + unique = list(dict.fromkeys(tool_error_messages))[:10] + parts.append("Tool errors encountered:") + for e in unique: + parts.append(f" - {e[:300]}") + + return "\n".join(parts) if parts else "(no verifier feedback available)" + + +async def synthesize_hint( + task_prompt: str, + chat_history: List[Dict[str, Any]], + verifier_stdout: Optional[str], + verifier_error: Optional[str], + tool_error_messages: Optional[List[str]], + model: str = "openrouter/anthropic/claude-sonnet-4", + timeout: float = 30.0, + static_fallback_fn=None, +) -> Tuple[str, str]: + """Synthesize a hint from a failed trajectory using an LLM via litellm. + + Returns: + (hint_text, hint_category) where category is one of + CATEGORY_LLM, CATEGORY_STATIC, CATEGORY_LLM_FAILED. + """ + try: + from litellm import acompletion + except ImportError: + logger.warning("litellm not installed, falling back to static hints") + if static_fallback_fn: + return static_fallback_fn(verifier_stdout, verifier_error, tool_error_messages), CATEGORY_STATIC + return "The previous attempt failed. Try a different approach.", CATEGORY_STATIC + + trajectory_text = format_trajectory_for_hint(chat_history) + verifier_text = format_verifier_feedback(verifier_stdout, verifier_error, tool_error_messages) + + user_message = f"""## Task +{task_prompt[:5000]} + +## Agent Trajectory (last turns) +{trajectory_text} + +## Verifier Feedback +{verifier_text} + +Based on the trajectory and feedback above, provide 2-5 sentences of specific, actionable guidance for the agent's next attempt.""" + + try: + response = await asyncio.wait_for( + acompletion( + model=model, + max_tokens=300, + temperature=0.3, + messages=[ + {"role": "system", "content": HINT_SYSTEM_PROMPT}, + {"role": "user", "content": user_message}, + ], + ), + timeout=timeout, + ) + hint_text = response.choices[0].message.content.strip() + if hint_text: + return hint_text, CATEGORY_LLM + else: + logger.warning("LLM returned empty hint, falling back to static") + except asyncio.TimeoutError: + logger.warning(f"LLM hint synthesis timed out after {timeout}s") + except Exception as e: + logger.warning(f"LLM hint synthesis failed: {e}") + + # Fallback to static hint + if static_fallback_fn: + return static_fallback_fn(verifier_stdout, verifier_error, tool_error_messages), CATEGORY_LLM_FAILED + return "The previous attempt failed. Try a different approach.", CATEGORY_LLM_FAILED + + +async def synthesize_hints_batch( + hint_requests: List[Dict[str, Any]], + model: str = "openrouter/anthropic/claude-sonnet-4", + timeout: float = 30.0, + max_concurrency: int = 20, + static_fallback_fn=None, +) -> List[Tuple[str, str]]: + """Synthesize hints for a batch of failed trajectories concurrently. + + Args: + hint_requests: List of dicts with keys: + - task_prompt: str + - chat_history: List[Dict] + - verifier_stdout: Optional[str] + - verifier_error: Optional[str] + - tool_error_messages: Optional[List[str]] + - instance_id: str (for logging) + model: LLM model to use + timeout: per-request timeout + max_concurrency: max concurrent LLM calls + static_fallback_fn: fallback function for static hints + + Returns: + List of (hint_text, hint_category) tuples, one per request. + """ + if not hint_requests: + return [] + + sem = asyncio.Semaphore(max_concurrency) + start = time.monotonic() + + async def _synth(req: Dict[str, Any]) -> Tuple[str, str]: + async with sem: + return await synthesize_hint( + task_prompt=req["task_prompt"], + chat_history=req.get("chat_history", []), + verifier_stdout=req.get("verifier_stdout"), + verifier_error=req.get("verifier_error"), + tool_error_messages=req.get("tool_error_messages"), + model=model, + timeout=timeout, + static_fallback_fn=static_fallback_fn, + ) + + results = await asyncio.gather(*[_synth(req) for req in hint_requests]) + + elapsed = time.monotonic() - start + n_llm = sum(1 for _, cat in results if cat == CATEGORY_LLM) + n_fallback = sum(1 for _, cat in results if cat in (CATEGORY_STATIC, CATEGORY_LLM_FAILED)) + logger.info( + f"Hint synthesis batch: {len(results)} total, {n_llm} LLM-synthesized, " + f"{n_fallback} fallback, {elapsed:.1f}s elapsed" + ) + + return list(results) diff --git a/skyrl-gym/skyrl_gym/envs/fleet_task/tool_call_parser.py b/skyrl-gym/skyrl_gym/envs/fleet_task/tool_call_parser.py new file mode 100644 index 0000000000..bec243a9e4 --- /dev/null +++ b/skyrl-gym/skyrl_gym/envs/fleet_task/tool_call_parser.py @@ -0,0 +1,68 @@ +"""Tool call parser for LLM-generated tool calls. + +Parses tool calls from various tag-based formats commonly produced by LLMs: +- {"name": "...", "arguments": {...}} +- {"name": "...", "arguments": {...}} + +Handles missing closing tags (e.g., when is the stop string) +and repairs common JSON issues like missing trailing braces. +""" + +import json +import re +from typing import Any, Dict, Optional + + +def _try_parse_json(raw: str) -> Optional[Dict[str, Any]]: + """Try to parse JSON, repairing missing trailing braces if needed.""" + raw = raw.strip() + try: + parsed = json.loads(raw) + if isinstance(parsed, dict): + return parsed + except (json.JSONDecodeError, ValueError): + pass + + # Repair: models often drop trailing closing braces on nested JSON. + # Try appending up to 3 closing braces. + for extra in range(1, 4): + try: + parsed = json.loads(raw + "}" * extra) + if isinstance(parsed, dict): + return parsed + except (json.JSONDecodeError, ValueError): + continue + + return None + + +def parse_tool_call(action: str) -> Optional[Dict[str, Any]]: + """Parse tool call from LLM response. + + Supports tag-based formats: + - {"name": "...", "arguments": {...}} + - {"name": "...", "arguments": {...}} + + Also handles cases where the closing tag is missing (e.g., when + is used as the stop string and not included in the output). + + Returns: + Dict with "name" and "arguments" keys, or None if no tool call found. + """ + for tag in ["tool_call", "function_call"]: + # First try with closing tag + match = re.search(rf"<{tag}>(.*?)", action, re.DOTALL) + if not match: + # Try without closing tag (for when is the stop string) + match = re.search(rf"<{tag}>(.*?)(?:<\||\Z)", action, re.DOTALL) + if match: + parsed = _try_parse_json(match.group(1)) + if parsed is None: + continue + # Normalize keys + name = parsed.get("name") or parsed.get("tool") + args = parsed.get("arguments") or parsed.get("params", {}) + if name: + return {"name": name, "arguments": args} + + return None diff --git a/skyrl-gym/skyrl_gym/envs/task_gen/__init__.py b/skyrl-gym/skyrl_gym/envs/task_gen/__init__.py new file mode 100644 index 0000000000..b5c5a7e88c --- /dev/null +++ b/skyrl-gym/skyrl_gym/envs/task_gen/__init__.py @@ -0,0 +1,5 @@ +from skyrl_gym.envs.task_gen.task_gen_env import TaskGenEnv +from skyrl_gym.envs.task_gen.tool_call_parser import parse_tool_call +from skyrl_gym.envs.task_gen.verifier_sandbox import VerifierSandbox + +__all__ = ["TaskGenEnv", "VerifierSandbox", "parse_tool_call"] diff --git a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py new file mode 100644 index 0000000000..c966e133d0 --- /dev/null +++ b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py @@ -0,0 +1,1472 @@ +""" +Task Generation Environment for SkyRL. + +Multi-turn BaseTextEnv where the LLM can explore the seed database via +``query_db`` meta-tool before generating a task. + +When ``max_turns > 1`` (the default), the model explores the DB first +and then produces a ```` block. When ``max_turns == 1`` it +behaves identically to the original single-turn variant. + +Reward: + + R(task) = base_quality + llm_validity * (alpha * var(raw_scores) + (p_hint - p_raw)) + + base_quality: Small reward for passing sandbox+judge (default 0.1) + llm_validity: Binary 0/1 from LLM-as-a-judge (is the task well-formed?) + var(raw_scores): Variance of k raw evaluator rollouts (difficulty calibration) + p_hint - p_raw: Hint gap — solvable with hints but not without (learnability) + alpha: Weight balancing variance vs hint gap (default 0.5) +""" + +import ast +import asyncio +import json +import logging +import os +import re +import time +import uuid +from typing import Any, Dict, List, Optional, Tuple + +from omegaconf import DictConfig + +from skyrl_gym.envs.base_text_env import ( + BaseTextEnv, + BaseTextEnvStepOutput, + ConversationType, +) +from skyrl_gym.envs.task_gen.tool_call_parser import parse_tool_call, parse_tool_calls +from skyrl_gym.envs.task_gen.verifier_sandbox import ( + VerifierSandbox, + parse_task_output, +) + +logger = logging.getLogger(__name__) + +def _format_compact_schema(describe_result: Any) -> str: + """Convert a DescribeResponse dict to compact 'table: col (type), ...' format.""" + if not isinstance(describe_result, dict): + return str(describe_result) if describe_result else "" + tables = describe_result.get("tables") + if not tables or not isinstance(tables, list): + return "" + lines = [] + for t in tables: + name = t.get("name", "") + cols = t.get("columns", []) + col_parts = [] + for c in cols: + col_name = c.get("name", "") + col_type = c.get("type", "").lower() + col_parts.append(f"{col_name} ({col_type})" if col_type else col_name) + lines.append(f"{name}: {', '.join(col_parts)}") + return "\n".join(lines) + + +# Meta-tools the model can call to explore the seed database. +_META_TOOLS = {"query_db"} + +# All callable tools = meta-tools + any MCP env tools discovered at init time. +# Populated per-instance in init_async(). + + +class TaskGenEnv(BaseTextEnv): + """Environment for RL-based task generation. + + The LLM generates (prompt, verifier) pairs for Fleet environments. + Supports multi-turn: the model can explore the seed DB via ``query_db`` + meta-tool before outputting a ```` block. Schema is in the prompt. + + Reward = llm_validity * (alpha * var(raw_scores) + (p_hint - p_raw)) + + Evaluation uses Fleet harness jobs (POST /v1/jobs) to run an LLM agent + against the generated task, rather than a stub evaluator. + + Constructor args (via extras, from dataset): + env_key, env_version, data_key, data_version + env_tools, env_tools_schema, env_variable_keys + + Constructor args (via env_config, from Hydra): + max_turns: Max turns before forced termination (default 10) + judge_model: Model ID for LLM-as-a-judge gate + k_rollouts: Number of rollouts per condition (raw/hinted, default 4) + max_eval_steps: Max agent steps per evaluator rollout (default 30) + evaluator_model: Fleet harness model for task evaluation (default anthropic/claude-sonnet-4.5) + base_quality_reward: Small reward for passing sandbox+judge (default 0.1). + Prevents GRPO zero-signal deadlock when all harness evals fail. + """ + + def __init__( + self, + env_config: DictConfig, + extras: Dict[str, Any] = {}, + ): + super().__init__() + + # Configurable multi-turn (default 10; set to 1 for single-turn) + self.max_turns = int(env_config.get("max_turns", 10)) if env_config else 10 + + # Fleet orchestrator for DB exploration (set in init_async) + self.orch = None + # MCP tools client for calling env tools (set in init_async) + self.mcp_tools = None + # Set of all callable tool names (meta-tools + MCP tools) + self.callable_tools = set(_META_TOOLS) + # Exploration sequence tracking (reset in init_async) + self.called_query_db = False + + # Environment context from dataset (extras) + self.env_key = extras.get("env_key") or extras.get("data_source", "unknown") + self.env_version = extras.get("env_version", "") + self.data_key = extras.get("data_key", "") + self.data_version = extras.get("data_version", "") + + # Parse env_tools_schema (full tool schemas for prompt building) + env_tools_schema_raw = extras.get("env_tools_schema", "[]") + if isinstance(env_tools_schema_raw, str): + try: + self.env_tools_schema: List[Dict[str, Any]] = json.loads(env_tools_schema_raw) + except json.JSONDecodeError: + self.env_tools_schema: List[Dict[str, Any]] = [] + else: + self.env_tools_schema: List[Dict[str, Any]] = env_tools_schema_raw or [] + + # Parse env_tools (tool name list for sandbox validation) + env_tools_raw = extras.get("env_tools", []) + if isinstance(env_tools_raw, str): + try: + self.env_tools: List[str] = json.loads(env_tools_raw) + except json.JSONDecodeError: + self.env_tools: List[str] = [] + else: + self.env_tools: List[str] = env_tools_raw or [] + + # If env_tools is empty but we have schemas, extract names from schemas + if not self.env_tools and self.env_tools_schema: + self.env_tools = [ + t["function"]["name"] for t in self.env_tools_schema if "function" in t and "name" in t["function"] + ] + + # Parse env_variable_keys (available context variables for this env) + env_var_keys_raw = extras.get("env_variable_keys", "[]") + if isinstance(env_var_keys_raw, str): + try: + self.env_variable_keys: List[str] = json.loads(env_var_keys_raw) + except json.JSONDecodeError: + self.env_variable_keys: List[str] = [] + else: + self.env_variable_keys: List[str] = env_var_keys_raw or [] + + # Parse env_variables (actual values for harness evaluation) + env_vars_raw = extras.get("env_variables", "{}") + if isinstance(env_vars_raw, str): + try: + self.env_variables: Dict[str, Any] = json.loads(env_vars_raw) + except json.JSONDecodeError: + self.env_variables: Dict[str, Any] = {} + else: + self.env_variables: Dict[str, Any] = env_vars_raw or {} + + # Parse env_schema (compact DB schema: table→columns) + self.env_schema: str = extras.get("env_schema", "") or "" + + # Verifier sandbox — filters out CUA-only tool "computer" from available tools + api_tools = set(self.env_tools) - {"computer"} if self.env_tools else None + self.sandbox = VerifierSandbox(available_tools=api_tools if api_tools else None) + + # Judge config (from Hydra env_config) + self.judge_model = str(env_config.get("judge_model", "")) if env_config else "" + + # Evaluator config (from Hydra env_config) + self.k_rollouts = int(env_config.get("k_rollouts", 4)) if env_config else 4 + self.max_eval_steps = int(env_config.get("max_eval_steps", 30)) if env_config else 30 + self.evaluator_model = ( + str(env_config.get("evaluator_model", "anthropic/claude-sonnet-4.5")) + if env_config + else "anthropic/claude-sonnet-4.5" + ) + + # API keys from environment variables (set by SkyPilot YAML) + self.openrouter_api_key = os.environ.get("OPENROUTER_API_KEY", "") + self.fleet_api_key = os.environ.get("FLEET_API_KEY", "") + + # Eval mode: k=8 raw only (no hints) + self.is_eval = extras.get("training_phase") == "eval" + self.eval_k_rollouts = int(env_config.get("eval_k_rollouts", 8)) if env_config else 8 + # Whether to run hinted evaluation jobs (2nd harness job with verifier feedback). + # Default off — hints were net negative in iter#11 (verifier code dump confused evaluator). + self.enable_hints = bool(env_config.get("enable_hints", False)) if env_config else False + + # Lazy-init Fleet SDK client for harness evaluation + self._fleet_client = None + + # Rollout dump directory (full prompt/verifier/scores per eval) + default_rollout_dir = os.path.join(os.path.expanduser("~"), "reward_rollouts") + self._rollout_dir = os.environ.get("REWARD_ROLLOUT_DIR", default_rollout_dir) + os.makedirs(self._rollout_dir, exist_ok=True) + + # Base quality reward for tasks passing sandbox + judge gate. + # Provides GRPO gradient signal even when all harness evals return 0. + self.base_quality_reward = float(env_config.get("base_quality_reward", 0.1)) if env_config else 0.1 + + # Small per-tool-call reward to incentivize DB exploration (query_db). + # Default 0.0 = off (no behavior change for existing runs). + self.tool_call_reward_per_call = float(env_config.get("tool_call_reward_per_call", 0.0)) if env_config else 0.0 + + logger.info( + f"TaskGenEnv: env={self.env_key}, max_turns={self.max_turns}, " + f"judge={self.judge_model or 'none'}, " + f"tools={len(self.env_tools)}, k={self.k_rollouts}, eval_k={self.eval_k_rollouts}, " + f"evaluator={self.evaluator_model}, is_eval={self.is_eval}, " + f"base_quality={self.base_quality_reward}, tool_call_reward={self.tool_call_reward_per_call}" + ) + + def _format_tool_schema(self, tool: Dict[str, Any]) -> str: + """Format a single tool schema for the system prompt.""" + func = tool.get("function", {}) + name = func.get("name", "unknown") + desc = func.get("description", "") + params = func.get("parameters", {}) + properties = params.get("properties", {}) + required = set(params.get("required", [])) + + lines = [f"**{name}**: {desc}"] + if properties: + lines.append(" Parameters:") + for pname, pschema in properties.items(): + ptype = pschema.get("type", "any") + pdesc = pschema.get("description", "") + req_marker = " (required)" if pname in required else "" + lines.append(f" - {pname} ({ptype}{req_marker}): {pdesc}") + + return "\n".join(lines) + + def _build_system_prompt(self) -> str: + """Build the system prompt with environment context and priors.""" + parts = [] + + parts.append(f'You are a task designer for the "{self.env_key}" environment.') + + # --- Date context (critical for date-sensitive environments) --- + current_date = self.env_variables.get("CURRENT_DATE", "") + if current_date: + parts.append( + f"\n**IMPORTANT — Current Date: {current_date}**\n" + f"The environment's current date is {current_date}. " + "All dates in generated tasks MUST be on or after this date. " + "Do NOT use past dates — the environment will reject them " + "(e.g., check-in dates, event dates, appointment dates must be in the future)." + ) + + # --- A. Environment context (from tool discovery) --- + parts.append(f"\n## Environment: {self.env_key}") + parts.append("\n### Available Tools") + + # Filter out CUA-only "computer" tool — task-gen is for tool-use APIs + api_schemas = [t for t in self.env_tools_schema if t.get("function", {}).get("name") != "computer"] + api_tool_names = [t for t in self.env_tools if t != "computer"] + + if api_schemas: + # Compact format: name + description only (no parameter schemas) + # Full schemas make the prompt too long for envs with many tools + for tool in api_schemas: + func = tool.get("function", {}) + name = func.get("name", "unknown") + desc = func.get("description", "") + parts.append(f"- **{name}**: {desc}") + elif api_tool_names: + parts.append("\n".join(f"- {t}" for t in api_tool_names)) + else: + parts.append("No tools discovered for this environment.") + + # Environment variables (user context available at task runtime) + if self.env_variables: + parts.append("\n### Environment Variables (embed as constants)") + parts.append( + "These variables describe the user/session context. " + "**Embed them directly as string constants** in your verifier code. " + "Do NOT use `env.env_variables` — it is not available at verifier runtime." + ) + for var_key, var_val in self.env_variables.items(): + parts.append(f'- `{var_key}` = `"{var_val}"`') + parts.append( + "\nExample usage in verifier:\n" + "```python\n" + f'LOGGED_IN_USER = "{self.env_variables.get("LOGGED_IN_USER", "user@example.com")}"\n' + f'# Use as: rows = current.table("users").eq("email", LOGGED_IN_USER).all()\n' + "```" + ) + elif self.env_variable_keys: + parts.append("\n### Environment Variables") + parts.append( + "These variables parameterize each environment instance. " + "Look up values from the database instead of using env.env_variables." + ) + for var_key in self.env_variable_keys: + parts.append(f"- `{var_key}`") + + # Database schema (table names and columns) + if self.env_schema: + parts.append("\n### Database Schema") + parts.append( + "Use these exact table and column names in verifiers " + '(e.g., `current.table("bookings").eq("guest_email", val).all()`):' + ) + parts.append(f"```\n{self.env_schema}\n```") + + # --- B. Priors (concise, static, same for all envs) --- + # Date awareness guidance (prevents past-date failures in booking/ticketmaster) + if current_date: + date_guidance = ( + f"### Date Awareness\n" + f"The environment's current date is **{current_date}**. " + f"ALL dates in your task MUST be on or after {current_date}. " + "Tasks with past dates will always fail because the environment " + "rejects them (e.g., 'checkIn date cannot be in the past'). " + "Use `query_db` to check what date ranges exist in the data, " + "and always generate future dates." + ) + else: + date_guidance = ( + "### Date Awareness\n" + "If the environment works with dates, verify what date ranges " + "are valid before generating tasks. Use `query_db` to check." + ) + + # NOTE: env.env_variables is NOT available at verifier runtime (Fleet harness bug). + # Model is instructed to embed env var values as constants instead. + + parts.append( + f""" +## Verifier Guidelines + +The verifier checks whether the agent completed the task by inspecting database state changes. + +Signature: `def validate_task(env: Environment, final_answer: str | None = None) -> int` + +**IMPORTANT**: The function MUST be named `validate_task` and return `TASK_FAILED_SCORE` (0) or `TASK_SUCCESSFUL_SCORE` (1). + +### Verifier API +```python +env.instance.load() # Load current state (call first) +seed = env.db("seed") # Original DB before agent acted +current = env.db("current") # Current DB after agent acted + +# Query tables — ALL results are Python dicts, use row["column"] NOT row.column: +rows = current.table("table_name").eq("column", value).all() # -> List[dict] +row = current.table("table_name").eq("column", value).first() # -> dict or None +rows = current.table("table_name").neq("column", value).all() # -> List[dict] +count = current.table("table_name").eq("column", value).count() # -> int +rows = current.table("table_name").select("col1", "col2").all() # -> List[dict] +# Access fields: row["id"], row["name"], row["email"] — NEVER row.id or row.name +# Only methods: .table(), .eq(), .neq(), .select(), .all(), .first(), .count() +# NO .like(), .gt(), .lt(), .contains(), .in_() — use Python filtering instead + +# Compare seed vs current to detect NEW entries: +def find_new_entries(seed, current, table_name, id_field="id", filter_conditions=None): + before_query = seed.table(table_name) + after_query = current.table(table_name) + if filter_conditions: + for key, value in filter_conditions.items(): + before_query = before_query.eq(key, value) + after_query = after_query.eq(key, value) + before_ids = {{entry[id_field] for entry in before_query.select(id_field).all()}} + return [e for e in after_query.all() if e[id_field] not in before_ids] +``` + +### Error Tracking (REQUIRED) +Every verifier MUST track errors and successes using accumulator lists, and print them +before returning. This enables automated feedback for hint-based evaluation. + +```python +error_accumulator = [] +success_accumulator = [] + +# ... check conditions ... +if condition_met: + success_accumulator.append("[C] Booking was created") +else: + error_accumulator.append("[X] Expected booking not found") + +# ALWAYS print accumulators before returning: +if error_accumulator: + print(">>> ERROR_ACCUMULATOR >>>") + print(error_accumulator) + print("<<< ERROR_ACCUMULATOR <<<") +if success_accumulator: + print(">>> SUCCESS_ACCUMULATOR >>>") + print(success_accumulator) + print("<<< SUCCESS_ACCUMULATOR <<<") +``` + +### Verifier Template (follow this structure) +```python +def validate_task(env: Environment, final_answer: str | None = None) -> int: + error_accumulator = [] + success_accumulator = [] + env.instance.load() + seed = env.db("seed") + current = env.db("current") + + def find_new_entries(table_name, id_field="id", filter_conditions=None): + \"\"\"Compare seed vs current to find rows added by the agent. + + Args: + table_name: Table to compare. + id_field: Primary key column (default "id"). + filter_conditions: Optional dict of {{column: value}} filters + applied to BOTH seed and current before comparison. + + Returns: + List[dict] — rows present in current but not in seed. + \"\"\" + before_query = seed.table(table_name) + after_query = current.table(table_name) + if filter_conditions: + for key, value in filter_conditions.items(): + before_query = before_query.eq(key, value) + after_query = after_query.eq(key, value) + before_ids = set(entry[id_field] for entry in before_query.select(id_field).all()) + return [e for e in after_query.all() if e[id_field] not in before_ids] + + # --- Validation: use SET-BASED comparison, never row-index --- + # GOOD: compare by content/ID sets, order-independent + # expected_ids = {{"id_1", "id_2"}} + # actual_ids = {{row["id"] for row in new_entries}} + # if not expected_ids.issubset(actual_ids): ... + # + # BAD: comparing by row index (fragile, order-dependent) + # if new_entries[0]["id"] == "id_1": ... + + # Check conditions... + # On early failure: + if critical_failure: + error_accumulator.append("[X] Critical check failed") + print(">>> ERROR_ACCUMULATOR >>>") + print(error_accumulator) + print("<<< ERROR_ACCUMULATOR <<<") + return TASK_FAILED_SCORE + + # Final result: + if error_accumulator: + print(">>> ERROR_ACCUMULATOR >>>") + print(error_accumulator) + print("<<< ERROR_ACCUMULATOR <<<") + return TASK_FAILED_SCORE + print(">>> SUCCESS_ACCUMULATOR >>>") + print(success_accumulator) + print("<<< SUCCESS_ACCUMULATOR <<<") + return TASK_SUCCESSFUL_SCORE +``` + +### Rules +- **NEVER hardcode database IDs** (user_id, hotel_id, etc.) — always query the DB to find them +- **NEVER use `env.env_variables`** — it is not available at runtime. Embed env var values as string constants at the top of your verifier (e.g., `LOGGED_IN_USER = "riley3318"`) +- **DB rows are dicts** — use `row["id"]`, `row["name"]`, NOT `row.id`, `row.name`. Using dot notation will crash with `AttributeError: 'dict' object has no attribute 'id'` +- **Only use supported query methods**: `.eq()`, `.neq()`, `.select()`, `.all()`, `.first()`, `.count()`. NO `.like()`, `.gt()`, `.lt()`, `.order()`, `.limit()`, `.contains()`, `.in_()` — filter and sort in Python instead (e.g., `sorted([r for r in rows if r["score"] > 8.0], key=lambda r: r["score"], reverse=True)[:5]`) +- **`.eq()` takes exactly 2 args**: `.eq(column, value)`. NO operator arg like `.eq("rating", ">", 8)` — use Python: `[r for r in rows if r["rating"] > 8]` +- **Use timezone-tolerant comparisons** for datetimes — the DB may store `"2025-08-08T14:00:00Z"` while you expect `"2025-08-08T14:00:00"`. Use `.startswith()` or strip the trailing `"Z"` before comparing +- **If you use `.select()`, only access the selected columns** — accessing other columns raises `KeyError`. Prefer `.all()` without `.select()` unless you specifically need to limit columns +- **Define `find_new_entries` inside your verifier function** — it is NOT a built-in. Copy it from the template above into your `validate_task()` function body. Do NOT call `find_new_entries()` without defining it first +- **List comprehensions produce tuples if you use tuple syntax** — `[(a, b) for ...]` creates tuples, not dicts. If you need dict-like access later, keep the original dicts: `[row for row in rows if condition]` +- **NEVER hardcode expected values the agent must create** — e.g., don't check for a specific phone number or email the agent would need to invent. Instead, check that the field was changed from its original value: `current_val != seed_val` +- Look up the logged-in user by name/email from the users table, don't assume an ID +- Compare `seed` (before) vs `current` (after) to detect what the agent did +- Must return `TASK_FAILED_SCORE` on a fresh environment (before agent acts) +- **NEVER call `.table("X").all()` without a preceding `.eq()` or `.neq()` filter** — unfiltered `.all()` fetches every row, which is wasteful and causes warm-pool saturation with large tables. Always filter first: `current.table("orders").eq("user_id", uid).all()`. The only exception is inside `find_new_entries` where `.select(id_field).all()` fetches just IDs for comparison +- **Use order-independent (set-based) comparison** — never compare results by row index or list position. Rows may be returned in any order. Use sets: `actual_ids = {{r["id"] for r in rows}}; assert expected_ids.issubset(actual_ids)`. NEVER do `rows[0]["id"] == expected` — it breaks when row order changes +- **Verifier MUST return 0 on unmodified DB** — the verifier must fail when the agent has not acted. Always compare `seed` vs `current` state. A verifier that only checks `current` without comparing to `seed` is permissive — it may return 1 even when the agent did nothing. Pattern: `new_entries = find_new_entries("table"); if not new_entries: return TASK_FAILED_SCORE` +- Use `final_answer` for tasks that require the agent to report a value +- Reference actual tool names from this environment + +## Task Design Guidelines + +Design tasks that maximize learnability: an ideal task is one that a capable agent can solve with effort, but not trivially. Tasks that are too easy (always solved) or too hard (never solved) produce no learning signal. + +{date_guidance} + +### Realism +Write prompts as a real user would — natural language, concrete parameters, plausible intent. The task should sound like something a person would actually ask, not a test case. + +BAD: "Call get_user with id=5, then call update_user to set email to test@example.com" +GOOD: "Update the email address for Jamie Chen to jamie.chen@newdomain.com" + +### Avoiding Underspecification +A prompt is underspecified when multiple valid solutions exist but the verifier only accepts one. This creates false negatives — the agent solves the task correctly but gets reward 0. + +BAD prompt: "Find a designer in Mexico" (3 designers exist, verifier checks for one specific one) +FIX option 1: Make the prompt specific: "Find the designer in Mexico City who joined after 2023" +FIX option 2: Make the verifier accept all valid answers: check that ANY designer in Mexico is returned + +Use `query_db` to check the actual data before writing the prompt. If a query returns multiple rows, either narrow the prompt or widen the verifier. Always verify your assumptions by querying — don't guess. + +### Avoiding Overspecification +A prompt is overspecified when it dictates HOW to accomplish the task rather than WHAT outcome is needed. This makes the task trivially easy (no learning signal) and doesn't test real problem-solving. + +BAD: "First call list_tables, then call get_bookings with check_in_date='2024-03-15', then count the results and call submit_answer with the count" +GOOD: "How many bookings have a check-in date of March 15, 2024?" + +The prompt should specify the desired outcome. The agent should figure out which tools to use and in what order. + +### Complexity +Aim for tasks solvable in 2-8 tool calls. Tasks requiring 1 tool call are too easy (no signal). Tasks requiring 15+ calls are too hard (agent gives up). The sweet spot is 3-6 calls with some reasoning required. + +### Diversity +Vary tasks across multiple dimensions: +- Operations: reads (lookup, search, aggregate) AND writes (create, update, delete) +- Complexity: simple (2-3 tool calls) through moderate (4-8 tool calls with dependencies) +- Reasoning: some tasks need multi-step logic (find X, use X to look up Y, modify Y based on Z) +- Data entities: use different tables, columns, and relationships in the schema + +### Verifier-Prompt Consistency +The verifier must check exactly what the prompt asks — no more, no less. Before writing, verify: +1. Is there exactly one correct outcome for this prompt? (If not, widen the verifier or narrow the prompt) +2. Does the verifier return 0.0 on a fresh environment? (It must — the agent hasn't acted yet) +3. Does the verifier avoid hardcoded values? (Query the DB instead) +4. Could a different valid approach fool the verifier? (If so, fix the verifier to accept it)""" + ) + + # --- C. Exploration tools (multi-turn only) --- + if self.max_turns > 1: + parts.append( + """ +## Exploration Tools + +The database schema is provided above. Use BOTH `query_db` AND environment API tools during exploration. + +### Database Tools +{"name": "query_db", "arguments": {"sql": "SELECT * FROM table_name LIMIT 5"}} +Runs a read-only SQL query against the seed database. + +### Environment API Tools +{"name": "tool_name", "arguments": {"param": "value"}} +Calls the environment API tool and returns its result. **You MUST call at least one API tool** (e.g., searchEvents, getAvailability) during exploration to understand what the solver agent will experience. The solver uses these API tools, not SQL — if you only explore via SQL, you won't know whether the API tools actually work for your task. + +### Workflow +1. **Inspect data**: Call `query_db` to inspect real data (values, ranges, row counts). +2. **Try API tools**: Call at least one environment API tool to understand its behavior, input/output format, and what data it returns. This is critical — your task must be achievable using these tools. +3. **Draft a task idea**: Based on the data AND tool behavior you've observed. +4. **Validate**: Before outputting, verify: + - Does the data your prompt references actually exist? (Query to confirm.) + - Is the task achievable using the available API tools? (You tested them.) + - Does your verifier check for a DB mutation (e.g., new order, new cart item)? If so, does the task actually cause that mutation? + - Will the verifier return 0 on the unmodified DB? (If it uses `find_new_entries`, the task MUST involve a write action like buy/reserve/create — NOT just search/list.) +5. **Output**: Only when confident, output the final task in the format below.""" + ) + + # --- D. Few-shot examples removed --- + # Few-shot examples were removed because they anchored the model to + # generate near-copies of the examples (especially booking/wishlist tasks), + # causing mode collapse and zero reward signal. The verifier template + + # guidelines above provide enough structure for the model to generate + # diverse tasks from the actual DB schema and tools. + + # --- E. Output format --- + parts.append( + """ +## Output Format + +Generate exactly ONE task. Output it in this format: + + + +[Natural language task instruction for the agent. Be specific about what needs to be done.] + + +[Python function: def validate_task(env, final_answer=None) -> int] + +""" + ) + + return "\n".join(parts) + + def _judge_task(self, prompt: str, verifier: str) -> float: + """LLM classifier gate: returns 0.0 (reject) or 1.0 (accept). + + Predicts whether the (prompt, verifier) pair will produce meaningful + evaluation signal. Optimized for very low false positive rate — only + rejects tasks that are near-certain to waste harness compute. + + Checks: + 1. Phantom tables: verifier references tables not in env schema + 2. Undefined references: calls to functions/constants not defined + 3. Vacuous checks: verifier only checks user existence or len>0 + """ + if not self.judge_model or not self.openrouter_api_key: + return 1.0 # No judge configured, pass through + + # Build context for the classifier + tool_names = [t for t in self.env_tools if t != "computer"] + tools_str = ", ".join(tool_names[:20]) if tool_names else "none discovered" + + schema_block = self.env_schema if self.env_schema else "Schema not available." + + judge_prompt = ( + "You are a verifier quality judge for an AI task-generation system. You evaluate " + "whether a generated verifier function can reliably determine if an AI agent " + "correctly completed a task.\n\n" + "## Context\n\n" + "The verifier has access to:\n" + "- `env.db(\"seed\")` — database state BEFORE the agent acted\n" + "- `env.db(\"current\")` — database state AFTER the agent acted\n" + "- `final_answer` — the agent's text response\n" + "- DB query methods: `.table(name)`, `.eq(col, val)`, `.first()`, `.all()`, " + "`.select()`, `.neq()`, `.gt()`, `.lt()`\n\n" + f"Database schema (valid tables and columns):\n```\n{schema_block}\n```\n\n" + f'Environment: "{self.env_key}"\n' + f"Available tools: {tools_str}\n\n" + "## Classification Criteria\n\n" + "### ACCEPT if the verifier does ANY of:\n\n" + "1. **Mutation verification**: Compares seed vs current database state to detect " + "that the agent created, modified, or deleted records.\n\n" + "2. **DB-grounded answer validation**: Queries the database for specific records " + "and validates that values FROM those records appear in `final_answer`. The " + "expected values must come from the database, not from hardcoded strings or " + "the task prompt.\n\n" + "3. **Specific record validation**: Looks up a record by ID or unique field and " + "checks its field values match expected values.\n\n" + "### REJECT if the verifier does ANY of:\n\n" + "1. **Generic keyword checking**: Checks if generic category words appear in " + "`final_answer` (e.g., \"event\", \"venue\", \"concert\", \"price\", \"bedroom\", " + "\"listing\"). These words appear in any topically-relevant response regardless " + "of task completion.\n\n" + "2. **Prompt echo checking**: Checks if values from the task prompt appear in " + "`final_answer` (e.g., \"Los Angeles\" when the prompt asked about LA events). " + "The agent could echo prompt values without doing any work.\n\n" + "3. **Exists-check-only**: Only checks `final_answer is not None` or " + "`len(answer) > 0`.\n\n" + "4. **Dead code DB queries**: Has `seed.table()` or `current.table()` calls but " + "never uses the query results in conditional logic that affects the return value.\n\n" + "5. **Nonexistent API access**: References `env.instance.tool_calls`, " + "`get_call_history()`, or `env.call_tool()` — these don't exist in the verifier " + "runtime.\n\n" + "6. **Cargo-cult DB**: Queries the DB only for user/account existence (which always " + "passes for pre-existing entities), then gates on keyword checks for actual " + "validation.\n\n" + "7. **Phantom tables**: The verifier calls `.table(\"X\")` where X does not exist " + "in the schema above.\n\n" + "8. **Undefined references**: The verifier calls functions or uses constants that " + "are not defined in the code and are not Python builtins.\n\n" + "### Edge Cases:\n\n" + "- Read-only tasks with DB-grounded keywords: ACCEPT — if the verifier queries a " + "DB table to get specific values then checks those values appear in `final_answer`.\n" + "- JSON structure validation without DB cross-reference: REJECT.\n" + "- Existence checks on initially-empty tables (e.g., orders after \"place order\"): " + "weak ACCEPT.\n\n" + f"## Generated Task\n\n" + f"Prompt:\n{prompt}\n\n" + f"Verifier:\n```python\n{verifier}\n```\n\n" + "Answer with exactly one word: ACCEPT or REJECT" + ) + + try: + import litellm + + response = litellm.completion( + model=f"openrouter/{self.judge_model}", + messages=[{"role": "user", "content": judge_prompt}], + temperature=0, + max_tokens=10, + api_key=self.openrouter_api_key, + ) + answer = response.choices[0].message.content.strip().upper() + accepted = "ACCEPT" in answer and "REJECT" not in answer + logger.info( + f"LLM classifier [{self.env_key}]: {answer} -> " + f"{'ACCEPT' if accepted else 'REJECT'}" + ) + return 1.0 if accepted else 0.0 + except Exception as e: + logger.warning(f"LLM classifier failed, defaulting to accept: {e}") + return 1.0 + + @staticmethod + def _build_hint_text( + verifier_stdout: Optional[str], + verifier_error: Optional[str], + tool_error_messages: Optional[List[str]], + ) -> str: + """Build hint text from verifier feedback. No LLM call. + + Parses ERROR_ACCUMULATOR / SUCCESS_ACCUMULATOR from verifier stdout + and formats tool errors into structured feedback for hinted rollouts. + """ + parts: List[str] = [] + + if verifier_stdout: + err_match = re.search( + r">>> ERROR_ACCUMULATOR >>>\n(.+?)\n<<< ERROR_ACCUMULATOR <<<", + verifier_stdout, + re.DOTALL, + ) + suc_match = re.search( + r">>> SUCCESS_ACCUMULATOR >>>\n(.+?)\n<<< SUCCESS_ACCUMULATOR <<<", + verifier_stdout, + re.DOTALL, + ) + if err_match or suc_match: + try: + errors = ast.literal_eval(err_match.group(1)) if err_match else [] + successes = ast.literal_eval(suc_match.group(1)) if suc_match else [] + except Exception: + errors, successes = [], [] + if successes: + parts.append(f"Checks passed ({len(successes)}): " + ", ".join(str(s)[:100] for s in successes[:5])) + if errors: + parts.append(f"Checks failed ({len(errors)}): " + ", ".join(str(e)[:100] for e in errors[:5])) + + if verifier_error: + parts.append(f"Verifier: {verifier_error}") + + if tool_error_messages: + unique = list(dict.fromkeys(tool_error_messages))[:5] + parts.append("Tool errors: " + "; ".join(e[:200] for e in unique)) + + return "\n".join(parts) if parts else "The previous attempt failed. Try a different approach." + + def _get_fleet_client(self): + """Lazy-init Fleet SDK client.""" + if self._fleet_client is None: + from fleet import Fleet + + self._fleet_client = Fleet(api_key=self.fleet_api_key) + return self._fleet_client + + async def _poll_job(self, fleet, job_id: str, poll_interval: int = 10, timeout: int = 600) -> str: + """Poll Fleet job until completion or timeout. + + Returns: + Final job status string. + """ + start = time.time() + while time.time() - start < timeout: + try: + job = fleet.get_job(job_id) + status = job.status + if status in ("completed", "cancelled", "errored"): + return status + except Exception as e: + logger.warning(f"Error polling job {job_id}: {e}") + await asyncio.sleep(poll_interval) + + logger.error(f"Job {job_id} timed out after {timeout}s") + return "timeout" + + def _query_supabase_scores(self, job_id: str) -> Dict[str, float]: + """Query Supabase for session verifier scores as fallback. + + When Fleet backend doesn't populate verifier_execution FK (regression + since 2026-03-23), the score is still available in session metadata. + + Returns: + Dict mapping session_id -> verifier_score. + """ + supabase_url = os.environ.get("SUPABASE_URL", "") + supabase_key = os.environ.get("SUPABASE_KEY", "") + if not supabase_url or not supabase_key: + return {} + try: + import httpx + + resp = httpx.get( + f"{supabase_url}/rest/v1/sessions", + params={"job_id": f"eq.{job_id}", "select": "id,metadata"}, + headers={ + "apikey": supabase_key, + "Authorization": f"Bearer {supabase_key}", + }, + timeout=10, + ) + if resp.status_code != 200: + logger.warning(f"Supabase query failed: {resp.status_code}") + return {} + scores = {} + for row in resp.json(): + meta = row.get("metadata") or {} + sid = row.get("id") + v_score = meta.get("verifier_score") + if sid and v_score is not None: + scores[sid] = float(v_score) + return scores + except Exception as e: + logger.warning(f"Supabase fallback failed: {e}") + return {} + + def _extract_job_results(self, fleet, job_id: str) -> List[Tuple[float, Optional[str], Optional[str]]]: + """Extract (score, verifier_stdout, verifier_error) from completed job sessions. + + Primary path: read from session.verifier_execution (Fleet SDK). + Fallback: query Supabase for metadata.verifier_score when VE is null + (Fleet backend regression since 2026-03-23 stopped populating VE FK). + + Returns: + List of (score, stdout, error) tuples per session. + """ + results: List[Tuple[float, Optional[str], Optional[str]]] = [] + sessions_response = fleet.list_job_sessions(job_id) + + # Check if any session has verifier_execution populated + all_ve_null = all(s.verifier_execution is None for tg in sessions_response.tasks for s in tg.sessions) + + # Fallback: query Supabase only when needed + supabase_scores: Dict[str, float] = {} + if all_ve_null: + supabase_scores = self._query_supabase_scores(job_id) + if supabase_scores: + logger.info(f"[{job_id[:8]}] Using Supabase fallback for {len(supabase_scores)} session scores") + + for task_group in sessions_response.tasks: + for session in task_group.sessions: + score = 0.0 + stdout = None + error = None + if session.verifier_execution: + if session.verifier_execution.score is not None: + score = float(session.verifier_execution.score) + elif session.verifier_execution.success: + score = 1.0 + stdout = getattr(session.verifier_execution, "stdout", None) + # Capture error from verifier crashes — error is nested in result.error + ve_result = getattr(session.verifier_execution, "result", None) + if ve_result: + ve_error = ( + ve_result.get("error") if isinstance(ve_result, dict) else getattr(ve_result, "error", None) + ) + if ve_error: + error = ( + ve_error.get("message", "") + if isinstance(ve_error, dict) + else getattr(ve_error, "message", "") + ) + traceback_str = ( + ve_error.get("traceback", "") + if isinstance(ve_error, dict) + else getattr(ve_error, "traceback", "") + ) + if traceback_str: + # Extract just the last line of traceback (the actual error) + error = traceback_str.strip().split("\n")[-1] if traceback_str else error + elif session.session_id in supabase_scores: + # Fallback: use Supabase metadata.verifier_score + score = supabase_scores[session.session_id] + results.append((score, stdout, error)) + return results + + async def _run_harness_job( + self, prompt: str, verifier: str, k: int + ) -> List[Tuple[float, Optional[str], Optional[str]]]: + """Run a single Fleet harness job and return per-session results + job ID. + + 1. Import task to Fleet + 2. Create harness job with pass_k=k + 3. Poll until completion + 4. Extract results + + Returns: + Tuple of (job_id, results) where results is a list of + (score, verifier_stdout, verifier_error) tuples. + job_id is None on failure. + """ + from fleet.tasks import Task + + fleet = self._get_fleet_client() + task_key = f"taskgen_{uuid.uuid4().hex[:12]}" + + task = Task( + key=task_key, + prompt=prompt, + env_id=self.env_key, + version=self.env_version or None, + verifier_func=verifier, + data_id=self.data_key or None, + data_version=self.data_version or None, + env_variables=self.env_variables or {}, + ) + + try: + import_response = fleet.import_single_task(task) + except Exception as e: + logger.error(f"[{task_key}] Failed to import task to Fleet: {e}") + return (None, [(0.0, None, None)] * k) + if import_response is None: + logger.error(f"[{task_key}] Failed to import task to Fleet (returned None, api_key set: {bool(self.fleet_api_key)})") + return (None, [(0.0, None, None)] * k) + + job_response = fleet.create_job( + models=[self.evaluator_model], + task_keys=[task_key], + pass_k=k, + max_steps=self.max_eval_steps, + mode="tool-use", + name=f"taskgen-eval-{task_key}", + ) + job_id = job_response.job_id + logger.info(f"[{task_key}] Harness job created: {job_id} (model={self.evaluator_model}, k={k})") + + status = await self._poll_job(fleet, job_id) + if status != "completed": + logger.warning(f"[{task_key}] Job {job_id} ended with status: {status}") + return (job_id, [(0.0, None, None)] * k) + + return (job_id, self._extract_job_results(fleet, job_id)) + + async def _evaluate_task(self, prompt: str, verifier: str) -> Dict[str, float]: + """Run hint-based evaluation via Fleet harness jobs. + + 1. Raw job: k rollouts without hints + 2. Build hint from first failing session's verifier stdout + 3. Hinted job: k rollouts with hint appended to prompt + 4. Compute reward via compute_task_reward() + + Returns: + Reward breakdown dict from compute_task_reward. + """ + from integrations.fleet.task_gen_reward import compute_task_reward + + zero_result = compute_task_reward([], [], validity=1.0) + + if not self.fleet_api_key: + return zero_result + + task_id = f"taskgen_{uuid.uuid4().hex[:8]}" + start = time.time() + + try: + # Eval: k=eval_k_rollouts for pass rate; Train: k=k_rollouts + eval_k = self.eval_k_rollouts if self.is_eval else self.k_rollouts + + # 1. Raw job: k rollouts without hints + raw_job_id, raw_results = await self._run_harness_job(prompt, verifier, k=eval_k) + raw_scores = [r[0] for r in raw_results] + + if self.enable_hints and not self.is_eval: + # Hinted training: k raw + k hinted for hint_gap signal + # 2. Build hint from first failing session's stdout/error + hint_stdout = None + hint_error = None + for score, stdout, error in raw_results: + if score < 1.0: + if stdout: + hint_stdout = stdout + if error: + hint_error = error + if hint_stdout or hint_error: + break + hint_text = self._build_hint_text(hint_stdout, hint_error, None) + + # Fallback: if hint is generic (no VE stdout due to backend regression), + # use the verifier source code as the hint. This tells the hinted agent + # exactly what checks to satisfy, creating hint_gap signal. + if hint_text == "The previous attempt failed. Try a different approach.": + # Truncate verifier to avoid blowing up prompt length + verifier_hint = verifier[:2000] + hint_text = ( + "Here is the verification function that will be used to check your work. " + "Make sure your actions satisfy all the checks:\n\n" + f"```python\n{verifier_hint}\n```" + ) + + # 3. Hinted job: k rollouts with hint + hinted_prompt = f"{prompt}\n\nHere is feedback from a previous attempt to help you:\n{hint_text}" + hinted_job_id, hinted_results = await self._run_harness_job(hinted_prompt, verifier, k=self.k_rollouts) + hinted_scores = [r[0] for r in hinted_results] + + # 4. Compute reward + result = compute_task_reward(raw_scores, hinted_scores, validity=1.0) + else: + # No hints — reward based on raw variance only + hinted_scores = [] + hinted_job_id = None + hint_text = "" + result = compute_task_reward(raw_scores, raw_scores, validity=1.0) + + duration = time.time() - start + + # --- Iron-clad eval logging --- + # Truncate prompt/verifier for log readability + prompt_log = prompt[:300].replace("\n", " ") + verifier_log = verifier[:200].replace("\n", " ") + hint_log = hint_text[:200].replace("\n", " ") + logger.info( + f"[{task_id}] EVAL | " + f"raw_job={raw_job_id} hinted_job={hinted_job_id} | " + f"raw={raw_scores} hinted={hinted_scores} | " + f"var={result['var_raw']:.4f} gap={result['hint_gap']:.4f} total={result['total']:.4f} | " + f"time={duration:.0f}s | " + f"prompt={prompt_log} | " + f"verifier={verifier_log} | " + f"hint={hint_log}" + ) + + # Save full rollout to local JSONL + self._save_rollout( + task_id=task_id, + env_key=self.env_key, + data_key=self.data_key, + prompt=prompt, + verifier=verifier, + hint=hint_text, + raw_scores=raw_scores, + hinted_scores=hinted_scores, + raw_job_id=raw_job_id, + hinted_job_id=hinted_job_id, + result=result, + duration=duration, + ) + + return result + + except Exception as e: + logger.error(f"[{task_id}] Evaluation failed: {e}") + return zero_result + + def _save_rollout( + self, + task_id, + env_key, + data_key, + prompt, + verifier, + hint, + raw_scores, + hinted_scores, + raw_job_id, + hinted_job_id, + result, + duration, + ): + """Append full rollout data to a local JSONL file.""" + try: + run_name = os.environ.get("RUN_NAME", "unknown") + path = os.path.join(self._rollout_dir, f"{run_name}.jsonl") + record = { + "task_id": task_id, + "env_key": env_key, + "data_key": data_key, + "prompt": prompt, + "verifier": verifier, + "hint": hint, + "raw_scores": raw_scores, + "hinted_scores": hinted_scores, + "raw_job_id": raw_job_id, + "hinted_job_id": hinted_job_id, + "var_raw": result["var_raw"], + "hint_gap": result["hint_gap"], + "total": result["total"], + "duration": duration, + "timestamp": time.time(), + } + with open(path, "a") as f: + f.write(json.dumps(record, ensure_ascii=False) + "\n") + except Exception as e: + logger.warning(f"[{task_id}] Failed to save rollout: {e}") + + async def _dryrun_verifier(self, verifier: str) -> Tuple[bool, str]: + """Run verifier against seed DB (no agent actions). Returns (ok, error_msg). + + A correct verifier should return 0 on unmodified DB (task not done yet). + Returns 1 → broken (permissive). Crashes → broken. + """ + if self.orch is None: + return True, "" # Can't dry-run without orchestrator, skip + try: + from fleet._async.tasks import Task as AsyncFleetTask + task = AsyncFleetTask( + key=f"dryrun_{self.env_key}", + prompt="dry-run", + env_id=self.env_key, + verifier_func=verifier, + ) + result = await task.verify_detailed_async(self.orch._fleet_env) + if result.success: + return False, "Verifier returned 1 on the unmodified database — it passes even when no agent has acted. Your verifier must return 0 on seed state. Check that your task involves a write/mutation action and your verifier checks for that mutation (e.g., find_new_entries)." + return True, "" + except Exception as e: + err_msg = str(e) + # Truncate long tracebacks + if len(err_msg) > 500: + err_msg = err_msg[:500] + "..." + return False, f"Verifier crashed on seed DB: {err_msg}" + + async def _handle_task_generation(self, action: str) -> BaseTextEnvStepOutput: + """Evaluate a generated task through the full pipeline. + + Pipeline: + 1. Parse output -> fail = reward 0 + 2. Sandbox validation -> fail = reward 0 + 3. Verifier dry-run on seed DB -> if broken, return feedback (retry) + 4. LLM-as-a-judge -> gate (0/1), fail = reward 0 + 5. Hint-based evaluation via Fleet harness (k raw + k hinted rollouts) + 6. R = base_quality + binary_eval_signal + + base_quality (default 0.1) rewards structural validity (sandbox+judge pass), + providing GRPO gradient signal even when harness evals return all zeros. + """ + metadata: Dict[str, Any] = {"env_key": self.env_key, "turn": self.turns} + max_turns_reached = self.turns >= self.max_turns + + # 1. Parse + parsed = parse_task_output(action) + if parsed is None: + metadata["error"] = "parse_failed" + metadata["reward_breakdown"] = {"total": 0.0} + return BaseTextEnvStepOutput(observations=[], reward=0.0, done=True, metadata=metadata) + + prompt = parsed["prompt"] + verifier = parsed["verifier"] + metadata["generated_prompt"] = prompt + metadata["generated_verifier"] = verifier + + # 2. Sandbox validation + validation = self.sandbox.validate(verifier, prompt) + metadata["validation"] = { + "valid": validation.valid, + "passed": validation.checks_passed, + "failed": validation.checks_failed, + "error": validation.error, + } + if not validation.valid: + if not max_turns_reached: + remaining = self.max_turns - self.turns + obs = {"role": "user", "content": f"Sandbox rejected your verifier: {', '.join(validation.checks_failed)}. Fix and resubmit. {remaining} turn(s) left."} + return BaseTextEnvStepOutput(observations=[obs], reward=0.0, done=False, metadata=metadata) + metadata["reward_breakdown"] = {"sandbox": 0.0, "total": 0.0} + return BaseTextEnvStepOutput(observations=[], reward=0.0, done=True, metadata=metadata) + + # 3. Verifier dry-run on seed DB + dryrun_ok, dryrun_error = await self._dryrun_verifier(verifier) + metadata["dryrun_ok"] = dryrun_ok + if not dryrun_ok: + logger.info(f"TaskGenEnv [{self.env_key}]: Verifier dry-run failed: {dryrun_error[:200]}") + if not max_turns_reached: + remaining = self.max_turns - self.turns + obs = {"role": "user", "content": f"⚠️ Verifier dry-run FAILED: {dryrun_error}\n\nFix your verifier and resubmit. {remaining} turn(s) left."} + return BaseTextEnvStepOutput(observations=[obs], reward=0.0, done=False, metadata=metadata) + metadata["reward_breakdown"] = {"dryrun": 0.0, "total": 0.0} + return BaseTextEnvStepOutput(observations=[], reward=0.0, done=True, metadata=metadata) + + # 4. LLM-as-a-judge gate + judge_gate = self._judge_task(prompt, verifier) + metadata["judge_gate"] = judge_gate + + if judge_gate == 0.0: + metadata["reward_breakdown"] = {"sandbox": 1.0, "judge": 0.0, "total": 0.0} + return BaseTextEnvStepOutput(observations=[], reward=0.0, done=True, metadata=metadata) + + # 5. Hint-based evaluation via Fleet harness + eval_result = await self._evaluate_task(prompt, verifier) + + # 6. R = base_quality + binary_eval_signal + base_quality = self.base_quality_reward + reward = base_quality + eval_result["total"] + + metadata["reward_breakdown"] = { + "sandbox": 1.0, + "dryrun": 1.0, + "judge": judge_gate, + "base_quality": base_quality, + **eval_result, + "total": reward, + } + + return BaseTextEnvStepOutput(observations=[], reward=reward, done=True, metadata=metadata) + + def step(self, action: str) -> BaseTextEnvStepOutput: + """Sync wrapper for step_async.""" + return asyncio.run(self.step_async(action)) + + async def step_async(self, action: str) -> BaseTextEnvStepOutput: + """Execute one step — tool call, task generation, or nudge. + + Multi-turn flow: + 1. block detected → evaluation pipeline (done=True) + 2. detected → execute query_db/MCP tools (done=False) + 3. Neither → nudge observation (done=False) + 4. max_turns reached → done=True, reward=0 + """ + self.turns += 1 + max_turns_reached = self.turns >= self.max_turns + + # 1. Check for block → evaluation pipeline + if "" in action: + # Exploration gate: in multi-turn mode, bounce back if model hasn't + # called query_db yet and still has turns remaining. Prevents + # single-turn collapse where model skips DB exploration entirely. + if self.max_turns > 1 and not self.called_query_db and not max_turns_reached: + remaining = self.max_turns - self.turns + nudge = ( + "You must explore the database with `query_db` before submitting a task. " + "Use SELECT queries to inspect actual data — table contents, value ranges, " + f"row counts — so your task and verifier are grounded in real data. " + f"You have {remaining} turn(s) remaining." + ) + observation = {"role": "user", "content": nudge} + return BaseTextEnvStepOutput( + observations=[observation], + reward=0.0, + done=False, + metadata={"env_key": self.env_key, "turn": self.turns, "exploration_gate": True}, + ) + return await self._handle_task_generation(action) + + # 2. Check for tool calls → execute all via Fleet orchestrator or MCP + tool_calls = parse_tool_calls(action) + tool_calls = [tc for tc in tool_calls if tc["name"] in self.callable_tools] + if tool_calls: + results = [] + for tc in tool_calls: + if tc["name"] in _META_TOOLS: + self.meta_tool_calls += 1 + if tc["name"] == "query_db": + self.called_query_db = True + result = await self._execute_meta_tool(tc) + else: + self.mcp_tool_calls += 1 + result = await self._execute_mcp_tool(tc) + results.append(f"[{tc['name']}] {result}") + + if max_turns_reached: + return BaseTextEnvStepOutput( + observations=[], + reward=0.0, + done=True, + metadata={"env_key": self.env_key, "turn": self.turns, "done_reason": "max_turns"}, + ) + + obs_content = "\n\n".join(results) + remaining = self.max_turns - self.turns + if remaining <= 3 and self.called_query_db: + obs_content += ( + f"\n\n⚠️ You have {remaining} turn(s) left. " + "You MUST output your block NOW. " + "Stop exploring and generate the task." + ) + observation = {"role": "user", "content": obs_content} + return BaseTextEnvStepOutput( + observations=[observation], + reward=0.0, + done=False, + metadata={"env_key": self.env_key, "turn": self.turns, "tool_calls": [tc["name"] for tc in tool_calls]}, + ) + + # 3. Neither task nor tool call → nudge + if max_turns_reached: + return BaseTextEnvStepOutput( + observations=[], + reward=0.0, + done=True, + metadata={ + "env_key": self.env_key, + "turn": self.turns, + "done_reason": "max_turns", + }, + ) + + remaining = self.max_turns - self.turns + if self.max_turns == 1: + nudge = "No block found. Output your task in ... format." + elif remaining <= 2: + nudge = ( + f"You have {remaining} turn(s) left. Output your block NOW or you will " + "get reward 0. Stop exploring and generate the task." + ) + else: + nudge = "Use to explore the database or call environment tools, then generate a block." + observation = {"role": "user", "content": nudge} + return BaseTextEnvStepOutput( + observations=[observation], + reward=0.0, + done=False, + metadata={"env_key": self.env_key, "turn": self.turns}, + ) + + async def _execute_meta_tool(self, tool_call: Dict[str, Any]) -> str: + """Execute a query_db meta-tool call via the Fleet orchestrator.""" + name = tool_call["name"] + args = tool_call.get("arguments", {}) + + if self.orch is None: + return "Error: Fleet environment not provisioned. Generate a directly." + + if name != "query_db": + return f"Error: Unknown meta-tool '{name}'." + + sql = args.get("sql", "") + if not sql: + return "Error: query_db requires a 'sql' argument." + + max_retries = 3 + for attempt in range(max_retries): + try: + result = await self.orch.query_db_async(sql=sql, db_name=args.get("db_name", "seed")) + if isinstance(result, dict): + # Truncate rows to save context — model only needs a sample + if "rows" in result and isinstance(result["rows"], list) and len(result["rows"]) > 5: + result["rows"] = result["rows"][:5] + result["message"] = f"Query returned more rows; showing first 5." + formatted = json.dumps(result, indent=2, default=str) + if len(formatted) > 3000: + formatted = formatted[:3000] + "\n... (truncated)" + return f"Tool result:\n{formatted}" + return f"Tool result:\n{str(result)[:3000]}" + except Exception as e: + if attempt < max_retries - 1 and ("closed" in str(e).lower() or "transport" in str(e).lower() or "connection" in str(e).lower()): + await asyncio.sleep(1) + continue + return f"Error: {e}" + + async def _execute_mcp_tool(self, tool_call: Dict[str, Any]) -> str: + """Execute an MCP tool call via FleetMCPTools.""" + name = tool_call["name"] + args = tool_call.get("arguments", {}) + + if self.mcp_tools is None: + return "Error: MCP tools not available. Use query_db or generate a ." + + try: + result = await self.mcp_tools.call_tool(name, args) + if isinstance(result, dict): + return f"Tool result:\n{json.dumps(result, indent=2, default=str)}" + return f"Tool result:\n{result}" + except Exception as e: + return f"Error calling {name}: {e}" + + async def init_async(self, prompt: ConversationType) -> Tuple[ConversationType, Dict[str, Any]]: + """Initialize the environment, optionally provisioning a Fleet env for DB exploration. + + When ``max_turns > 1``, provisions a Fleet environment via + ``FleetEnvClient.from_fleet_async`` so the model can call + ``query_db`` during exploration turns. + Falls back to single-turn if provisioning fails. + """ + self.turns = 0 + self.meta_tool_calls = 0 + self.mcp_tool_calls = 0 + self.called_query_db = False + self.orch = None + self.mcp_tools = None + self.callable_tools = set(_META_TOOLS) + + # Provision Fleet env for multi-turn exploration (DB + MCP tools) + if self.max_turns > 1 and self.fleet_api_key and self.data_key: + try: + from envs.fleet_env import FleetEnvClient + + self.orch, self.mcp_tools = await FleetEnvClient.from_fleet_async( + api_key=self.fleet_api_key, + env_key=self.env_key, + data_key=self.data_key, + data_version=self.data_version, + image_type="standard", + ttl_seconds=900, + ) + # Load instance resources so db("seed") works + # instance.load() is async — must await directly, not via to_thread + await self.orch._fleet_env.instance.load() + logger.info(f"TaskGenEnv [{self.env_key}]: Fleet env provisioned for DB + tool exploration") + + # Auto-populate env_schema from describe_db if not provided in dataset. + # Compact format: "table: col1 (type), col2 (type), ..." — one line per table. + if not self.env_schema: + try: + schema_result = await self.orch.describe_db_async(db_name="seed") + self.env_schema = _format_compact_schema(schema_result) + if self.env_schema: + logger.info(f"TaskGenEnv [{self.env_key}]: Auto-populated env_schema ({len(self.env_schema)} chars)") + except Exception as e: + logger.warning(f"TaskGenEnv [{self.env_key}]: Failed to auto-populate env_schema: {e}") + + # Discover MCP tools so the model can call them + if self.mcp_tools: + try: + tools_action = await self.mcp_tools.list_tools() + mcp_tools = [ + t for t in tools_action.tools if "function" in t and t["function"].get("name") != "computer" + ] + mcp_tool_names = {t["function"]["name"] for t in mcp_tools} + self.callable_tools = set(_META_TOOLS) | mcp_tool_names + # Update tool schemas for system prompt if dataset didn't have them + if not self.env_tools_schema: + self.env_tools_schema = mcp_tools + self.env_tools = [t["function"]["name"] for t in mcp_tools] + logger.info(f"TaskGenEnv [{self.env_key}]: {len(mcp_tool_names)} MCP tools available") + except Exception as e: + logger.warning(f"TaskGenEnv [{self.env_key}]: Failed to list MCP tools: {e}") + except Exception as e: + logger.warning( + f"TaskGenEnv [{self.env_key}]: Fleet provisioning failed, " f"falling back to single-turn: {e}" + ) + self.max_turns = 1 + + system_prompt = self._build_system_prompt() + + user_content = ( + f"Generate a task for the {self.env_key} environment. " + "First explore the database to understand the data, then draft a prompt and verifier. " + "Before outputting, query the DB to verify your assumptions are correct — " + "iterate on your draft until you're confident the data supports it." + if self.max_turns > 1 + else f"Generate a task for the {self.env_key} environment." + ) + + conversation = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_content}, + ] + + metadata = { + "env_key": self.env_key, + "env_version": self.env_version, + "num_tools": len(self.env_tools), + "multi_turn": self.max_turns > 1, + } + + return conversation, metadata + + def init(self, prompt: ConversationType) -> Tuple[ConversationType, Dict[str, Any]]: + """Sync wrapper for init_async.""" + return asyncio.run(self.init_async(prompt)) + + def close(self): + """Close the Fleet orchestrator if provisioned.""" + if self.orch is not None: + try: + self.orch.close() + except Exception: + pass + self.orch = None + + async def close_async(self): + """Async close — release Fleet orchestrator resources.""" + if self.orch is not None: + try: + await self.orch.close_async() + except Exception: + pass + self.orch = None + + def get_metrics(self) -> Dict[str, Any]: + """Return per-episode metrics.""" + return { + "env_key": self.env_key, + "turns": self.turns, + } + + @staticmethod + def aggregate_metrics(metrics: List[Dict[str, Any]]) -> Dict[str, Any]: + """Aggregate metrics across episodes.""" + if not metrics: + return {} + + # Group by env_key + env_counts: Dict[str, int] = {} + for m in metrics: + env_key = m.get("env_key", "unknown") + env_counts[env_key] = env_counts.get(env_key, 0) + 1 + + result = {"total_episodes": len(metrics)} + for env_key, count in env_counts.items(): + result[f"{env_key}/episodes"] = count + + return result diff --git a/skyrl-gym/skyrl_gym/envs/task_gen/tool_call_parser.py b/skyrl-gym/skyrl_gym/envs/task_gen/tool_call_parser.py new file mode 100644 index 0000000000..95e21912fd --- /dev/null +++ b/skyrl-gym/skyrl_gym/envs/task_gen/tool_call_parser.py @@ -0,0 +1,83 @@ +""" +Tool call parser for task generation environment. + +Parses and tagged JSON from LLM responses. +Copied from skyrl-train/integrations/fleet/env.py to avoid cross-package imports. +""" + +import json +import re +from typing import Any, Dict, List, Optional + + +def _try_parse_json(raw: str) -> Optional[Dict[str, Any]]: + """Try to parse JSON, repairing missing trailing braces if needed.""" + raw = raw.strip() + try: + parsed = json.loads(raw) + if isinstance(parsed, dict): + return parsed + except (json.JSONDecodeError, ValueError): + pass + + # Repair: models often drop trailing closing braces on nested JSON. + # Try appending up to 3 closing braces. + for extra in range(1, 4): + try: + parsed = json.loads(raw + "}" * extra) + if isinstance(parsed, dict): + return parsed + except (json.JSONDecodeError, ValueError): + continue + + return None + + +def _parse_one(match_text: str) -> Optional[Dict[str, Any]]: + """Parse a single tool call from matched text.""" + parsed = _try_parse_json(match_text) + if parsed is None: + return None + name = parsed.get("name") or parsed.get("tool") + args = parsed.get("arguments") or parsed.get("params", {}) + if name: + return {"name": name, "arguments": args} + return None + + +def parse_tool_call(action: str) -> Optional[Dict[str, Any]]: + """Parse the first tool call from LLM response. Returns None if not found.""" + calls = parse_tool_calls(action) + return calls[0] if calls else None + + +def parse_tool_calls(action: str) -> List[Dict[str, Any]]: + """Parse all tool calls from LLM response. + + Supports tag-based formats: + - {"name": "...", "arguments": {...}} + - {"name": "...", "arguments": {...}} + + Also handles cases where the closing tag is missing (e.g., when + is used as the stop string and not included in the output). + + Returns list of dicts with "name" and "arguments" keys. + """ + results: List[Dict[str, Any]] = [] + + for tag in ["tool_call", "function_call"]: + # Find all with closing tag + for match in re.finditer(rf"<{tag}>(.*?)", action, re.DOTALL): + parsed = _parse_one(match.group(1)) + if parsed: + results.append(parsed) + + # If none found with closing tags, try without (stop string case) + if not results: + match = re.search(rf"<{tag}>(.*?)(?:<\||\Z)", action, re.DOTALL) + if match: + parsed = _parse_one(match.group(1)) + if parsed: + results.append(parsed) + + return results diff --git a/skyrl-gym/skyrl_gym/envs/task_gen/verifier_sandbox.py b/skyrl-gym/skyrl_gym/envs/task_gen/verifier_sandbox.py new file mode 100644 index 0000000000..9d6c2ace44 --- /dev/null +++ b/skyrl-gym/skyrl_gym/envs/task_gen/verifier_sandbox.py @@ -0,0 +1,345 @@ +""" +Verifier sandbox for task generation. + +Validates generated verifier code via AST analysis and safe execution checks. +Used as the validity gate in the task generation reward: + R(task) = validity * (variance + alpha * separation) + +If validity returns 0, the entire reward is zeroed out. +""" + +import ast +import re +from dataclasses import dataclass, field +from typing import List, Optional, Set + + +@dataclass +class ValidationResult: + """Result of verifier validation.""" + + valid: bool + checks_passed: List[str] = field(default_factory=list) + checks_failed: List[str] = field(default_factory=list) + error: Optional[str] = None + + @property + def score(self) -> float: + """Return 1.0 if valid, 0.0 otherwise (multiplicative gate).""" + return 1.0 if self.valid else 0.0 + + +# Disallowed modules/builtins in verifier code +BLOCKED_IMPORTS = { + "os", + "sys", + "subprocess", + "shutil", + "pathlib", + "socket", + "http", + "urllib", + "requests", + "importlib", + "ctypes", + "signal", + "multiprocessing", + "threading", + "pickle", + "shelve", + "tempfile", + "glob", + "io", +} + +BLOCKED_BUILTINS = { + "exec", + "eval", + "compile", + "__import__", + "open", + "input", + "breakpoint", + "exit", + "quit", +} + +# Min/max AST node count for verifier complexity +MIN_AST_NODES = 5 # reject trivial verifiers like `return 1.0` +MAX_AST_NODES = 700 # reject overly complex verifiers + + +class VerifierSandbox: + """Validates and sandboxes generated verifier code. + + Performs static analysis to catch common issues before any execution: + 1. Python syntax validity (AST parsing) + 2. Function signature check (must be `async def verify(env, ...)`) + 3. Complexity bounds (not trivial, not overly complex) + 4. No dangerous imports or builtins + 5. Must reference env parameter (actually uses the environment) + 6. Prompt-verifier alignment (optional, LLM-based) + """ + + def __init__(self, available_tools: Optional[Set[str]] = None): + """ + Args: + available_tools: Set of tool names available in the target environment. + If provided, checks that verifier references at least one real tool. + """ + self.available_tools = available_tools or set() + + def validate( + self, + verifier_code: str, + prompt: Optional[str] = None, + ) -> ValidationResult: + """Run all validation checks on verifier code. + + Args: + verifier_code: The generated verifier Python code. + prompt: The associated task prompt (for alignment checks). + + Returns: + ValidationResult with pass/fail and details. + """ + result = ValidationResult(valid=True) + + # 1. Parse as valid Python + tree = self._check_syntax(verifier_code, result) + if tree is None: + result.valid = False + return result + + # 2. Check function signature + self._check_signature(tree, result) + + # 3. Check complexity bounds + self._check_complexity(tree, result) + + # 4. Check for dangerous imports/builtins + self._check_safety(tree, result) + + # 5. Check env usage + self._check_env_usage(tree, result) + + # 6. Check for hardcoded return values + self._check_hardcoded_returns(tree, result) + + # 7. Check for unfiltered .all() calls + self._check_unfiltered_all(tree, result) + + # 8. Check prompt length bounds (if prompt provided) + if prompt is not None: + self._check_prompt_bounds(prompt, result) + + # Any failed check -> invalid + if result.checks_failed: + result.valid = False + + return result + + def _check_syntax(self, code: str, result: ValidationResult) -> Optional[ast.AST]: + """Check that verifier code is valid Python.""" + try: + tree = ast.parse(code) + result.checks_passed.append("syntax") + return tree + except SyntaxError as e: + result.checks_failed.append("syntax") + result.error = f"SyntaxError: {e}" + return None + + def _check_signature(self, tree: ast.AST, result: ValidationResult): + """Check that verifier defines a valid function with env parameter. + + Accepts both `verify(env, ...)` and `validate_task(env, ...)` names, + both sync and async. + """ + valid_names = {"verify", "validate_task"} + for node in ast.walk(tree): + if isinstance(node, (ast.AsyncFunctionDef, ast.FunctionDef)): + if node.name in valid_names: + args = node.args + arg_names = [a.arg for a in args.args] + if "env" in arg_names: + result.checks_passed.append("signature") + return + else: + result.checks_failed.append("signature") + result.error = f"{node.name}() must have 'env' parameter, got: {arg_names}" + return + + result.checks_failed.append("signature") + result.error = "No verify(env, ...) or validate_task(env, ...) function found" + + def _check_complexity(self, tree: ast.AST, result: ValidationResult): + """Check AST node count is within bounds.""" + node_count = sum(1 for _ in ast.walk(tree)) + + if node_count < MIN_AST_NODES: + result.checks_failed.append("complexity_min") + result.error = f"Verifier too simple ({node_count} nodes < {MIN_AST_NODES})" + elif node_count > MAX_AST_NODES: + result.checks_failed.append("complexity_max") + result.error = f"Verifier too complex ({node_count} nodes > {MAX_AST_NODES})" + else: + result.checks_passed.append("complexity") + + def _check_safety(self, tree: ast.AST, result: ValidationResult): + """Check for dangerous imports and builtin calls.""" + for node in ast.walk(tree): + # Check imports + if isinstance(node, ast.Import): + for alias in node.names: + module = alias.name.split(".")[0] + if module in BLOCKED_IMPORTS: + result.checks_failed.append("safety_import") + result.error = f"Blocked import: {alias.name}" + return + + elif isinstance(node, ast.ImportFrom): + if node.module: + module = node.module.split(".")[0] + if module in BLOCKED_IMPORTS: + result.checks_failed.append("safety_import") + result.error = f"Blocked import from: {node.module}" + return + + # Check dangerous builtin calls + elif isinstance(node, ast.Call): + if isinstance(node.func, ast.Name): + if node.func.id in BLOCKED_BUILTINS: + result.checks_failed.append("safety_builtin") + result.error = f"Blocked builtin call: {node.func.id}" + return + + result.checks_passed.append("safety") + + def _check_env_usage(self, tree: ast.AST, result: ValidationResult): + """Check that the verifier actually uses the env parameter.""" + # Look for attribute access on 'env' (e.g., env.list_issues, env.get_data) + # or 'env' passed as argument to await expressions + env_used = False + for node in ast.walk(tree): + if isinstance(node, ast.Attribute): + if isinstance(node.value, ast.Name) and node.value.id == "env": + env_used = True + break + elif isinstance(node, ast.Call): + if isinstance(node.func, ast.Name) and node.func.id == "env": + env_used = True + break + + if env_used: + result.checks_passed.append("env_usage") + else: + result.checks_failed.append("env_usage") + result.error = "Verifier does not use 'env' parameter" + + def _check_hardcoded_returns(self, tree: ast.AST, result: ValidationResult): + """Check that verifier isn't just `return 1.0` or `return 0.0`.""" + valid_names = {"verify", "validate_task"} + verify_func = None + for node in ast.walk(tree): + if isinstance(node, (ast.AsyncFunctionDef, ast.FunctionDef)): + if node.name in valid_names: + verify_func = node + break + + if verify_func is None: + return # Already caught by signature check + + # Check if all return statements are constant + returns = [n for n in ast.walk(verify_func) if isinstance(n, ast.Return)] + if not returns: + result.checks_failed.append("hardcoded_return") + result.error = "Verifier has no return statements" + return + + all_constant = all(isinstance(r.value, ast.Constant) for r in returns if r.value is not None) + + if all_constant and len(returns) == 1: + result.checks_failed.append("hardcoded_return") + result.error = "Verifier always returns a constant value" + else: + result.checks_passed.append("return_logic") + + def _check_unfiltered_all(self, tree: ast.AST, result: ValidationResult): + """Reject verifiers that call .table("X").all() without a filter. + + Unfiltered .all() fetches every row from a table, causing warm-pool + saturation with large tables (6.5k zombie verifiers in production). + + Allowed patterns (filter present in chain): + .table("X").eq("col", val).all() + .table("X").neq("col", val).all() + .table("X").select("col1").all() # ID-only in find_new_entries + + Rejected pattern: + .table("X").all() # no filter before .all() + """ + for node in ast.walk(tree): + if not isinstance(node, ast.Call): + continue + # Match .all() call + if not (isinstance(node.func, ast.Attribute) and node.func.attr == "all"): + continue + # Walk up the chain: .all() is called on some object + receiver = node.func.value + # Check if the receiver is a .table() call (direct: .table("X").all()) + if self._is_table_call(receiver): + result.checks_failed.append("unfiltered_all") + result.error = ( + 'Unfiltered .all() on table — use .eq()/.neq()/.select() ' + 'before .all() (e.g., table("X").eq("col", val).all())' + ) + return + + result.checks_passed.append("filtered_all") + + @staticmethod + def _is_table_call(node: ast.AST) -> bool: + """Check if an AST node is a .table("...") call.""" + return ( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Attribute) + and node.func.attr == "table" + ) + + def _check_prompt_bounds(self, prompt: str, result: ValidationResult): + """Check that prompt is within reasonable length bounds.""" + word_count = len(prompt.split()) + + if word_count < 5: + result.checks_failed.append("prompt_length") + result.error = f"Prompt too short ({word_count} words < 5)" + elif word_count > 500: + result.checks_failed.append("prompt_length") + result.error = f"Prompt too long ({word_count} words > 500)" + else: + result.checks_passed.append("prompt_length") + + +def parse_task_output(action: str) -> Optional[dict]: + """Parse LLM output to extract task prompt and verifier code. + + Expected format: + + ... + ... + + + Returns: + Dict with 'prompt' and 'verifier' keys, or None if parsing fails. + """ + prompt_match = re.search(r"(.*?)", action, re.DOTALL) + verifier_match = re.search(r"(.*?)", action, re.DOTALL) + + if not prompt_match or not verifier_match: + return None + + return { + "prompt": prompt_match.group(1).strip(), + "verifier": verifier_match.group(1).strip(), + } diff --git a/skyrl-gym/skyrl_taste/__init__.py b/skyrl-gym/skyrl_taste/__init__.py new file mode 100644 index 0000000000..65aa9054ec --- /dev/null +++ b/skyrl-gym/skyrl_taste/__init__.py @@ -0,0 +1,13 @@ +"""skyrl_taste: thin async wrapper around the taste-judge for SkyRL GRPO. + +Public API: + score_trajectory_async(task, actions, outcome) -> Optional[float] + get_judge_provider_info() -> {"taste_judge_provider", "taste_judge_model"} + +Returns a value in [0, 1] (rescaled from the 1-5 weighted_total) or None +when the judge is disabled / errored. +""" + +from .judge import score_trajectory_async, get_judge_provider_info + +__all__ = ["score_trajectory_async", "get_judge_provider_info"] diff --git a/skyrl-gym/skyrl_taste/judge.py b/skyrl-gym/skyrl_taste/judge.py new file mode 100644 index 0000000000..3bd1459ead --- /dev/null +++ b/skyrl-gym/skyrl_taste/judge.py @@ -0,0 +1,177 @@ +"""skyrl_taste.judge +==================== + +Async wrapper around the synchronous taste judge defined in +`research/judge/judge.py`. Re-exposes the judge with the contract the +SkyRL Fleet env expects: + + async def score_trajectory_async(task, actions, outcome) -> Optional[float] + +Provider routing (env vars, read at call-time so swaps don't require a +restart of the process -- only a fresh rollout): +- ``SKYRL_TASTE_PROVIDER`` in {"anthropic", "openai", "openrouter"}. + Default: "openrouter" (cheapest production path). +- ``SKYRL_TASTE_MODEL``: model identifier. Default + "anthropic/claude-haiku-4.5" (an OpenRouter slug). For provider="anthropic" + this would be e.g. "claude-sonnet-4-6"; for provider="openai" e.g. + "gpt-4o-mini". +- ``SKYRL_TASTE_BLIND_OUTCOME``: "1" (default) suppresses the verifier + outcome from the judge prompt. Stream 4 found that exposing the outcome + causes taste scores to correlate ~0.7 with verifier (outcome bleed) and + collapses the shaping signal. + +Behavior: +- The underlying judges are *synchronous* and use blocking SDKs. We run + them in `asyncio.to_thread(...)` so they do not stall the event loop. + SkyRL's generator runs each rollout's `step_async` as its own asyncio + task, so judge calls across the GRPO group naturally overlap. +- Returns the rubric's `weighted_total`, rescaled from [1, 5] -> [0, 1] so + the blended reward stays in [0, 1] and existing pass@n / signal-ratio + metrics in `integrations/fleet/reward_metrics.py` keep working. +- Returns None on: + * `SKYRL_TASTE_DISABLED=1` (env-var bypass) + * The underlying judge returning a None-shaped result (parse / API failure) + The caller is expected to fall back to verifier-only reward when None. +- Screenshots are NOT passed in this version (text-only judge). Trade-off: + text-only keeps judge latency around 1-3 s/trajectory and avoids blowing + the judge's context with 50-80 base64 PNGs per browser_use rollout. We + lose direct ui_grounding signal, but the judge can still infer it from + action targets + tool errors. Re-enable screenshots later by sampling the + `tool_result` image_url blocks out of `chat_history` and threading them + through the judge call with `screenshots=...`. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +import sys +from pathlib import Path +from typing import Any, Callable, Optional + +logger = logging.getLogger("skyrl_taste") + +# Make the research-side judge importable. In a packaged install this would +# be a sibling import; for the launch-ready integration we add the research +# tree to sys.path so we don't have to vendor it. +_RESEARCH_JUDGE_DIR = Path(__file__).resolve().parents[2] / "research" / "judge" +if _RESEARCH_JUDGE_DIR.is_dir() and str(_RESEARCH_JUDGE_DIR) not in sys.path: + sys.path.insert(0, str(_RESEARCH_JUDGE_DIR)) + +try: + from judge import ( # type: ignore + score_trajectory as _score_trajectory_anthropic, + score_trajectory_gpt4o as _score_trajectory_openai, + score_trajectory_openrouter as _score_trajectory_openrouter, + ) +except Exception as e: # pragma: no cover + logger.warning("could not import research judge: %s", e) + _score_trajectory_anthropic = None # type: ignore[assignment] + _score_trajectory_openai = None # type: ignore[assignment] + _score_trajectory_openrouter = None # type: ignore[assignment] + + +_DEFAULT_PROVIDER = "openrouter" +_DEFAULT_MODEL = "anthropic/claude-haiku-4.5" + + +def _resolve_provider() -> tuple[str, str, bool, Optional[Callable[..., dict]]]: + """Read SKYRL_TASTE_PROVIDER / SKYRL_TASTE_MODEL / SKYRL_TASTE_BLIND_OUTCOME + and return (provider, model, blind_outcome, callable). The callable is + None if the corresponding research-side function failed to import.""" + provider = os.environ.get("SKYRL_TASTE_PROVIDER", _DEFAULT_PROVIDER).strip().lower() + model = os.environ.get("SKYRL_TASTE_MODEL", _DEFAULT_MODEL) + blind_outcome = os.environ.get("SKYRL_TASTE_BLIND_OUTCOME", "1") == "1" + + if provider == "anthropic": + return provider, model, blind_outcome, _score_trajectory_anthropic + if provider == "openai": + return provider, model, blind_outcome, _score_trajectory_openai + if provider == "openrouter": + return provider, model, blind_outcome, _score_trajectory_openrouter + logger.warning( + "unknown SKYRL_TASTE_PROVIDER=%r; falling back to %s", + provider, + _DEFAULT_PROVIDER, + ) + return _DEFAULT_PROVIDER, model, blind_outcome, _score_trajectory_openrouter + + +def _rescale_to_unit_interval(weighted_total: Optional[float]) -> Optional[float]: + """Rescale weighted_total from [1, 5] (rubric) to [0, 1] (RL reward). + + Returns None passthrough; clips defensively. + """ + if weighted_total is None: + return None + try: + v = (float(weighted_total) - 1.0) / 4.0 + except (TypeError, ValueError): + return None + if v < 0.0: + return 0.0 + if v > 1.0: + return 1.0 + return v + + +def get_judge_provider_info() -> dict[str, str]: + """Return the resolved (provider, model) for run-once metric logging.""" + provider, model, _, _ = _resolve_provider() + return {"taste_judge_provider": provider, "taste_judge_model": model} + + +async def score_trajectory_async( + task: str, + actions: list[dict[str, Any]], + outcome: bool, +) -> Optional[float]: + """Async-friendly entrypoint to the taste judge. + + Args: + task: natural-language task description (`task_config["prompt"]`). + actions: ordered list of action dicts pulled from the trajectory. + outcome: bool from the verifier (verifier_reward >= 1.0). + + Returns: + A scalar in [0, 1] = rescaled `weighted_total`, or None if the + judge is disabled or failed. The caller must treat None as + "fall back to verifier-only reward". + """ + if os.environ.get("SKYRL_TASTE_DISABLED") == "1": + # Hard kill switch for runtime rollback. + return None + + provider, model, blind_outcome, fn = _resolve_provider() + if fn is None: + logger.warning( + "taste judge module unavailable for provider=%s; returning None", + provider, + ) + return None + + # Run the blocking judge in a thread so we don't stall the event loop. + # screenshots=None: see module docstring for the rationale. + try: + result = await asyncio.to_thread( + fn, + task, + actions, + outcome, + None, # screenshots + model, + blind_outcome, + ) + except Exception as e: + logger.warning("taste judge (%s) raised in thread: %s", provider, e) + return None + + if not isinstance(result, dict): + return None + + if result.get("error"): + # The judge already logged; signal fall-back. + return None + + return _rescale_to_unit_interval(result.get("weighted_total")) diff --git a/skyrl/backends/skyrl_train/distributed/dispatch.py b/skyrl/backends/skyrl_train/distributed/dispatch.py index 18e56666d1..eea4d6a479 100644 --- a/skyrl/backends/skyrl_train/distributed/dispatch.py +++ b/skyrl/backends/skyrl_train/distributed/dispatch.py @@ -1,10 +1,13 @@ """Defines dispatch and collect logic for distributed training""" import asyncio +import logging from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Type +logger = logging.getLogger(__name__) + import ray from ray import ObjectRef from ray.actor import ActorHandle @@ -187,9 +190,22 @@ def stage_chunks( List of per-mini-batch chunk ref lists. ``result[i][dp_rank]`` is the ObjectRef for mini-batch *i*, DP rank *dp_rank*. """ - assert ( - len(data) % mini_batch_size == 0 - ), f"data batch size must be divisible by mini_batch_size, got {len(data)} and {mini_batch_size}" + # Hint augmentation can produce variable batch sizes that don't evenly + # divide the configured mini_batch_size. Rather than dropping samples + # (which wastes the expensive hint rollouts), reduce mini_batch_size to + # the largest value that divides both len(data) and dp_size. + if len(data) % mini_batch_size != 0: + original_mbs = mini_batch_size + # Step down by dp_size to stay dp-divisible + while mini_batch_size > 0 and (len(data) % mini_batch_size != 0 or mini_batch_size % dp_size != 0): + mini_batch_size -= dp_size + if mini_batch_size <= 0: + mini_batch_size = dp_size + logger.info( + f"Adjusted mini_batch_size from {original_mbs} to {mini_batch_size} " + f"to evenly divide batch of {len(data)} samples " + f"({len(data) // mini_batch_size} mini-batches)." + ) assert ( mini_batch_size % dp_size == 0 ), f"mini_batch_size must be divisible by dp_size, got {mini_batch_size} and {dp_size}" diff --git a/skyrl/backends/skyrl_train/distributed/fsdp_utils.py b/skyrl/backends/skyrl_train/distributed/fsdp_utils.py index 449cf0266e..c192632a88 100644 --- a/skyrl/backends/skyrl_train/distributed/fsdp_utils.py +++ b/skyrl/backends/skyrl_train/distributed/fsdp_utils.py @@ -23,6 +23,20 @@ import torch import torch.distributed as dist import torch.nn as nn + +# Patch torch.nn.Parameter.__new__ to accept and ignore _is_hf_initialized. +# accelerate's init_empty_weights passes param.__dict__ (which includes +# _is_hf_initialized set by transformers 5.x) to Parameter(), but torch 2.10 +# rejects unknown kwargs. This patch filters them out. +_orig_param_new = torch.nn.Parameter.__new__ + + +def _patched_param_new(cls, *args, **kwargs): + kwargs.pop("_is_hf_initialized", None) + return _orig_param_new(cls, *args, **kwargs) + + +torch.nn.Parameter.__new__ = _patched_param_new from omegaconf import DictConfig from packaging import version from peft.utils.save_and_load import get_peft_model_state_dict diff --git a/skyrl/backends/skyrl_train/inference_engines/base.py b/skyrl/backends/skyrl_train/inference_engines/base.py index c3040fc1a2..7d687926d8 100644 --- a/skyrl/backends/skyrl_train/inference_engines/base.py +++ b/skyrl/backends/skyrl_train/inference_engines/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Dict, Hashable, List, Optional, TypedDict +from typing import TYPE_CHECKING, Any, Dict, Hashable, List, Optional, TypedDict, Union if TYPE_CHECKING: from skyrl.backends.skyrl_train.weight_sync import WeightUpdateRequest @@ -7,7 +7,30 @@ WeightSyncInitInfo, ) -MessageType = Dict[str, str] + +# --- Multimodal Message Types (OpenAI-compatible) --- +class TextContent(TypedDict): + type: str # "text" + text: str + + +class ImageUrlContent(TypedDict): + url: str # "data:image/png;base64,..." or URL + + +class ImageContent(TypedDict): + type: str # "image_url" + image_url: ImageUrlContent + + +ContentType = Union[str, List[Union[TextContent, ImageContent]]] + + +class MessageType(TypedDict): + role: str + content: ContentType + + ConversationType = List[MessageType] @@ -17,6 +40,9 @@ class InferenceEngineInput(TypedDict): prompt_token_ids: Optional[List[List[int]]] sampling_params: Optional[Dict[str, Any]] session_ids: Optional[List[Hashable]] + # Multimodal data for VL models. Each element corresponds to a prompt. + # Format: {"image": [PIL.Image, ...]} or {"image": ["base64_string", ...]} + multi_modal_data: Optional[List[Optional[Dict[str, Any]]]] class InferenceEngineOutput(TypedDict): diff --git a/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py b/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py index 5f11761b91..271110156d 100644 --- a/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py +++ b/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py @@ -95,6 +95,7 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu prompt_token_ids = input_batch.get("prompt_token_ids") session_ids = input_batch.get("session_ids") sampling_params = input_batch.get("sampling_params") + multi_modal_data = input_batch.get("multi_modal_data") if (prompts is None and prompt_token_ids is None) or (prompts is not None and prompt_token_ids is not None): raise ValueError("Either `prompts` or `prompt_token_ids` must be provided, but not both.") @@ -122,9 +123,11 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu for engine_idx, prompt_ids in engine_idx_to_prompt_ids.items(): # index prompt_token_ids with prompt_ids cur_prompt_token_ids = [prompt_token_ids[i] for i in prompt_ids] + cur_mm_data = [multi_modal_data[i] for i in prompt_ids] if multi_modal_data else None engine_input = InferenceEngineInput( prompt_token_ids=cur_prompt_token_ids, sampling_params=sampling_params, + multi_modal_data=cur_mm_data, ) tasks.append(asyncio.create_task(self.engines[engine_idx].generate(engine_input))) indices_list.append(prompt_ids) diff --git a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py index 6d827ac327..fd28948fee 100644 --- a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py +++ b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py @@ -135,6 +135,7 @@ def _preprocess_prompts(self, input_batch: InferenceEngineInput): prompts = input_batch.get("prompts") prompt_token_ids = input_batch.get("prompt_token_ids") request_sampling_params = input_batch.get("sampling_params") + multi_modal_data = input_batch.get("multi_modal_data") assert ( prompts is None and prompt_token_ids is not None @@ -144,7 +145,7 @@ def _preprocess_prompts(self, input_batch: InferenceEngineInput): SamplingParams(**request_sampling_params) if request_sampling_params is not None else SamplingParams() ) - return prompt_token_ids, sampling_params + return prompt_token_ids, sampling_params, multi_modal_data def _postprocess_outputs(self, outputs): """Common output processing logic.""" @@ -247,7 +248,7 @@ def _create_engine(self, *args, **kwargs): return vllm.LLM(*args, **kwargs) async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput: - prompt_token_ids, sampling_params = self._preprocess_prompts(input_batch) + prompt_token_ids, sampling_params, multi_modal_data = self._preprocess_prompts(input_batch) # Check if LoRA is enabled and create LoRA requests lora_requests = None @@ -261,9 +262,18 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu LoRARequest(lora_name=f"{lora_int_id}", lora_int_id=lora_int_id, lora_path="/dummy_lora_path") ] * batch_size + # Build prompts with multimodal data for VL models + prompts = [] + for i, token_ids in enumerate(prompt_token_ids): + mm_data = multi_modal_data[i] if multi_modal_data and i < len(multi_modal_data) else None + if mm_data: + prompts.append({"prompt_token_ids": token_ids, "multi_modal_data": mm_data}) + else: + prompts.append(TokensPrompt(prompt_token_ids=token_ids)) + outputs = await asyncio.to_thread( self.llm.generate, - prompts=[TokensPrompt(prompt_token_ids=r) for r in prompt_token_ids], + prompts=prompts, sampling_params=sampling_params, lora_request=lora_requests, ) @@ -460,7 +470,13 @@ async def _load_lora_from_disk(self, lora_path: str): result = await self.llm.add_lora(lora_request) return result - async def _collect_outputs(self, prompt_token_ids, request_id: str, sampling_params: SamplingParams): + async def _collect_outputs( + self, + prompt_token_ids, + request_id: str, + sampling_params: SamplingParams, + multi_modal_data: Optional[Dict[str, Any]] = None, + ): """Collect outputs for a single prompt.""" # Check if LoRA is enabled and create LoRA request final_output = None @@ -475,8 +491,16 @@ async def _collect_outputs(self, prompt_token_ids, request_id: str, sampling_par lora_name=f"{lora_int_id}", lora_int_id=lora_int_id, lora_path="/dummy_lora_path" ) + # Build prompt with multimodal data for VL models + if multi_modal_data: + num_images = len(multi_modal_data.get("image", [])) + logger.info(f"VL generate: {num_images} images, {len(prompt_token_ids)} input tokens") + prompt = {"prompt_token_ids": prompt_token_ids, "multi_modal_data": multi_modal_data} + else: + prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) + async for request_output in self.llm.generate( - prompt=TokensPrompt(prompt_token_ids=prompt_token_ids), + prompt=prompt, sampling_params=sampling_params, request_id=request_id, lora_request=lora_request, @@ -487,14 +511,15 @@ async def _collect_outputs(self, prompt_token_ids, request_id: str, sampling_par async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput: """Generate responses using vLLM's async engine.""" - prompt_token_ids, sampling_params = self._preprocess_prompts(input_batch) + prompt_token_ids, sampling_params, multi_modal_data = self._preprocess_prompts(input_batch) tasks = [] - for prompt in prompt_token_ids: + for i, prompt in enumerate(prompt_token_ids): # Schedule the collection of outputs for each prompt. # Avoid duplicate request_ids request_id = str(uuid4().hex) - task = asyncio.create_task(self._collect_outputs(prompt, request_id, sampling_params)) + mm_data = multi_modal_data[i] if multi_modal_data and i < len(multi_modal_data) else None + task = asyncio.create_task(self._collect_outputs(prompt, request_id, sampling_params, mm_data)) tasks.append(task) outputs = await asyncio.gather(*tasks) diff --git a/skyrl/backends/skyrl_train/utils/ppo_utils.py b/skyrl/backends/skyrl_train/utils/ppo_utils.py index 189fecccc0..fe48a8dded 100644 --- a/skyrl/backends/skyrl_train/utils/ppo_utils.py +++ b/skyrl/backends/skyrl_train/utils/ppo_utils.py @@ -1192,17 +1192,24 @@ def compute_grpo_outcome_advantage( index: np.ndarray, epsilon: float = 1e-6, grpo_norm_by_std: bool = True, + is_hinted: Optional[np.ndarray] = None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: """ Compute advantage for GRPO, operating only on Outcome reward (with only one scalar reward for each response). + When ``is_hinted`` is provided, uses a **first-turn baseline**: the group mean and std + are computed from raw (unhinted) samples only. All samples (raw + hinted) are then + centered using this raw-only baseline. This prevents hinted samples from contaminating + the baseline for raw samples (RLTF-SD paper, Section 3.2). + Expects: - token_level_rewards: Float[torch.Tensor, "batch_size seqlen"] - response_mask: Float[torch.Tensor, "batch_size seqlen"] - index: np.ndarray (batch_size) - epsilon: float - grpo_norm_by_std: bool + - is_hinted: Optional[np.ndarray] bool array (batch_size), True for hinted samples Returns: - advantages: Float[torch.Tensor, "batch_size seqlen"] @@ -1211,23 +1218,50 @@ def compute_grpo_outcome_advantage( # this assumes response-level rewards scores = token_level_rewards.sum(dim=-1) - id2score = defaultdict(list) id2mean = {} id2std = {} + use_first_turn_baseline = is_hinted is not None and np.any(is_hinted) + with torch.no_grad(): bsz = scores.shape[0] - for i in range(bsz): - id2score[index[i]].append(scores[i]) - for idx in id2score: - if len(id2score[idx]) == 1: - id2mean[idx] = torch.tensor(0.0) - id2std[idx] = torch.tensor(1.0) - elif len(id2score[idx]) > 1: - id2mean[idx] = torch.mean(torch.tensor(id2score[idx])) - id2std[idx] = torch.std(torch.tensor([id2score[idx]])) - else: - raise ValueError(f"no score in prompt index: {idx}") + + if use_first_turn_baseline: + # First-turn baseline: compute mean/std from raw (unhinted) samples only + id2raw_scores = defaultdict(list) + for i in range(bsz): + if not is_hinted[i]: + id2raw_scores[index[i]].append(scores[i]) + + for idx in id2raw_scores: + raw = id2raw_scores[idx] + if len(raw) == 1: + id2mean[idx] = torch.tensor(0.0) + id2std[idx] = torch.tensor(1.0) + else: + id2mean[idx] = torch.mean(torch.tensor(raw)) + id2std[idx] = torch.std(torch.tensor([raw])) + + # For groups with only hinted samples (no raw), use 0 mean / 1 std + for i in range(bsz): + if index[i] not in id2mean: + id2mean[index[i]] = torch.tensor(0.0) + id2std[index[i]] = torch.tensor(1.0) + else: + # Standard GRPO: compute mean/std from all samples + id2score = defaultdict(list) + for i in range(bsz): + id2score[index[i]].append(scores[i]) + for idx in id2score: + if len(id2score[idx]) == 1: + id2mean[idx] = torch.tensor(0.0) + id2std[idx] = torch.tensor(1.0) + elif len(id2score[idx]) > 1: + id2mean[idx] = torch.mean(torch.tensor(id2score[idx])) + id2std[idx] = torch.std(torch.tensor([id2score[idx]])) + else: + raise ValueError(f"no score in prompt index: {idx}") + for i in range(bsz): if grpo_norm_by_std: scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon) diff --git a/skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py b/skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py index a0a1990f5d..d66ccbc7b7 100644 --- a/skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py +++ b/skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py @@ -188,6 +188,7 @@ def init_model(self, model_path, num_training_steps: int = None): rope_scaling=get_rope_scaling_config(self.cfg), rope_theta=get_rope_theta_config(self.cfg), model_config_kwargs=self.cfg.policy.model_config_kwargs, + loss_chunk_size=self.cfg.loss_chunk_size, ) # in-place patch self._seq_parallel_monkey_patch(model=wrapped_model.model) @@ -395,10 +396,17 @@ def forward( class FSDPRefWorkerBase(RefWorkerBase): def offload_to_cpu(self, pin_memory=True, non_blocking=True, **kwargs): self._set_numa_affinity(torch.distributed.get_rank() % torch.cuda.device_count()) - self.strategy.offload_to_cpu(self.model, None, pin_memory, non_blocking) + # Force synchronous transfers + barrier to prevent cudaErrorIllegalAddress + # when policy workers access GPU memory that ref workers are still offloading + # across nodes (no shared CUDA context in multi-node). + self.strategy.offload_to_cpu(self.model, None, pin_memory, non_blocking=False) + if torch.distributed.is_initialized(): + torch.distributed.barrier() def backload_to_gpu(self, non_blocking=True, **kwargs): - self.strategy.backload_to_gpu(self.model, None, non_blocking) + self.strategy.backload_to_gpu(self.model, None, non_blocking=False) + if torch.distributed.is_initialized(): + torch.distributed.barrier() def init_model(self, model_path): assert self.cfg.strategy in ("fsdp", "fsdp2") @@ -426,6 +434,7 @@ def init_model(self, model_path): rope_scaling=get_rope_scaling_config(self.cfg), rope_theta=get_rope_theta_config(self.cfg), model_config_kwargs=self.cfg.ref.model_config_kwargs, + loss_chunk_size=self.cfg.loss_chunk_size, ) self._seq_parallel_monkey_patch(model=wrapped_model.model) diff --git a/skyrl/backends/skyrl_train/workers/model_wrapper.py b/skyrl/backends/skyrl_train/workers/model_wrapper.py index 3eb45f80a7..5bd10112a4 100644 --- a/skyrl/backends/skyrl_train/workers/model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/model_wrapper.py @@ -8,7 +8,9 @@ import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F import transformers +from torch.utils.checkpoint import checkpoint as gradient_checkpoint from flash_attn.bert_padding import pad_input, unpad_input from loguru import logger from packaging.version import Version @@ -33,6 +35,36 @@ ) +def _chunked_logprobs_only(hidden_chunk, labels_chunk, weight, bias, temperature): + """Compute logprobs for one chunk. Used inside gradient_checkpoint. + Uses F.linear with raw weight/bias to avoid DTensor issues with FSDP2. + """ + logits = F.linear(hidden_chunk, weight, bias) + logits = logits / temperature + return logprobs_from_logits(logits, labels_chunk, inplace_backward=False) + + +def _chunked_logprobs_and_entropy(hidden_chunk, labels_chunk, weight, bias, temperature): + """Compute logprobs and entropy for one chunk. Used inside gradient_checkpoint. + Uses F.linear with raw weight/bias to avoid DTensor issues with FSDP2. + """ + logits = F.linear(hidden_chunk, weight, bias) + logits = logits / temperature + lp = logprobs_from_logits(logits, labels_chunk, inplace_backward=False) + log_softmax_vals = F.log_softmax(logits, dim=-1) + entropy = -(log_softmax_vals.exp() * log_softmax_vals).sum(dim=-1) + return lp, entropy + + +class _IdentityLMHead(nn.Module): + """Dummy lm_head that passes hidden states through unchanged. + Used to prevent the HF model from materializing the full (B, S, vocab) logits tensor. + """ + + def forward(self, x): + return x + + class HFModelWrapper(nn.Module): """ Base class for wrapped HF models in reinforcement learning. @@ -74,6 +106,7 @@ def __init__( use_liger_kernel=False, sequence_parallel_size=1, use_sample_packing: bool = False, + loss_chunk_size: int = 0, use_torch_compile: bool = False, rope_scaling: Dict[str, Any] = {}, rope_theta: float | None = None, @@ -85,6 +118,7 @@ def __init__( self.sequence_parallel_size = sequence_parallel_size self.attn_implementation = "flash_attention_2" if use_flash_attention_2 else "sdpa" self.use_sample_packing = use_sample_packing + self.loss_chunk_size = loss_chunk_size self.is_vlm = False # packing samples using Flash Attention 2 if use_sample_packing: @@ -351,31 +385,62 @@ def forward( sequences_rolled, None, None, self.sequence_parallel_size ) - if self.is_vlm: - output = self.model( - sequences_fwd, - attention_mask=attention_mask_fwd, - position_ids=None, - pixel_values=pixel_values, - image_grid_thw=image_grid_thw, - ) - # NOTE (sumanthrh): Once we have position_ids, we don't need attention mask with flash attention. - elif self.use_sample_packing and self.attn_implementation == "flash_attention_2": - # NOTE (sumanthrh): Don't use attention mask. position_ids is enough. - # Not using attention mask leads to higher perf since flash attention varlen func is enabled - output = self.model(sequences_fwd, attention_mask=None, position_ids=position_ids_fwd) - else: - output = self.model(sequences_fwd, attention_mask=attention_mask_fwd, position_ids=position_ids_fwd) + use_chunked = self.loss_chunk_size > 0 - logits_BSV = output["logits"] - logits_BSV.div_(temperature) + if use_chunked: + # Chunked lm_head: avoid materializing full (B, S, vocab_size) logits tensor. + # Replace lm_head with identity so the model returns hidden states instead. + lm_head = self.model.lm_head + self.model.lm_head = _IdentityLMHead() - # NOTE: this is slightly inaccurate with sample packing because last token from nth seq -> first token of n+1th seq loss is added. - log_probs = logprobs_from_logits( - logits_BSV, - sequences_rolled, - inplace_backward=True, - ) + try: + if self.is_vlm: + output = self.model( + sequences_fwd, + attention_mask=attention_mask_fwd, + position_ids=None, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + # NOTE (sumanthrh): Once we have position_ids, we don't need attention mask with flash attention. + elif self.use_sample_packing and self.attn_implementation == "flash_attention_2": + # NOTE (sumanthrh): Don't use attention mask. position_ids is enough. + # Not using attention mask leads to higher perf since flash attention varlen func is enabled + output = self.model(sequences_fwd, attention_mask=None, position_ids=position_ids_fwd) + else: + output = self.model(sequences_fwd, attention_mask=attention_mask_fwd, position_ids=position_ids_fwd) + finally: + if use_chunked: + self.model.lm_head = lm_head + + if use_chunked: + # output["logits"] is actually hidden_states (B, S, hidden_dim) since lm_head was identity + hidden_states = output["logits"] + entropy_mask = None + if compute_entropy and not self.use_sample_packing: + entropy_mask = attention_mask_fwd + log_probs, entropy_BS = self._chunked_lm_head_forward( + hidden_states, + lm_head, + sequences_rolled, + temperature, + self.loss_chunk_size, + compute_entropy=compute_entropy, + entropy_requires_grad=entropy_requires_grad, + attention_mask=entropy_mask, + ) + # Replace hidden_states in output with None to free memory + output["logits"] = None + else: + logits_BSV = output["logits"] + logits_BSV.div_(temperature) + + # NOTE: this is slightly inaccurate with sample packing because last token from nth seq -> first token of n+1th seq loss is added. + log_probs = logprobs_from_logits( + logits_BSV, + sequences_rolled, + inplace_backward=True, + ) # gather output if sp > 1 if self.sequence_parallel_size > 1: @@ -392,7 +457,20 @@ def forward( log_probs.transpose(0, 1), indices=nnz_indices, batch=batch_size, seqlen=seqlen ).squeeze(-1) - if compute_entropy: + if use_chunked: + # Entropy already computed in _chunked_lm_head_forward + if compute_entropy: + if self.sequence_parallel_size > 1: + dim = entropy_BS.ndim - 1 + entropy_BS = gather_outputs_and_unpad( + entropy_BS, gather_dim=dim, unpad_dim=dim, padding_size=pad_size + ) + if self.use_sample_packing: + entropy_BS = pad_input( + entropy_BS.transpose(0, 1), indices=nnz_indices, batch=batch_size, seqlen=seqlen + ).squeeze(-1) + output["entropy"] = entropy_BS + elif compute_entropy: # For sample packing: entropy is calculated on unpacked data, so no attention mask needed # For non-sample packing: pass the attention mask to exclude padding tokens entropy_mask = None @@ -431,6 +509,93 @@ def forward( else: return action_log_probs + def _chunked_lm_head_forward( + self, + hidden_states, + lm_head, + labels, + temperature, + chunk_size, + compute_entropy=False, + entropy_requires_grad=True, + attention_mask=None, + ): + """Compute log_probs (and optionally entropy) via chunked lm_head projection. + + Instead of materializing the full (B, S, vocab_size) logits tensor, this + computes lm_head in chunks of `chunk_size` tokens along the sequence dimension. + Each chunk uses gradient checkpointing so logits are recomputed during backward + rather than stored, keeping peak memory at (B, chunk_size, vocab_size). + """ + B, S, H = hidden_states.shape + all_log_probs = [] + all_entropy = [] if compute_entropy else None + + # Extract weight/bias from lm_head module. With FSDP2, parameters are DTensors; + # calling the module inside gradient_checkpoint causes DTensor/Tensor mismatch. + # We all-gather DTensors to regular tensors via full_tensor() which is differentiable. + weight = lm_head.weight + bias = lm_head.bias + try: + from torch.distributed.tensor import DTensor + + if isinstance(weight, DTensor): + weight = weight.full_tensor() + if bias is not None and isinstance(bias, DTensor): + bias = bias.full_tensor() + except ImportError: + pass + + # When not computing gradients (ref model), skip gradient_checkpoint entirely — + # just compute each chunk directly with no_grad already active from caller. + use_checkpointing = torch.is_grad_enabled() + + for start in range(0, S, chunk_size): + end = min(start + chunk_size, S) + chunk_hidden = hidden_states[:, start:end] + chunk_labels = labels[:, start:end] + + if compute_entropy: + if use_checkpointing: + chunk_lp, chunk_ent = gradient_checkpoint( + _chunked_logprobs_and_entropy, + chunk_hidden, + chunk_labels, + weight, + bias, + temperature, + use_reentrant=False, + ) + else: + chunk_lp, chunk_ent = _chunked_logprobs_and_entropy( + chunk_hidden, chunk_labels, weight, bias, temperature + ) + if not entropy_requires_grad: + chunk_ent = chunk_ent.detach() + if attention_mask is not None: + chunk_mask = attention_mask[:, start:end] + chunk_ent = chunk_ent * chunk_mask + all_entropy.append(chunk_ent) + else: + if use_checkpointing: + chunk_lp = gradient_checkpoint( + _chunked_logprobs_only, + chunk_hidden, + chunk_labels, + weight, + bias, + temperature, + use_reentrant=False, + ) + else: + chunk_lp = _chunked_logprobs_only(chunk_hidden, chunk_labels, weight, bias, temperature) + + all_log_probs.append(chunk_lp) + + log_probs = torch.cat(all_log_probs, dim=1) + entropy = torch.cat(all_entropy, dim=1) if compute_entropy else None + return log_probs, entropy + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs={"use_reentrant": False}): self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) diff --git a/skyrl/backends/skyrl_train/workers/worker.py b/skyrl/backends/skyrl_train/workers/worker.py index a88104a975..2609bf9347 100644 --- a/skyrl/backends/skyrl_train/workers/worker.py +++ b/skyrl/backends/skyrl_train/workers/worker.py @@ -805,6 +805,7 @@ def _forward_backward_micro( # SFT path: skip KL/entropy terms, return per-token outputs for Tinker API if resolved_loss_name == "cross_entropy": loss = policy_loss + torch.cuda.empty_cache() # defrag before backward to prevent OOM on 35B self.strategy.backward(loss, self.model, self.optimizer) # Compute elementwise loss for Tinker API (per-token NLL) @@ -870,6 +871,7 @@ def _forward_backward_micro( kl_loss_term = kl_loss * self.cfg.algorithm.kl_loss_coef loss = policy_loss + kl_loss_term - entropy_loss_term + torch.cuda.empty_cache() # defrag before backward to prevent OOM on 35B self.strategy.backward(loss, self.model, self.optimizer) # Build per-sequence loss_fn_outputs with logprobs. @@ -1100,6 +1102,7 @@ def _forward_backward_micro(self, experience: Experience) -> Dict[str, float]: loss_mask=loss_mask, ) # NO loss scaling here - gradient scaling happens at optim_step + torch.cuda.empty_cache() # defrag before backward to prevent OOM on 35B self.strategy.backward(loss, self.model, self.optimizer) status = { diff --git a/skyrl/train/config/config.py b/skyrl/train/config/config.py index 310f5da196..c04c9634e5 100644 --- a/skyrl/train/config/config.py +++ b/skyrl/train/config/config.py @@ -41,6 +41,8 @@ def from_dict_config(cls, cfg: DictConfig) -> "BaseConfig": class DataConfig(BaseConfig): train_data: List[str] = field(default_factory=lambda: [os.path.expanduser("~/data/gsm8k/train.parquet")]) val_data: List[str] = field(default_factory=lambda: [os.path.expanduser("~/data/gsm8k/validation.parquet")]) + env_filter: Optional[str] = None + """Comma-separated list of data_source values to include (e.g. 'outlook,github'). None = no filtering.""" # --------------------------------------------------------------------------- @@ -517,6 +519,12 @@ class GeneratorConfig(BaseConfig): """Can differ from the trainer's ``rope_scaling``, useful for thinking models.""" rope_theta: Optional[float] = None step_wise_trajectories: bool = False + inject_context_status: bool = False + """Inject context length status into the conversation.""" + context_warning_threshold: float = 0.90 + """Threshold for context length warning (fraction of max_input_length).""" + trajectory_timeout_seconds: Optional[int] = None + """Timeout in seconds for each trajectory rollout.""" def __post_init__(self): @@ -544,6 +552,8 @@ class SkyRLGymConfig(BaseConfig): text2sql: Text2SQLEnvConfig = field(default_factory=Text2SQLEnvConfig) llm_as_a_judge: GSM8kLLMJudgeEnvConfig = field(default_factory=GSM8kLLMJudgeEnvConfig) search: SearchEnvConfig = field(default_factory=SearchEnvConfig) + fleet_task: Optional[Dict[str, Any]] = None + task_gen: Optional[Dict[str, Any]] = None @dataclass @@ -607,8 +617,15 @@ class TrainerConfig(BaseConfig): logger: str = "wandb" dump_data_batch: bool = False dump_eval_results: bool = True + dump_training_trajectories: bool = False rope_scaling: Optional[Dict[str, Any]] = None rope_theta: Optional[float] = None + loss_chunk_size: Optional[int] = None + """Chunk size for loss computation to reduce memory usage.""" + use_hybrid_env_sampling: bool = False + """Enable hybrid environment sampling for multi-env training.""" + min_samples_per_env: int = 1 + """Minimum number of samples per environment in each batch.""" def __post_init__(self): # ref model defaults to the policy model @@ -786,25 +803,36 @@ def from_cli_overrides(cls, args: Union[List[str], dict]) -> "SkyRLTrainConfig": ) overrides = OmegaConf.from_cli(args) - # Try new format first + # Always load base config and merge overrides. + # Our run scripts use legacy flat keys (e.g. generator.backend) that + # need translation to the new nested format (generator.inference_engine.backend). + # The direct from_dict_config path only works for fully-qualified new-format keys. try: - return cls.from_dict_config(overrides) - except ValueError: - # Fall back to legacy format: load base YAML, merge overrides, translate - try: - base_cfg = get_legacy_config() - merged = OmegaConf.merge(base_cfg, overrides) - merged_dict = OmegaConf.to_container(merged, resolve=True) + base_cfg = get_legacy_config() + except Exception: + # Hydra compose can fail (e.g., GlobalHydra already initialized). + # Fall back to loading YAML directly without Hydra defaults resolution. + import yaml + config_yaml = Path(__file__).parent / "ppo_base_config.yaml" + with open(config_yaml) as f: + raw_yaml = yaml.safe_load(f) + # Remove Hydra defaults key (not needed for direct loading) + raw_yaml.pop("defaults", None) + base_cfg = OmegaConf.create(raw_yaml) - if is_legacy_config(merged_dict): - warn_legacy_config() - translated = translate_legacy_config(merged_dict) - return build_nested_dataclass(cls, translated) - except Exception: - pass # Legacy translation failed, re-raise original error + # Disable struct flag so overrides can add new keys to dict-typed fields + # (e.g., chat_template_kwargs={enable_thinking:true}). + # OmegaConf loads empty dicts from YAML as closed structs by default. + OmegaConf.set_struct(base_cfg, False) - # Re-raise original error if not a legacy config issue - raise + merged = OmegaConf.merge(base_cfg, overrides) + merged_dict = OmegaConf.to_container(merged, resolve=True) + + if is_legacy_config(merged_dict): + warn_legacy_config() + merged_dict = translate_legacy_config(merged_dict) + + return build_nested_dataclass(cls, merged_dict) def make_config( @@ -873,5 +901,18 @@ def get_config_as_dict(cfg: Union[dict, BaseConfig]) -> dict: return asdict(cfg) -def get_config_as_yaml_str(cfg: BaseConfig) -> str: - return yaml.dump(asdict(cfg)) +def get_config_as_yaml_str(cfg) -> str: + if dataclasses.is_dataclass(cfg) and not isinstance(cfg, type): + try: + return yaml.dump(asdict(cfg)) + except TypeError: + # asdict can fail in some Ray serialization edge cases; fall back to str + return str(cfg) + # Handle OmegaConf DictConfig (from Hydra entrypoints) + try: + from omegaconf import OmegaConf + if OmegaConf.is_config(cfg): + return OmegaConf.to_yaml(cfg, resolve=True) + except ImportError: + pass + return str(cfg) diff --git a/skyrl/train/config/ppo_base_config.yaml b/skyrl/train/config/ppo_base_config.yaml index f2a52006e3..a32f40e3f4 100644 --- a/skyrl/train/config/ppo_base_config.yaml +++ b/skyrl/train/config/ppo_base_config.yaml @@ -11,6 +11,7 @@ defaults: data: train_data: ["${oc.env:HOME}/data/gsm8k/train.parquet"] val_data: ["${oc.env:HOME}/data/gsm8k/validation.parquet"] + env_filter: null trainer: placement: @@ -250,11 +251,15 @@ trainer: run_name: "test_run" logger: "wandb" dump_data_batch: false + dump_training_trajectories: false dump_eval_results: true # YaRN: rope_scaling: null rope_theta: null + loss_chunk_size: null + use_hybrid_env_sampling: false + min_samples_per_env: 1 # rope_scaling: # rope_type: yarn # factor: 1.0 @@ -262,6 +267,42 @@ trainer: generator: + # New structured inference_engine config (provides defaults for validate_cfg). + # Legacy flat fields below are translated into this section at runtime. + inference_engine: + model_dtype: "bfloat16" + run_engines_locally: true + num_engines: 1 + backend: "vllm" + weight_sync_backend: "nccl" + weight_transfer_threshold_cuda_ipc_GB: 1.0 + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + expert_parallel_size: 1 + data_parallel_size: 1 + async_engine: true + vllm_v1_disable_multiproc: true + enable_prefix_caching: true + enable_chunked_prefill: true + enable_return_routed_experts: false + max_num_batched_tokens: 8192 + enforce_eager: true + fully_sharded_loras: false + enable_ray_prometheus_stats: false + gpu_memory_utilization: 0.8 + max_num_seqs: 1024 + remote_urls: ["127.0.0.1:8001"] + enable_http_endpoint: false + http_endpoint_host: "127.0.0.1" + http_endpoint_port: 8000 + served_model_name: null + distributed_executor_backend: "ray" + engine_init_kwargs: {} + override_existing_update_group: "auto" + external_proxy_url: null + external_server_urls: null + + # ---- Legacy flat fields (kept for backward compat; translated at runtime) ---- model_name: ${trainer.policy.model.path} model_dtype: "bfloat16" # should match dtype for inference engine run_engines_locally: true @@ -381,6 +422,9 @@ generator: rope_theta: ${trainer.rope_theta} step_wise_trajectories: false + inject_context_status: false + context_warning_threshold: 0.90 + trajectory_timeout_seconds: null environment: env_class: "gsm8k" diff --git a/skyrl/train/config/skyrl_gym_config/default.yaml b/skyrl/train/config/skyrl_gym_config/default.yaml index a94985f1e5..c6a2fec7a8 100644 --- a/skyrl/train/config/skyrl_gym_config/default.yaml +++ b/skyrl/train/config/skyrl_gym_config/default.yaml @@ -14,3 +14,25 @@ search: search_url: "http://127.0.0.1:8000/retrieve" topk: 3 timeout: 30 + +fleet_task: + tasks_file: null + api_key: null + ttl_seconds: 900 + partial_reward: false + enable_hints: false + hint_reward_threshold: 0.0 + n_hint_samples: 2 + enable_context_tools: false + use_llm_hints: false + hint_model: "openrouter/anthropic/claude-sonnet-4" + hint_llm_timeout: 30.0 + +task_gen: + max_turns: 10 + judge_model: "anthropic/claude-sonnet-4.5" + evaluator_model: "anthropic/claude-sonnet-4.5" + k_rollouts: 4 + eval_k_rollouts: 8 + alpha: 1.0 + max_eval_steps: 20 diff --git a/skyrl/train/dataset/dataset.py b/skyrl/train/dataset/dataset.py index 82383fd9b9..f49e09f448 100644 --- a/skyrl/train/dataset/dataset.py +++ b/skyrl/train/dataset/dataset.py @@ -15,12 +15,14 @@ def __init__( num_workers: int = 8, prompt_key: str = "prompt", env_class_key: str = "env_class", + env_filter: str | None = None, ): self.tokenizer = tokenizer self.max_prompt_length = max_prompt_length self.prompt_key = prompt_key self.env_class_key = env_class_key self.num_workers = num_workers + self.env_filter = env_filter self.datasets = datasets if isinstance(self.datasets, str): @@ -55,6 +57,17 @@ def _read_files_and_tokenize(self): logger.info(f"Total dataset size: {len(self.dataframe)}") + # Filter by data_source if env_filter is set + if self.env_filter: + allowed = {e.strip() for e in self.env_filter.split(",") if e.strip()} + before = len(self.dataframe) + self.dataframe = self.dataframe.filter( + lambda row: row.get("data_source", "") in allowed, + num_proc=self.num_workers, + desc=f"Filtering by env_filter ({allowed})", + ) + logger.info(f"env_filter={allowed}: {before} -> {len(self.dataframe)} rows") + # filter out too long prompts tokenizer = self.tokenizer prompt_key = self.prompt_key diff --git a/skyrl/train/entrypoints/main_base.py b/skyrl/train/entrypoints/main_base.py index 9e2fe6e3da..4de2e4dd16 100644 --- a/skyrl/train/entrypoints/main_base.py +++ b/skyrl/train/entrypoints/main_base.py @@ -168,6 +168,7 @@ def get_train_dataset(self): tokenizer=self.tokenizer, max_prompt_length=self.cfg.trainer.max_prompt_length, num_workers=8, + env_filter=getattr(self.cfg.data, "env_filter", None), ) # make sure the dataset is large enough to train on assert ( @@ -187,6 +188,7 @@ def get_eval_dataset(self): tokenizer=self.tokenizer, max_prompt_length=self.cfg.trainer.max_prompt_length, num_workers=8, + env_filter=getattr(self.cfg.data, "env_filter", None), ) return prompts_dataset return None @@ -230,6 +232,7 @@ def get_generator(self, cfg, tokenizer, inference_engine_client): skyrl_gym_cfg=cfg.environment.skyrl_gym, inference_engine_client=inference_engine_client, tokenizer=tokenizer, + model_name=cfg.trainer.policy.model.path, ) def get_trainer( diff --git a/skyrl/train/generators/base.py b/skyrl/train/generators/base.py index c2456b974f..b41d64d36a 100644 --- a/skyrl/train/generators/base.py +++ b/skyrl/train/generators/base.py @@ -41,8 +41,11 @@ class GeneratorOutput(TypedDict): rollout_logprobs: Optional[List[List[float]]] trajectory_ids: Optional[List[TrajectoryID]] rollout_expert_indices: Optional[List[List[List[List[int]]]]] # [batch_size, seq_len, layer_num, topk] + env_metrics: Optional[List[Dict[str, Any]]] # Applicable only for step-wise training is_last_step: Optional[List[bool]] + # Hint augmentation: True for samples generated with hint feedback + is_hinted: Optional[List[bool]] class MetricsOutput(TypedDict): diff --git a/skyrl/train/generators/skyrl_gym_generator.py b/skyrl/train/generators/skyrl_gym_generator.py index 446ad0e572..cbc9137349 100644 --- a/skyrl/train/generators/skyrl_gym_generator.py +++ b/skyrl/train/generators/skyrl_gym_generator.py @@ -7,6 +7,7 @@ import asyncio import copy +from collections import defaultdict from concurrent.futures import ThreadPoolExecutor from dataclasses import asdict, dataclass from typing import Any, Dict, List, Optional, Tuple, Union @@ -31,10 +32,14 @@ TrajectoryID, ) from skyrl.train.generators.utils import ( + apply_chat_template_with_images, apply_overlong_filtering, + extract_images_from_conversation, get_custom_chat_template, get_generation_prompt_ids, get_rollout_metrics, + is_multimodal_conversation, + try_load_processor, ) from skyrl_gym.envs.base_text_env import BaseTextEnvStepOutput @@ -51,6 +56,7 @@ class TrajectoryOutput: rollout_logprobs: Optional[List[float]] env_metrics: Dict[str, Any] rollout_expert_indices: Optional[List[List[List[int]]]] = None + multi_modal_data: Optional[Dict[str, Any]] = None @dataclass @@ -69,6 +75,7 @@ class AgentLoopState: response_end_idx: Optional[int] done: bool rollout_expert_indices: Optional[List[List[List[int]]]] = None + accumulated_images: Optional[List[Any]] = None @dataclass @@ -139,17 +146,22 @@ def __init__( skyrl_gym_cfg: SkyRLGymConfig, inference_engine_client: InferenceEngineClient, tokenizer, + model_name: str = "", ): """ Args: generator_cfg: GeneratorConfig object containing the generator configuration inference_engine_client: InferenceEngineClient object for interacting with the inference engines tokenizer: tokenizer object for encoding and decoding text + model_name: HuggingFace model name (used for VL processor detection) """ self.generator_cfg = generator_cfg self.skyrl_gym_cfg = skyrl_gym_cfg self.inference_engine_client = inference_engine_client self.tokenizer = tokenizer + self.model_name = model_name + self.processor = try_load_processor(model_name) if model_name else None + self.is_vl_model = self.processor is not None self.max_turns = generator_cfg.max_turns self.batched = generator_cfg.batched self.use_conversation_multi_turn = generator_cfg.use_conversation_multi_turn @@ -211,6 +223,28 @@ def _validate_cfg(self, generator_cfg: GeneratorConfig): if not self.use_conversation_multi_turn: raise ValueError("`step_wise_trajectories` doesn't support `use_conversation_multi_turn=False`") + def _apply_chat_template( + self, + conversation: ConversationType, + add_generation_prompt: bool = True, + **kwargs, + ) -> List[int]: + """Apply chat template, routing to VL processor for multimodal conversations.""" + if self.is_vl_model and is_multimodal_conversation(conversation): + return apply_chat_template_with_images( + self.processor, + conversation, + add_generation_prompt=add_generation_prompt, + **kwargs, + ) + return self.tokenizer.apply_chat_template( + conversation, + add_generation_prompt=add_generation_prompt, + tokenize=True, + return_dict=False, + **self.generator_cfg.chat_template_kwargs, + ) + async def _run_in_executor_if_available(self, func, *args, **kwargs): if (executor := self.env_executor) is not None: loop = asyncio.get_running_loop() @@ -218,6 +252,47 @@ async def _run_in_executor_if_available(self, func, *args, **kwargs): else: return func(*args, **kwargs) + async def _env_init(self, env, *args, **kwargs): + """Call env.init, using async path if available to avoid event loop isolation.""" + if hasattr(env, "init_async"): + return await env.init_async(*args, **kwargs) + return await self._run_in_executor_if_available(env.init, *args, **kwargs) + + async def _env_step(self, env, action): + """Call env.step, using async path if available to avoid event loop isolation.""" + if hasattr(env, "step_async"): + return await env.step_async(action) + return await self._run_in_executor_if_available(env.step, action) + + async def _env_close(self, env): + """Call env.close, using async path if available to avoid event loop isolation.""" + if hasattr(env, "close_async"): + return await env.close_async() + return await self._run_in_executor_if_available(env.close) + + def _make_zero_reward_output( + self, + prompt: ConversationType, + zero_reward: Union[float, list], + is_step_wise: bool, + stop_reason: str = "trajectory_error", + env_metrics: Optional[Dict[str, Any]] = None, + ) -> Union[TrajectoryOutput, StepWiseOutput]: + """Create a zero-reward output for failed/cancelled trajectories.""" + prompt_ids = self.tokenizer.apply_chat_template(prompt, add_generation_prompt=True, return_dict=False) + output = TrajectoryOutput( + response_ids=[self.tokenizer.eos_token_id], + reward=zero_reward, + stop_reason=stop_reason, + loss_mask=[0], + prompt_ids=prompt_ids, + rollout_logprobs=[0.0], + env_metrics=env_metrics or {stop_reason: 1.0}, + ) + if is_step_wise: + return StepWiseOutput(step_outputs=[output]) + return output + async def agent_loop( self, prompt: ConversationType, @@ -272,18 +347,53 @@ async def agent_loop( chat_history = copy.deepcopy(prompt) # init() returns the first prompt to be given to the model, and optional metadata dict - chat_history, _ = await self._run_in_executor_if_available(env.init, chat_history) + try: + chat_history, _ = await self._env_init(env, chat_history) + except Exception as e: + logger.warning(f"Session {session_id}: env.init failed ({type(e).__name__}: {e}), returning zero-reward trajectory") + # Return a minimal failed trajectory so training can continue + dummy_ids = self.tokenizer.apply_chat_template( + chat_history, add_generation_prompt=False, tokenize=True, return_dict=False, + **self.generator_cfg.chat_template_kwargs, + ) + eos_id = self.tokenizer.eos_token_id + # Match reward format: custom_chat_template uses float, otherwise per-token List[float] + reward_val = 0.0 if self.custom_chat_template else [0.0] + return TrajectoryOutput( + response_ids=[eos_id] if eos_id is not None else [0], + reward=reward_val, + stop_reason="env_init_error", + loss_mask=[0], + prompt_ids=dummy_ids, + rollout_logprobs=[0.0], + env_metrics={"env_init_error": str(e), "final_reward": 0.0}, + ) initial_chat_history_length = len(chat_history) - initial_input_ids = self.tokenizer.apply_chat_template( - chat_history, - # If retokenize_chat_history==True, avoid including the generation prompt in both the - # prompt_ids and response_ids due to how `response_encodings["input_ids"]` works. - add_generation_prompt=not retokenize_chat_history, - chat_template=self.custom_chat_template if retokenize_chat_history else None, - tokenize=True, - return_dict=False, - **self.generator_cfg.chat_template_kwargs, + + # VL: extract images from initial prompt for multimodal models + initial_images = ( + extract_images_from_conversation(chat_history) + if self.is_vl_model and is_multimodal_conversation(chat_history) + else [] ) + if self.is_vl_model and initial_images: + logger.info(f"Session {session_id}: VL model, extracted {len(initial_images)} initial images") + + # Tokenize initial prompt (VL-aware for multimodal content) + if self.is_vl_model and is_multimodal_conversation(chat_history): + initial_input_ids = self._apply_chat_template( + chat_history, + add_generation_prompt=not retokenize_chat_history, + ) + else: + initial_input_ids = self.tokenizer.apply_chat_template( + chat_history, + add_generation_prompt=not retokenize_chat_history, + chat_template=self.custom_chat_template if retokenize_chat_history else None, + tokenize=True, + return_dict=False, + **self.generator_cfg.chat_template_kwargs, + ) initial_prompt_length = len(initial_input_ids) loss_mask = [] # this excludes the prompt @@ -310,6 +420,7 @@ async def agent_loop( rollout_logprobs=[] if get_logprobs else None, response_end_idx=None, done=False, + accumulated_images=initial_images if initial_images else None, ) while not agent_loop_state.done: @@ -332,8 +443,16 @@ async def agent_loop( agent_loop_state.loss_mask = [] agent_loop_state.rollout_logprobs = None + # VL: build multimodal data for engine input + mm_data = None + if agent_loop_state.accumulated_images: + mm_data = [{"image": agent_loop_state.accumulated_images}] + engine_input = InferenceEngineInput( - prompt_token_ids=[agent_loop_state.input_ids], session_ids=[session_id], sampling_params=sampling_params + prompt_token_ids=[agent_loop_state.input_ids], + session_ids=[session_id], + sampling_params=sampling_params, + multi_modal_data=mm_data, ) engine_output = await self.inference_engine_client.generate(engine_input) output = engine_output["responses"][0] @@ -367,7 +486,7 @@ async def agent_loop( added_eos = True # 2. Environment step - env_step_output: BaseTextEnvStepOutput = await self._run_in_executor_if_available(env.step, output) + env_step_output: BaseTextEnvStepOutput = await self._env_step(env, output) new_obs = env_step_output["observations"] step_reward: float = env_step_output["reward"] agent_loop_state.done = env_step_output["done"] @@ -384,6 +503,14 @@ async def agent_loop( obs_ids = self.get_obs_ids_from_obs(new_obs, agent_loop_state.done) + # VL: accumulate images from observations + if new_obs and is_multimodal_conversation(new_obs): + new_images = extract_images_from_conversation(new_obs) + if new_images: + if agent_loop_state.accumulated_images is None: + agent_loop_state.accumulated_images = [] + agent_loop_state.accumulated_images.extend(new_images) + # final turn output containing generated response and environment observations turn_output = TurnOutput( output=output, @@ -438,10 +565,14 @@ async def agent_loop( per_step_rewards.append((step_reward, agent_loop_state.response_end_idx)) + # Close the environment first so final_reward and verifier feedback + # are captured into the env before we read metrics. Otherwise + # env_metrics is missing final_reward / verifier_stdout / tool_errors, + # which breaks downstream hint recovery metrics (they read + # m.get("final_reward", 0.0) and get 0 for every hinted rollout). + await self._env_close(env) # Get environment-specific metrics after the episode is done env_metrics = env.get_metrics() - # Close the environment - await self._run_in_executor_if_available(env.close) prompt_ids = agent_loop_state.input_ids[:initial_prompt_length] rollout_logprobs = None @@ -520,6 +651,9 @@ async def agent_loop( rollout_logprobs=rollout_logprobs, env_metrics=env_metrics, rollout_expert_indices=rollout_expert_indices_out, + multi_modal_data={"images": agent_loop_state.accumulated_images} + if agent_loop_state.accumulated_images + else None, ) return agent_loop_output @@ -562,6 +696,37 @@ def _build_per_token_rewards( reward_out = token_level_rewards return reward_out + @staticmethod + def _sanitize_messages_for_template(messages: ConversationType) -> ConversationType: + """Ensure message content is compatible with the model's chat template. + + Converts list-format content (multimodal observations from fleet env) to + plain text. This handles two cases from OpenEnv: + 1. List of strings (multiple text results): joined into one string + 2. List of dicts (image_url / text blocks): text extracted, images replaced + """ + sanitized = [] + for msg in messages: + content = msg.get("content") + if isinstance(content, list): + text_parts = [] + for item in content: + if isinstance(item, str): + text_parts.append(item) + elif isinstance(item, dict): + if "text" in item: + text_parts.append(item["text"]) + elif "image_url" in item or "image" in item: + text_parts.append("[image]") + else: + text_parts.append(str(item)) + else: + text_parts.append(str(item)) + sanitized.append({**msg, "content": "\n".join(text_parts)}) + else: + sanitized.append(msg) + return sanitized + def get_obs_ids_from_obs(self, new_obs: ConversationType, is_done: bool) -> List[int]: """ Returns observation token ids from observation messages for a turn. @@ -578,10 +743,13 @@ def get_obs_ids_from_obs(self, new_obs: ConversationType, is_done: bool) -> List # 2. apply chat template for observations, also generate generation prompt for next turn obs_ids_to_add = [] if len(new_obs) > 0: + # Sanitize list-format content (multimodal) to plain text for + # compatibility with text-only chat templates (e.g. Qwen3.5-35B-A3B) + safe_obs = self._sanitize_messages_for_template(new_obs) # For Qwen, this will generate `\n<|user|>Some observation<|im_end|>\n`. Note that the # first `\n` is generated since we stripped it in ``base_conversation_token_ids``. obs_ids_to_add = self.tokenizer.apply_chat_template( - [*self.base_conversation, *new_obs], + [*self.base_conversation, *safe_obs], add_generation_prompt=not is_done, tokenize=True, return_dict=False, @@ -594,7 +762,8 @@ def get_obs_ids_from_obs(self, new_obs: ConversationType, is_done: bool) -> List # no generation prompt is added in this case obs_ids_to_add = [] if len(new_obs) > 0: - for obs in new_obs: + safe_obs = self._sanitize_messages_for_template(new_obs) + for obs in safe_obs: obs_tokens = self.tokenizer.encode(obs["content"], add_special_tokens=False) obs_ids_to_add.extend(obs_tokens) return obs_ids_to_add @@ -629,6 +798,188 @@ def _update_chat_history( chat_history += new_obs return chat_history + @staticmethod + def _extract_task_prompt(prompt: ConversationType) -> str: + """Extract the user's task prompt text from a conversation.""" + for msg in prompt: + if msg.get("role") == "user": + content = msg.get("content", "") + if isinstance(content, list): + return " ".join( + b.get("text", "") for b in content + if isinstance(b, dict) and b.get("type") == "text" + ) + return str(content) + return "" + + async def _run_hint_augmentation( + self, + all_outputs: List[TrajectoryOutput], + prompts: List[ConversationType], + env_classes: List[str], + env_extras: List[Dict[str, Any]], + trajectory_ids: List[TrajectoryID], + max_tokens: int, + max_input_length: int, + sampling_params: Optional[Dict[str, Any]], + hint_cfg, + ) -> Tuple[List[TrajectoryOutput], List[TrajectoryID], List[str]]: + """Run hinted rollouts for prompts where all raw samples failed. + + Groups raw outputs by instance_id, identifies groups where max_reward < threshold, + synthesizes hints (LLM or static), and launches additional hinted rollouts. + + Uses RLTF-SD: hinted rollout prompt_ids are replaced with the original unhinted + prompt_ids so the model learns to produce hint-quality outputs conditioned on the + original prompt alone: grad log pi(y_hint | x_0) instead of grad log pi(y_hint | x_0 + hint). + + Returns: + Tuple of (hinted_outputs, hinted_trajectory_ids, hinted_env_classes) + """ + from skyrl_gym.envs.fleet_task.env import FleetTaskEnv + + use_llm_hints = hint_cfg.get("use_llm_hints", False) if hasattr(hint_cfg, "get") else False + hint_model = hint_cfg.get("hint_model", "openrouter/anthropic/claude-sonnet-4-20250514") if hasattr(hint_cfg, "get") else "openrouter/anthropic/claude-sonnet-4-20250514" + hint_timeout = hint_cfg.get("hint_llm_timeout", 30.0) if hasattr(hint_cfg, "get") else 30.0 + + # 1. Group outputs by instance_id + groups: Dict[str, List[Tuple[int, TrajectoryOutput]]] = defaultdict(list) + for i, output in enumerate(all_outputs): + iid = trajectory_ids[i].instance_id + groups[iid].append((i, output)) + + # 2. Identify prompts needing hints and collect data for LLM synthesis + failed_groups = [] # (iid, best_orig_idx, best_output, best_reward) + hint_reward_threshold = hint_cfg.get("hint_reward_threshold", 0.0) if hasattr(hint_cfg, "get") else 0.0 + + for iid, items in groups.items(): + rewards = [] + for _, output in items: + r = output.reward + rewards.append(r if isinstance(r, (int, float)) else sum(r)) + max_reward = max(rewards) + + if max_reward > hint_reward_threshold: + continue # at least one raw sample has signal + + best_idx = max(range(len(items)), key=lambda j: rewards[j]) + best_orig_idx, best_output = items[best_idx] + failed_groups.append((iid, best_orig_idx, best_output, rewards[best_idx])) + + if not failed_groups: + return [], [], [] + + # 3. Build hints — LLM-synthesized or static + if use_llm_hints: + from skyrl_gym.envs.fleet_task.hint_synthesizer import synthesize_hints_batch + + hint_requests = [] + for iid, best_orig_idx, best_output, _ in failed_groups: + metrics = best_output.env_metrics + hint_requests.append({ + "task_prompt": self._extract_task_prompt(prompts[best_orig_idx]), + "chat_history": metrics.get("chat_history", []), + "verifier_stdout": metrics.get("verifier_stdout"), + "verifier_error": metrics.get("verifier_error"), + "tool_error_messages": metrics.get("tool_error_messages"), + "instance_id": iid, + }) + + hint_results = await synthesize_hints_batch( + hint_requests=hint_requests, + model=hint_model, + timeout=hint_timeout, + static_fallback_fn=FleetTaskEnv.build_hint_text, + ) + else: + hint_results = [] + for iid, best_orig_idx, best_output, _ in failed_groups: + metrics = best_output.env_metrics + hint_text = FleetTaskEnv.build_hint_text( + verifier_stdout=metrics.get("verifier_stdout"), + verifier_error=metrics.get("verifier_error"), + tool_error_messages=metrics.get("tool_error_messages"), + ) + hint_results.append((hint_text, "static_fallback")) + + # 4. Create hinted agent_loop tasks + hint_tasks = [] + hint_tids = [] + hint_envs = [] + orig_prompt_ids = [] + hint_categories = [] + prompts_hinted = 0 + + for group_idx, (iid, best_orig_idx, best_output, best_reward) in enumerate(failed_groups): + hint_text, hint_category = hint_results[group_idx] + if not hint_text: + continue + + logger.info( + f"Hint [{hint_category}] for instance {iid} " + f"(best_reward={best_reward:.3f}):\n{hint_text[:500]}" + ) + prompts_hinted += 1 + + items = groups[iid] + base_rep_id = max(item[0] for item in items) + 1 + n_hint = hint_cfg.get("n_hint_samples", 2) if hasattr(hint_cfg, "get") else 2 + for h in range(n_hint): + hinted_extras = dict(env_extras[best_orig_idx]) + hinted_extras["hint"] = hint_text + hinted_extras["is_hinted"] = True + hinted_extras["hint_category"] = hint_category + tid = TrajectoryID(instance_id=iid, repetition_id=base_rep_id + h) + hint_tasks.append( + self.agent_loop( + prompts[best_orig_idx], + env_classes[best_orig_idx], + hinted_extras, + max_tokens, + max_input_length, + sampling_params=sampling_params, + trajectory_id=tid, + ) + ) + hint_tids.append(tid) + hint_envs.append(env_classes[best_orig_idx]) + orig_prompt_ids.append(best_output.prompt_ids) + hint_categories.append(hint_category) + + # 5. Clean up chat_history from env_metrics to free memory + for _, _, best_output, _ in failed_groups: + best_output.env_metrics.pop("chat_history", None) + + # 6. Run all hinted rollouts in parallel + if hint_tasks: + logger.info( + f"Hint augmentation: {prompts_hinted} prompts need hints, " + f"launching {len(hint_tasks)} hinted rollouts" + ) + hint_outputs = await tqdm.gather( + *hint_tasks, + desc="Hinted Rollouts", + miniters=1, + mininterval=5, + ) + # RLTF-SD: strip hint from training prompt. Replace hinted prompt_ids + # with the original unhinted prompt_ids so the model learns to produce + # hint-quality outputs conditioned on the original prompt alone. + hint_outputs = list(hint_outputs) + for i, output in enumerate(hint_outputs): + hinted_len = len(output.prompt_ids) + output.prompt_ids = orig_prompt_ids[i] + # Propagate hint_category into env_metrics for tracking + if isinstance(output.env_metrics, dict): + output.env_metrics["hint_category"] = hint_categories[i] + logger.debug( + f"RLTF-SD: replaced hinted prompt ({hinted_len} tokens) " + f"with original prompt ({len(output.prompt_ids)} tokens)" + ) + return hint_outputs, hint_tids, hint_envs + + return [], [], [] + async def generate_batched( self, prompts: List[ConversationType], @@ -656,7 +1007,7 @@ async def generate_batched( env_extra["max_turns"] = self.max_turns env_config = getattr(self.skyrl_gym_cfg, env_class, dict()) env = skyrl_gym.make(env_class, env_config=env_config, extras=env_extra) - init_prompt, _ = await self._run_in_executor_if_available(env.init, prompt) + init_prompt, _ = await self._env_init(env, prompt) init_prompts.append(init_prompt) envs.append(env) @@ -684,7 +1035,7 @@ async def generate_batched( for i, (output, response, env, env_class) in enumerate(zip(outputs, responses, envs, env_classes)): # step on environment and compute reward - env_step_output: BaseTextEnvStepOutput = await self._run_in_executor_if_available(env.step, output) + env_step_output: BaseTextEnvStepOutput = await self._env_step(env, output) reward = env_step_output["reward"] rewards.append(reward) @@ -700,10 +1051,11 @@ async def generate_batched( prompt_len = len(prompt_token_ids[i]) truncated_indices.append(sample_indices[: prompt_len + len(response)]) + # Close the environment first so final_reward and verifier + # feedback are populated before get_metrics() reads them. + await self._env_close(env) # Get environment-specific metrics env_metrics.append(env.get_metrics()) - # Close the environment - await self._run_in_executor_if_available(env.close) rollout_metrics = get_rollout_metrics(responses, rewards, env_metrics, env_classes) @@ -720,6 +1072,7 @@ async def generate_batched( "rollout_metrics": rollout_metrics, "rollout_logprobs": truncated_logprobs, "rollout_expert_indices": truncated_indices, + "env_metrics": env_metrics, } return generator_output @@ -749,27 +1102,88 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False return await self.generate_batched(prompts, env_classes, env_extras, max_tokens, sampling_params) # Async agent loop to generate trajectories in parallel. - tasks = [] + # Use asyncio.wait() instead of gather() so individual trajectory failures + # don't crash the entire batch — failed trajectories get zero-reward outputs. + is_step_wise = self.generator_cfg.step_wise_trajectories + zero_reward = 0.0 if self.custom_chat_template else [0.0] + + async_tasks = [] for i in range(len(prompts)): - tasks.append( - self.agent_loop( - prompts[i], - env_classes[i], - env_extras[i], - max_tokens, - max_input_length, - sampling_params=sampling_params, - trajectory_id=trajectory_ids[i] if trajectory_ids is not None else None, - ) + coro = self.agent_loop( + prompts[i], + env_classes[i], + env_extras[i], + max_tokens, + max_input_length, + sampling_params=sampling_params, + trajectory_id=trajectory_ids[i] if trajectory_ids is not None else None, ) + async_tasks.append(asyncio.ensure_future(coro)) + + task_to_idx = {id(t): i for i, t in enumerate(async_tasks)} - all_outputs = await tqdm.gather( - *tasks, + pbar = tqdm( + total=len(async_tasks), desc="Generating Trajectories", - miniters=max(1, len(tasks) // 10), + miniters=max(1, len(async_tasks) // 10), mininterval=5, disable=disable_tqdm, ) + for t in async_tasks: + t.add_done_callback(lambda _: pbar.update(1)) + + done, pending = await asyncio.wait(async_tasks) + pbar.close() + + all_outputs: list = [None] * len(async_tasks) + for t in done: + idx = task_to_idx[id(t)] + if t.exception() is not None: + logger.error(f"Trajectory {idx} raised exception: {t.exception()}") + all_outputs[idx] = self._make_zero_reward_output(prompts[idx], zero_reward, is_step_wise) + else: + all_outputs[idx] = t.result() + + # --- Hint augmentation: rescue GRPO signal on dead prompts --- + # Only during training; eval should not run hints. + n_raw = len(all_outputs) + batch_metadata = input_batch.get("batch_metadata") + is_training = batch_metadata is not None and batch_metadata.training_phase == "train" + hint_cfg = getattr(self.skyrl_gym_cfg, "fleet_task", None) + enable_hints = hint_cfg is not None and (hint_cfg.get("enable_hints", False) if hasattr(hint_cfg, "get") else False) + if ( + enable_hints + and not self.generator_cfg.step_wise_trajectories + and trajectory_ids is not None + and is_training + ): + hint_outputs, hint_tids, hint_env_classes = await self._run_hint_augmentation( + all_outputs=list(all_outputs), + prompts=prompts, + env_classes=env_classes, + env_extras=env_extras, + trajectory_ids=trajectory_ids, + max_tokens=max_tokens, + max_input_length=max_input_length, + sampling_params=sampling_params, + hint_cfg=hint_cfg, + ) + if hint_outputs: + all_outputs = list(all_outputs) + hint_outputs + # Extend in-place so input_batch references are updated (trainer reads these) + trajectory_ids.extend(hint_tids) + env_classes.extend(hint_env_classes) + # Also extend prompts and env_extras arrays to stay aligned + for tid in hint_tids: + # Find original prompt index for this instance_id + for orig_i, orig_tid in enumerate(input_batch.get("trajectory_ids", [])): + if orig_tid.instance_id == tid.instance_id: + prompts.append(prompts[orig_i]) + env_extras.append(env_extras[orig_i]) + break + + # Build is_hinted array: raw samples are False, hint-augmented samples are True + is_hinted = [False] * n_raw + [True] * (len(all_outputs) - n_raw) if self.generator_cfg.step_wise_trajectories: responses = [] @@ -827,6 +1241,57 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False rollout_metrics = get_rollout_metrics(responses, rewards, env_metrics, env_classes) + # Log hint augmentation metrics + hinted_metrics = [m for m in env_metrics if isinstance(m, dict) and m.get("is_hinted")] + if hinted_metrics: + n_hinted = len(hinted_metrics) + hinted_rewards = [] + for m in hinted_metrics: + r = m.get("final_reward", 0.0) + hinted_rewards.append(r if r is not None else 0.0) + n_success = sum(1 for r in hinted_rewards if r > 0) + rollout_metrics["hint/total_hinted_rollouts"] = n_hinted + rollout_metrics["hint/hint_success_rate"] = n_success / n_hinted if n_hinted > 0 else 0.0 + # Count unique prompts that were hinted (by instance_id) + hinted_iids = set() + for i, m in enumerate(env_metrics): + if m.get("is_hinted") and trajectory_ids is not None and i < len(trajectory_ids): + hinted_iids.add(trajectory_ids[i].instance_id) + rollout_metrics["hint/prompts_hinted"] = len(hinted_iids) + # Signal rescued: prompts where at least 1 hinted sample scored > 0 + rescued = 0 + for iid in hinted_iids: + for j, m in enumerate(env_metrics): + if m.get("is_hinted") and trajectory_ids is not None and j < len(trajectory_ids): + if trajectory_ids[j].instance_id == iid: + r = m.get("final_reward", 0.0) + if r is not None and r > 0: + rescued += 1 + break + rollout_metrics["hint/signal_rescued"] = rescued / len(hinted_iids) if hinted_iids else 0.0 + + # Category-level metrics (LLM-synthesized vs static) + from skyrl_gym.envs.fleet_task.hint_synthesizer import CATEGORY_LLM + llm_metrics = [m for m in hinted_metrics if m.get("hint_category") == CATEGORY_LLM] + n_llm = len(llm_metrics) + if n_llm > 0: + llm_rewards = [m.get("final_reward", 0.0) or 0.0 for m in llm_metrics] + llm_success = sum(1 for r in llm_rewards if r > 0) + rollout_metrics["hint/category_llm_synthesized_count"] = n_llm + rollout_metrics["hint/category_llm_synthesized_success_rate"] = llm_success / n_llm + static_metrics = [m for m in hinted_metrics if m.get("hint_category", "").endswith("fallback")] + n_static = len(static_metrics) + if n_static > 0: + static_rewards = [m.get("final_reward", 0.0) or 0.0 for m in static_metrics] + static_success = sum(1 for r in static_rewards if r > 0) + rollout_metrics["hint/category_static_fallback_count"] = n_static + rollout_metrics["hint/category_static_fallback_success_rate"] = static_success / n_static + + # Clean up chat_history from env_metrics to prevent it from being serialized downstream + for m in env_metrics: + if isinstance(m, dict): + m.pop("chat_history", None) + if self.generator_cfg.zero_reward_on_non_stop: # set reward to 0 if the stop reason is not "stop" rewards = self._zero_reward_if_not_stop(rewards, stop_reasons) @@ -835,6 +1300,9 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False # set loss mask to 0 if the stop reason is not "stop" loss_masks = apply_overlong_filtering(loss_masks, stop_reasons) + # Collect per-trajectory images (for dump_training_trajectories) + multi_modal_data_list = [output.multi_modal_data for output in all_outputs] + generator_output: GeneratorOutput = { "prompt_token_ids": prompt_token_ids, "response_ids": responses, @@ -845,7 +1313,10 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False "rollout_logprobs": rollout_logprobs, "trajectory_ids": out_trajectory_ids, "rollout_expert_indices": rollout_expert_indices, + "env_metrics": env_metrics, "is_last_step": is_last_step, + "is_hinted": is_hinted, + "multi_modal_data": multi_modal_data_list, } return generator_output diff --git a/skyrl/train/generators/utils.py b/skyrl/train/generators/utils.py index 39410b9d6b..a7bf440702 100644 --- a/skyrl/train/generators/utils.py +++ b/skyrl/train/generators/utils.py @@ -1,7 +1,7 @@ import copy import os from collections import defaultdict -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -574,3 +574,185 @@ def get_response_ids_and_loss_mask_from_messages( assert len(rollout_logprobs) == len(response_ids) if rollout_logprobs is not None else True return response_ids, loss_mask, rollout_logprobs + + +# --- Multimodal/VL Utilities --- + + +def is_multimodal_message(message: Dict[str, Any]) -> bool: + """Check if a message contains multimodal content (images).""" + content = message.get("content") + if isinstance(content, str): + return False + if isinstance(content, list): + return any(isinstance(item, dict) and item.get("type") == "image_url" for item in content) + return False + + +def is_multimodal_conversation(conversation: ConversationType) -> bool: + """Check if any message in a conversation contains multimodal content.""" + return any(is_multimodal_message(msg) for msg in conversation) + + +def extract_images_from_conversation(conversation: ConversationType) -> List[Any]: + """Extract all images from a conversation in order. + + Supports base64 data URLs, HTTP URLs, and local file paths. + Returns a list of PIL Images or image URL strings. + """ + images = [] + for message in conversation: + content = message.get("content") + if not isinstance(content, list): + continue + for item in content: + if not isinstance(item, dict) or item.get("type") != "image_url": + continue + image_url_data = item.get("image_url", {}) + url = image_url_data.get("url") or "" + if url.startswith("data:image"): + images.append(decode_base64_image(url)) + elif url.startswith(("http://", "https://")): + images.append(url) + elif url: + images.append(load_image_from_path(url)) + return images + + +def decode_base64_image(data_url: str) -> "Image.Image": + """Decode a base64 image from a data URL.""" + import base64 + import io + + try: + from PIL import Image + except ImportError: + raise ImportError("PIL/Pillow is required for multimodal support. Install with: pip install pillow") + + if "," in data_url: + base64_data = data_url.split(",", 1)[1] + else: + base64_data = data_url + + image_bytes = base64.b64decode(base64_data) + return Image.open(io.BytesIO(image_bytes)) + + +def load_image_from_path(path: str) -> "Image.Image": + """Load an image from a file path.""" + try: + from PIL import Image + except ImportError: + raise ImportError("PIL/Pillow is required for multimodal support. Install with: pip install pillow") + + return Image.open(path) + + +def get_text_from_multimodal_content(content: Any) -> str: + """Extract text from multimodal content, ignoring images.""" + if isinstance(content, str): + return content + if isinstance(content, list): + texts = [] + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + texts.append(item.get("text", "")) + return " ".join(texts) + return "" + + +def convert_to_text_only_conversation(conversation: ConversationType) -> ConversationType: + """Convert multimodal conversation to processor format (image_url -> image type).""" + text_only = [] + for message in conversation: + content = message.get("content") + if isinstance(content, str): + text_only.append(message) + elif isinstance(content, list): + new_content = [] + for item in content: + if isinstance(item, dict): + if item.get("type") == "image_url": + new_content.append({"type": "image"}) + else: + new_content.append(item) + else: + new_content.append(item) + text_only.append({"role": message["role"], "content": new_content}) + else: + text_only.append(message) + return text_only + + +def try_load_processor(model_name: str) -> Optional[Any]: + """Try to load a HuggingFace processor for VL models. Returns None for text-only models.""" + try: + from transformers import AutoProcessor + + processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) + if hasattr(processor, "image_processor") or "VL" in type(processor).__name__: + logger.info(f"Loaded VL processor for model: {model_name}") + return processor + return None + except Exception as e: + logger.debug(f"No processor available for {model_name}: {e}") + return None + + +def apply_chat_template_with_images( + processor_or_tokenizer, + conversation: ConversationType, + add_generation_prompt: bool = True, + chat_template: Optional[str] = None, + **kwargs, +) -> List[int]: + """Apply chat template handling both text-only and multimodal conversations. + + For VL models (with processor), extracts images and uses the processor to get + correctly-sized token IDs (vision tokens expand based on image dimensions). + """ + has_processor = hasattr(processor_or_tokenizer, "image_processor") or hasattr( + processor_or_tokenizer, "tokenizer" + ) + + if has_processor and is_multimodal_conversation(conversation): + processor = processor_or_tokenizer + converted = convert_to_text_only_conversation(conversation) + images = extract_images_from_conversation(conversation) + + text = processor.apply_chat_template( + converted, + tokenize=False, + add_generation_prompt=add_generation_prompt, + **kwargs, + ) + + if images: + inputs = processor(text=text, images=images, return_tensors="pt") + token_ids = inputs.input_ids[0].tolist() + else: + tokenizer = getattr(processor, "tokenizer", processor) + token_ids = tokenizer.encode(text, add_special_tokens=False) + return token_ids + else: + tokenizer = getattr(processor_or_tokenizer, "tokenizer", processor_or_tokenizer) + + if is_multimodal_conversation(conversation): + text_conversation = [] + for msg in conversation: + content = msg.get("content") + if isinstance(content, list): + text = get_text_from_multimodal_content(content) + text_conversation.append({"role": msg["role"], "content": text}) + else: + text_conversation.append(msg) + conversation = text_conversation + + return tokenizer.apply_chat_template( + conversation, + add_generation_prompt=add_generation_prompt, + tokenize=True, + chat_template=chat_template, + return_dict=False, + **kwargs, + ) diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index e22312c56e..e3bb0d3028 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -70,6 +70,7 @@ ResumeMode, build_dataloader, cleanup_old_checkpoints, + dump_training_trajectories, extract_step_from_path, run_on_each_node, validate_consistency_for_latest_checkpoint, @@ -172,6 +173,22 @@ async def eval(self) -> Dict[str, float]: global_step=self.global_step, tokenizer=self.tokenizer, ) + + # Upload eval results to S3 + if self.cfg.trainer.dump_eval_results: + try: + from integrations.fleet.s3_checkpoints import upload_eval_results_to_s3 + + step_suffix = "eval_only" if self.global_step is None else f"global_step_{self.global_step}_evals" + local_dir = os.path.join(self.cfg.trainer.export_path, "dumped_evals", step_suffix) + upload_eval_results_to_s3( + local_dir=local_dir, + run_name=self.cfg.trainer.run_name, + global_step=self.global_step, + ) + except Exception as e: + logger.warning(f"Failed to upload eval results to S3: {e}") + return eval_metrics async def train(self): @@ -231,6 +248,10 @@ async def train(self): # NOTE: We use instance_ids from `trajectory_ids` here instead of re-using `uids` # this is because in step-wise training, len(uids) != len(generator_output["response_ids"]) uids = [trajectory_id.instance_id for trajectory_id in generator_output["trajectory_ids"]] + elif "trajectory_ids" in generator_input and generator_input["trajectory_ids"] is not None: + # Hint augmentation may extend trajectory_ids in-place during generate(). + # Re-derive uids to stay aligned with rewards/responses. + uids = [tid.instance_id for tid in generator_input["trajectory_ids"]] # dynamic sampling if self.cfg.trainer.algorithm.dynamic_sampling.type is not None: @@ -244,6 +265,26 @@ async def train(self): # if we are not continuing sampling, we sleep the inference engine await self.inference_engine_client.sleep() + # 1.1.5 dump training trajectories + if self.cfg.trainer.dump_training_trajectories: + with Timer("dump_training_trajectories", self.all_timings): + traj_file = dump_training_trajectories( + dump_dir=self.cfg.trainer.export_path, + tokenizer=self.tokenizer, + generator_output=generator_output, + env_extras=generator_input.get("env_extras", []), + global_step=self.global_step, + ) + try: + from integrations.fleet.s3_checkpoints import upload_training_trajectories_to_s3 + upload_training_trajectories_to_s3( + local_path=traj_file, + run_name=self.cfg.trainer.run_name, + global_step=self.global_step, + ) + except Exception as e: + logger.warning(f"Failed to upload training trajectories to S3: {e}") + # 1.2 postprocess rewards with Timer("postprocess_generator_output", self.all_timings): generator_output = self.postprocess_generator_output(generator_output, uids) @@ -295,6 +336,16 @@ async def train(self): if self.cfg.trainer.ckpt_interval > 0 and self.global_step % self.cfg.trainer.ckpt_interval == 0: with Timer("save_checkpoints", self.all_timings): self.save_checkpoints() + if self.cfg.trainer.dump_training_trajectories: + try: + from integrations.fleet.s3_checkpoints import upload_reward_rollouts_to_s3 + reward_rollout_dir = os.environ.get("REWARD_ROLLOUT_DIR", "/workspace/reward_rollouts") + upload_reward_rollouts_to_s3( + rollout_dir=reward_rollout_dir, + run_name=self.cfg.trainer.run_name, + ) + except Exception as e: + logger.warning(f"Failed to upload reward rollouts to S3: {e}") if ( self.cfg.trainer.hf_save_interval > 0 and self.global_step % self.cfg.trainer.hf_save_interval == 0 @@ -659,6 +710,9 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis }, ) training_input.metadata = {"uids": uids} + # Track which samples are hint-augmented for first-turn baseline + if generator_output.get("is_hinted") is not None: + training_input.metadata["is_hinted"] = generator_output["is_hinted"] # padded response length training_input.metadata["response_length"] = response_masks_tensor.shape[1] batch_num_seq, batch_padded_seq_len = sequences_tensor.shape @@ -798,10 +852,15 @@ def compute_advantages_and_returns(self, data: TrainingInputBatch) -> TrainingIn """ token_level_rewards = data["rewards"] + # Convert is_hinted metadata to numpy array for advantage computation + is_hinted_list = data.metadata.get("is_hinted") + is_hinted = np.array(is_hinted_list) if is_hinted_list is not None else None + if self.cfg.generator.step_wise_trajectories: is_last_step = data["is_last_step"].bool() index = np.array(data.metadata["uids"]) values = data["values"] + last_step_is_hinted = is_hinted[is_last_step.cpu().numpy()] if is_hinted is not None else None # Use the last step of each trajectory to compute advantages. Compatible with any advantage estimator # NOTE(Charlie): so we ignore per-step rewards in step-wise training. last_step_advantages, last_step_returns = ppo_utils.compute_advantages_and_returns( @@ -814,6 +873,7 @@ def compute_advantages_and_returns(self, data: TrainingInputBatch) -> TrainingIn gamma=self.cfg.trainer.algorithm.gamma, lambd=self.cfg.trainer.algorithm.lambd, grpo_norm_by_std=self.cfg.trainer.algorithm.grpo_norm_by_std, + is_hinted=last_step_is_hinted, ) # Broadcast each trajectory's advantage and return to all steps of each trajectory. traj_ids = ( @@ -836,6 +896,7 @@ def compute_advantages_and_returns(self, data: TrainingInputBatch) -> TrainingIn gamma=self.cfg.trainer.algorithm.gamma, lambd=self.cfg.trainer.algorithm.lambd, grpo_norm_by_std=self.cfg.trainer.algorithm.grpo_norm_by_std, + is_hinted=is_hinted, ) data["returns"] = returns data["advantages"] = advantages @@ -920,8 +981,10 @@ def pad_batch(self, training_input: TrainingInputBatch) -> TrainingInputBatch: new_training_input.metadata["trajectory_ids"] = training_input.metadata["trajectory_ids"] + [ f"pad{i}" for i in range(pad_size) ] + if "is_hinted" in training_input.metadata: + new_training_input.metadata["is_hinted"] = training_input.metadata["is_hinted"] + [False] * pad_size for key, value in training_input.metadata.items(): - if key not in ["uids", "trajectory_ids"]: + if key not in ["uids", "trajectory_ids", "is_hinted"]: new_training_input.metadata[key] = copy.deepcopy(value) return new_training_input diff --git a/skyrl/train/utils/trainer_utils.py b/skyrl/train/utils/trainer_utils.py index 7fe3e53fb5..d82f4fc1c7 100644 --- a/skyrl/train/utils/trainer_utils.py +++ b/skyrl/train/utils/trainer_utils.py @@ -1,5 +1,6 @@ import json import os +import time from collections import defaultdict from enum import Enum from pathlib import Path @@ -253,6 +254,15 @@ def dump_per_dataset_eval_results( # Prepare common data input_prompts = [tokenizer.decode(prompt) for prompt in concat_generator_outputs["prompt_token_ids"]] output_responses = [tokenizer.decode(response) for response in concat_generator_outputs["response_ids"]] + multi_modal_data_list = concat_generator_outputs.get("multi_modal_data") or [] + + # Save screenshots if any trajectories have images + images_dir = dump_dir_path / "images" + has_any_images = any( + mm and mm.get("images") for mm in multi_modal_data_list if isinstance(mm, dict) + ) + if has_any_images: + images_dir.mkdir(parents=True, exist_ok=True) # Group indices by data source data_source_indices = {} @@ -264,12 +274,36 @@ def dump_per_dataset_eval_results( data_source_indices[data_source].append(i) # Dump per-dataset files + total_images_saved = 0 for data_source, indices in data_source_indices.items(): sanitized_data_source = sanitize_data_source(data_source) filename = dump_dir_path / f"{sanitized_data_source}.jsonl" with open(filename, "w") as f: for i in indices: + # Save screenshots for this eval trajectory + image_paths = [] + mm_data = multi_modal_data_list[i] if i < len(multi_modal_data_list) else None + if isinstance(mm_data, dict) and mm_data.get("images"): + for j, img in enumerate(mm_data["images"]): + img_filename = f"eval_{i:04d}_img_{j:03d}.jpg" + img_path = images_dir / img_filename + try: + if hasattr(img, "save"): + img.save(str(img_path), "JPEG", quality=85) + image_paths.append(str(img_path)) + total_images_saved += 1 + elif isinstance(img, str) and img.startswith(("http://", "https://")): + image_paths.append(img) + total_images_saved += 1 + elif isinstance(img, bytes): + with open(img_path, "wb") as img_f: + img_f.write(img) + image_paths.append(str(img_path)) + total_images_saved += 1 + except Exception as e: + logger.warning(f"Failed to save eval image {j} for trajectory {i}: {e}") + entry = { "input_prompt": input_prompts[i], "output_response": output_responses[i], @@ -279,10 +313,16 @@ def dump_per_dataset_eval_results( "env_extras": concat_env_extras[i], "data_source": data_source, } + if image_paths: + entry["image_paths"] = image_paths + entry["num_screenshots"] = len(image_paths) f.write(json.dumps(entry, ensure_ascii=False) + "\n") logger.info(f"Dumped eval data for {data_source} to {filename}") + if total_images_saved: + logger.info(f"Saved {total_images_saved} eval screenshots to {images_dir}") + # Dump aggregated results file aggregated_filename = dump_dir_path / "aggregated_results.jsonl" with open(aggregated_filename, "w") as f: @@ -291,6 +331,104 @@ def dump_per_dataset_eval_results( logger.info(f"Dumped aggregated eval metrics to {aggregated_filename}") +def dump_training_trajectories( + dump_dir: str, + tokenizer: AutoTokenizer, + generator_output: GeneratorOutput, + env_extras: List[Dict[str, Any]], + global_step: int, +) -> str: + """Dump training trajectories to a JSONL file for analysis. + + Each line contains: step, env_key, data_source, stop_reason, reward, turns, tokens, prompt, text, timestamp. + """ + traj_dir = Path(dump_dir) / "dumped_trajectories" + traj_dir.mkdir(parents=True, exist_ok=True) + filename = traj_dir / f"global_step_{global_step}.jsonl" + + env_metrics_list = generator_output.get("env_metrics") or [] + multi_modal_data_list = generator_output.get("multi_modal_data") or [] + rewards_list = generator_output["rewards"] + stop_reasons = generator_output.get("stop_reasons") or [] + ts = time.time() + + # Save screenshots alongside JSONL if any trajectories have images + images_dir = traj_dir / f"global_step_{global_step}_images" + has_any_images = any( + mm and mm.get("images") for mm in multi_modal_data_list if isinstance(mm, dict) + ) + if has_any_images: + images_dir.mkdir(parents=True, exist_ok=True) + + with open(filename, "w") as f: + for i in range(len(generator_output["response_ids"])): + env_m = env_metrics_list[i] if i < len(env_metrics_list) and env_metrics_list[i] else {} + env_key = env_m.get("env_key", "unknown") + turns = env_m.get("turns", env_m.get("num_turns", 0)) + extras = env_extras[i] if i < len(env_extras) else {} + data_source = extras.get("data_source", "unknown") if isinstance(extras, dict) else "unknown" + + reward = rewards_list[i] + if isinstance(reward, list): + reward = float(sum(reward)) + else: + reward = float(reward) + + stop_reason = stop_reasons[i] if i < len(stop_reasons) else "unknown" + tokens = len(generator_output["response_ids"][i]) + + # Save screenshots for this trajectory + image_paths = [] + mm_data = multi_modal_data_list[i] if i < len(multi_modal_data_list) else None + if isinstance(mm_data, dict) and mm_data.get("images"): + for j, img in enumerate(mm_data["images"]): + img_filename = f"traj_{i:03d}_img_{j:03d}.jpg" + img_path = images_dir / img_filename + try: + if hasattr(img, "save"): + # PIL Image + img.save(str(img_path), "JPEG", quality=85) + image_paths.append(str(img_path)) + elif isinstance(img, str) and img.startswith(("http://", "https://")): + # URL — store the URL, don't download during training + image_paths.append(img) + elif isinstance(img, bytes): + with open(img_path, "wb") as img_f: + img_f.write(img) + image_paths.append(str(img_path)) + except Exception as e: + logger.warning(f"Failed to save image {j} for trajectory {i}: {e}") + + entry = { + "step": global_step, + "env_key": env_key, + "data_source": data_source, + "stop_reason": stop_reason, + "reward": reward, + "turns": turns, + "tokens": tokens, + "prompt": tokenizer.decode(generator_output["prompt_token_ids"][i]), + "text": tokenizer.decode(generator_output["response_ids"][i]), + "timestamp": ts, + } + if image_paths: + entry["image_paths"] = image_paths + entry["num_screenshots"] = len(image_paths) + f.write(json.dumps(entry, ensure_ascii=False) + "\n") + + n_images = sum( + len(entry.get("images", [])) + for mm in multi_modal_data_list + if isinstance(mm, dict) + for entry in [mm] + ) + logger.info( + f"Dumped {len(generator_output['response_ids'])} training trajectories to {filename}" + + (f" ({n_images} screenshots saved)" if has_any_images else "") + ) + return str(filename) + + class DynamicSamplingState(TypedDict, total=False): """Schema for dynamic sampling state dictionary. @@ -565,6 +703,10 @@ def filter_generator_output(output: GeneratorOutput, kept_indices: List[int]) -> if output.get("stop_reasons"): filtered["stop_reasons"] = [output["stop_reasons"][i] for i in kept_indices] + filtered["env_metrics"] = ( + [output["env_metrics"][i] for i in kept_indices] if output.get("env_metrics") else None + ) + return filtered @@ -744,6 +886,102 @@ def _validate_step_wise_fields(generator_output: GeneratorOutput, num_responses: ) +class HybridEnvSampler(torch.utils.data.Sampler): + """Ensures minimum representation from each environment per batch. + + Prevents batches dominated by large envs (zillow 1000 tasks) while small + envs (rops-mail 93 tasks) get zero samples. Each batch gets at least + min_samples_per_env from every env, remaining slots filled proportionally. + + Ported from fleet-ai/SkyRL-archived. + """ + + def __init__(self, dataset, batch_size, min_samples_per_env=1, generator=None, drop_last=True): + self.dataset = dataset + self.batch_size = batch_size + self.min_samples_per_env = min_samples_per_env + self.generator = generator + self.drop_last = drop_last + + self.env_to_indices: Dict[str, List[int]] = defaultdict(list) + for idx in range(len(dataset)): + row = dataset.dataframe[idx] + group = row.get("data_source") or row.get(dataset.env_class_key, "unknown") + self.env_to_indices[group].append(idx) + + self.env_classes = list(self.env_to_indices.keys()) + self.num_envs = len(self.env_classes) + + min_required = self.num_envs * min_samples_per_env + if min_required > batch_size: + logger.warning( + f"HybridEnvSampler: {self.num_envs} envs × {min_samples_per_env} = {min_required} " + f"> batch_size {batch_size}. Reducing min_samples_per_env." + ) + self.min_samples_per_env = max(1, batch_size // self.num_envs) + + total_samples = len(dataset) + self.env_weights = {env: len(indices) / total_samples for env, indices in self.env_to_indices.items()} + + logger.info(f"HybridEnvSampler: {self.num_envs} envs, batch_size={batch_size}, min_per_env={self.min_samples_per_env}") + for env, indices in sorted(self.env_to_indices.items()): + logger.info(f" {env}: {len(indices)} samples ({self.env_weights[env]*100:.1f}%)") + + def __iter__(self): + env_indices_shuffled = {} + for env, indices in self.env_to_indices.items(): + shuffled = indices.copy() + perm = torch.randperm(len(shuffled), generator=self.generator).tolist() + env_indices_shuffled[env] = [shuffled[i] for i in perm] + + env_positions = {env: 0 for env in self.env_classes} + + min_batches_per_env = [len(indices) // self.min_samples_per_env for indices in self.env_to_indices.values()] + num_batches = min(min_batches_per_env) + total_samples = sum(len(indices) for indices in self.env_to_indices.values()) + num_batches = min(num_batches, total_samples // self.batch_size) + + for _ in range(num_batches): + batch_indices = [] + + for env in self.env_classes: + available = len(env_indices_shuffled[env]) - env_positions[env] + samples_to_take = min(self.min_samples_per_env, available) + for _ in range(samples_to_take): + batch_indices.append(env_indices_shuffled[env][env_positions[env]]) + env_positions[env] += 1 + + remaining = self.batch_size - len(batch_indices) + if remaining > 0: + available_by_env = {env: env_indices_shuffled[env][env_positions[env]:] for env in self.env_classes} + for _ in range(remaining): + envs_with_samples = [env for env, avail in available_by_env.items() if avail] + if not envs_with_samples: + break + weights = [self.env_weights[env] for env in envs_with_samples] + total_w = sum(weights) + weights = [w / total_w for w in weights] + rand_val = torch.rand(1, generator=self.generator).item() + cumsum = 0 + chosen = envs_with_samples[-1] + for env, w in zip(envs_with_samples, weights): + cumsum += w + if rand_val < cumsum: + chosen = env + break + batch_indices.append(available_by_env[chosen].pop(0)) + env_positions[chosen] += 1 + + perm = torch.randperm(len(batch_indices), generator=self.generator).tolist() + yield [batch_indices[i] for i in perm] + + def __len__(self): + min_batches_per_env = [len(indices) // self.min_samples_per_env for indices in self.env_to_indices.values()] + num_batches = min(min_batches_per_env) + total_samples = sum(len(indices) for indices in self.env_to_indices.values()) + return min(num_batches, total_samples // self.batch_size) + + def build_dataloader( cfg: SkyRLTrainConfig, dataset: PromptDataset, is_train=True, is_fully_async=False ) -> StatefulDataLoader: @@ -764,20 +1002,47 @@ def build_dataloader( seeded_generator = torch.Generator() seeded_generator.manual_seed(cfg.trainer.seed) - dataloader = StatefulDataLoader( - dataset, - batch_size=batch_size if not is_fully_async else 1, - shuffle=True if is_train else False, - collate_fn=dataset.collate_fn, - # TODO(Charlie): debug why inference http endpoint is slow when num_workers is 8 - num_workers=0 if cfg.generator.inference_engine.enable_http_endpoint else 8, - drop_last=True if is_train else False, - generator=seeded_generator, - # NOTE (sumanthrh): We use ray and thus use `spawn` start method. - # forking within ray leads to undefined behaviour and often causes hard to debug - # memory leaks. See: https://docs.ray.io/en/latest/ray-core/patterns/fork-new-processes.html - multiprocessing_context="spawn" if not cfg.generator.inference_engine.enable_http_endpoint else None, + use_hybrid_sampling = ( + is_train + and not is_fully_async + and getattr(cfg.trainer, "use_hybrid_env_sampling", False) + and hasattr(dataset, "dataframe") + and hasattr(dataset, "env_class_key") ) + + num_workers = 0 if cfg.generator.inference_engine.enable_http_endpoint else 8 + mp_context = "spawn" if not cfg.generator.inference_engine.enable_http_endpoint else None + + if use_hybrid_sampling: + from torch.utils.data import DataLoader + + min_samples_per_env = getattr(cfg.trainer, "min_samples_per_env", 1) + sampler = HybridEnvSampler( + dataset=dataset, + batch_size=batch_size, + min_samples_per_env=min_samples_per_env, + generator=seeded_generator, + drop_last=True, + ) + dataloader = DataLoader( + dataset, + batch_sampler=sampler, + collate_fn=dataset.collate_fn, + num_workers=num_workers, + ) + logger.info(f"Using HybridEnvSampler with min_samples_per_env={min_samples_per_env}") + else: + dataloader = StatefulDataLoader( + dataset, + batch_size=batch_size if not is_fully_async else 1, + shuffle=True if is_train else False, + collate_fn=dataset.collate_fn, + num_workers=num_workers, + drop_last=True if is_train else False, + generator=seeded_generator, + multiprocessing_context=mp_context, + ) + if is_train: if not is_fully_async: logger.info(f"Total steps: {len(dataloader) * cfg.trainer.epochs}") diff --git a/skyrl/train/utils/utils.py b/skyrl/train/utils/utils.py index 1fa7b1d47b..ffa0f63d14 100644 --- a/skyrl/train/utils/utils.py +++ b/skyrl/train/utils/utils.py @@ -638,6 +638,11 @@ def prepare_runtime_environment(cfg: SkyRLTrainConfig) -> dict[str, str]: logger.info("Exporting mlflow tracking token to ray runtime env") env_vars["MLFLOW_TRACKING_TOKEN"] = os.environ["MLFLOW_TRACKING_TOKEN"] + # Fleet env vars needed by fleet_task and task_gen environments + for var_name in ["FLEET_API_KEY", "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_REGION"]: + if value := os.environ.get(var_name): + env_vars[var_name] = value + # NOTE(charlie): these are for Harbor. We should remove these once we have a sustainable way to handle these environment vars. for var_name in ["DAYTONA_API_KEY", "MODAL_TOKEN_ID", "MODAL_TOKEN_SECRET"]: if value := os.environ.get(var_name): diff --git a/tasks/openenv-fleet-grpo-qwen3_5-35b.yaml b/tasks/openenv-fleet-grpo-qwen3_5-35b.yaml new file mode 100644 index 0000000000..c65312b3db --- /dev/null +++ b/tasks/openenv-fleet-grpo-qwen3_5-35b.yaml @@ -0,0 +1,73 @@ +# Fleet Task GRPO Training via SkyPilot - Qwen3.5-35B-A3B (MoE, Multi-Node) +# Usage: sky launch tasks/openenv-fleet-grpo-qwen3_5-35b.yaml --env FLEET_API_KEY= --env WANDB_API_KEY= --env OPENROUTER_API_KEY= +# +# MoE: 35B total, 3B active (256 experts, 9 active/token). GatedDeltaNet architecture. +# 262K native context. All 35B params in memory (~70GB fp16), optimizer ~140GB, gradients ~70GB. +# +# Multi-node (2-node default, 16 GPUs total): 8x H200 per node +# NOTE: Requires vLLM >= 0.17.0 for native Qwen3.5/GDN support + FlashAttention 4 + +name: fleet-task-grpo-qwen3-5-35b + +resources: + disk_size: 750 + memory: 1500+ + ports: 6479 + ordered: + # RunPod reserved H200s (priority) + - accelerators: H200:8 + cloud: runpod + # GKE fallback + - accelerators: H200:8 + cloud: kubernetes + network_tier: best + use_spot: true + - accelerators: H200:8 + cloud: gcp + use_spot: true + image_id: projects/deeplearning-platform-release/global/images/common-cu128-ubuntu-2204-nvidia-570-v20260305 + # Other providers + - accelerators: H200:8 + cloud: lambda + - accelerators: H200:8 + cloud: nebius + +num_nodes: 2 + +workdir: + url: https://github.com/fleet-ai/SkyRL-v2.git + ref: main + +envs: + WANDB_API_KEY: "" + FLEET_API_KEY: "" + OPENROUTER_API_KEY: "" + LOGGER: "wandb" + INFERENCE_BACKEND: "vllm" + DATA_VERSION: "v6" + ENV_KEYS: "" + DIFFICULTY: "" + MODALITY: "tool_use" + MAX_TURNS: 50 + MAX_INPUT_LENGTH: 72000 + MAX_GENERATE_LENGTH: 4096 + NUM_EPOCHS: 20 + RUN_ID: "" + MAX_TASKS: "" + RESUME_RUN_NAME: "" + AWS_ACCESS_KEY_ID: "" + AWS_SECRET_ACCESS_KEY: "" + AWS_REGION: "us-east-1" + S3_DATASET_BUCKET: "fleet-internal-datasets" + S3_CHECKPOINT_BUCKET: "skyrl-checkpoints" + S3_TRAJECTORY_BUCKET: "skyrl-trajectories" + # TP=2 -> 8 engines (each uses 2 GPUs) to match 16 policy GPUs with colocate_all + NUM_INFERENCE_ENGINES: 8 + +setup: | + bash scripts/fleet-common-setup.sh \ + --openenv-branch deniz/fleet_client \ + --extra-setup scripts/fleet-qwen35-extra-setup.sh + +run: | + bash scripts/fleet-35b-run.sh diff --git a/tasks/openenv-fleet-grpo-vl-taste.yaml b/tasks/openenv-fleet-grpo-vl-taste.yaml new file mode 100644 index 0000000000..e74fc24e50 --- /dev/null +++ b/tasks/openenv-fleet-grpo-vl-taste.yaml @@ -0,0 +1,131 @@ +# Fleet VL/CUA GRPO Training WITH TASTE JUDGE - Qwen3.5-9B (Vision-Language) +# +# Reward shape: GATED TASTE +# effective_taste = max(taste_floor, taste_score) (1.0 on judge fail) +# reward = verifier_reward * effective_taste +# Closes the "pretty failure" hack (verifier=0 -> reward=0 always) while +# preserving within-success taste variance via the floor (default 0.1). +# +# Delta from tasks/openenv-fleet-grpo-vl.yaml: +# - environment.skyrl_gym.fleet_task.taste_floor=0.1 (NEW) +# - environment.skyrl_gym.fleet_task.taste_judge_timeout_s=10.0 (NEW) +# - trainer.algorithm.grpo_norm_by_std=false (FLIPPED, was true default) +# - ANTHROPIC_API_KEY / OPENAI_API_KEY env vars added (NEW) +# +# Required env vars (pass each via `sky launch --env KEY=VALUE`): +# FLEET_API_KEY - Fleet API access for OpenEnv environments +# WANDB_API_KEY - WandB logging +# AWS_ACCESS_KEY_ID - S3 dataset/checkpoint/trajectory buckets +# AWS_SECRET_ACCESS_KEY - S3 credentials +# ANTHROPIC_API_KEY - NEW. Claude judge (research/judge/judge.py default). +# OPENAI_API_KEY - NEW. Reserved for inter-rater / GPT-4o judge path. +# +# Usage: +# sky launch configs/openenv-fleet-grpo-vl-taste.yaml \ +# --env FLEET_API_KEY=... \ +# --env WANDB_API_KEY=... \ +# --env AWS_ACCESS_KEY_ID=... \ +# --env AWS_SECRET_ACCESS_KEY=... \ +# --env ANTHROPIC_API_KEY=... \ +# --env OPENAI_API_KEY=... + +name: fleet-vl-grpo-qwen3-5-9b-taste + +resources: + disk_size: 750 + ports: 6479 + ordered: + - accelerators: H200:8 + cloud: runpod + - accelerators: H200:8 + cloud: gcp + use_spot: true + image_id: projects/deeplearning-platform-release/global/images/common-cu128-ubuntu-2204-nvidia-570-v20260305 + - accelerators: H200:8 + cloud: lambda + - accelerators: H200:8 + cloud: nebius + - accelerators: H200-SXM:8 + cloud: vast + +num_nodes: 1 + +workdir: + url: https://github.com/fleet-ai/skyrl-fleet.git + ref: taste-reward-shaping + +envs: + WANDB_API_KEY: "" + FLEET_API_KEY: "" + # NEW: judge credentials. Anthropic is the default judge backend; OpenAI + # is used for inter-rater agreement / fallback paths. + ANTHROPIC_API_KEY: "" + OPENAI_API_KEY: "" + LOGGER: "wandb" + INFERENCE_BACKEND: "vllm" + DATA_VERSION: "v6" + ENV_KEYS: "" + DIFFICULTY: "" + MODALITY: "browser_use" + MAX_TURNS: 80 + MAX_INPUT_LENGTH: 80000 + MAX_GENERATE_LENGTH: 4096 + NUM_EPOCHS: 10 + RUN_ID: "" + MAX_TASKS: "" + RESUME_RUN_NAME: "" + AWS_ACCESS_KEY_ID: "" + AWS_SECRET_ACCESS_KEY: "" + AWS_REGION: "us-east-1" + S3_DATASET_BUCKET: "fleet-internal-datasets" + S3_CHECKPOINT_BUCKET: "skyrl-checkpoints" + S3_TRAJECTORY_BUCKET: "skyrl-trajectories" + # NEW: runtime kill-switch for taste judge. Set to "1" to fall back to + # verifier-only reward without touching the patched env or restarting. + SKYRL_TASTE_DISABLED: "0" + # NEW: floor for the gated taste reward (forwarded into Hydra override). + # reward = verifier * max(taste_floor, taste_score) + # floor=1.0 -> pure verifier (clean ablation baseline) + # floor=0.0 -> pure multiplicative (every ugly success -> 0 reward) + # floor=0.1 -> default; offline taste of verifier=1 trajectories sits + # around 0.13 so floor=0.1 acts as multiplicative-with-cushion; + # re-tune after a 50-100 step pilot using effective_taste P25. + TASTE_FLOOR: "0.1" + # NEW: production judge selection. OpenRouter is used at training time so + # we don't hit per-org Anthropic rate limits during burst-end-of-step + # judge calls. + OPENROUTER_API_KEY: "" + SKYRL_TASTE_PROVIDER: "openrouter" + SKYRL_TASTE_MODEL: "anthropic/claude-haiku-4.5" + # Stream 4 finding: pass blind_outcome=True at training time to suppress + # outcome bleed (judge sees outcome=True and inflates ~+1.4 weighted-pts). + SKYRL_TASTE_BLIND_OUTCOME: "1" + +setup: | + bash scripts/fleet-common-setup.sh \ + --openenv-branch deniz/fleet_client \ + --extra-setup scripts/fleet-qwen35-extra-setup.sh + # NEW: install the taste-judge package next to skyrl_gym so the env's + # `from skyrl_taste.judge import score_trajectory_async` import resolves. + pip install --no-deps anthropic openai + # skyrl_taste is in skyrl-gym/ — ensure it's importable even if editable install misses it + pip install --no-deps -e ./skyrl-gym 2>/dev/null || true + +run: | + # We delegate to the existing fleet-vl-run.sh wrapper, which forwards extra + # Hydra overrides via the trailing args. The new flags below are the ONLY + # delta from the upstream script. + bash scripts/fleet-vl-run.sh \ + environment.skyrl_gym.fleet_task.taste_floor=${TASTE_FLOOR} \ + environment.skyrl_gym.fleet_task.taste_judge_timeout_s=10.0 \ + trainer.algorithm.grpo_norm_by_std=false + # ^ grpo_norm_by_std=false (flipped from default true): + # Even under gated taste, within-group reward std is inflated whenever a + # group has a mix of pretty and ugly successes (rewards in {0.1, ..., 1.0}) + # on top of the binary verifier signal. Default GRPO normalization would + # divide advantages by that larger denominator and damp the gradient. + # Stream 1's analysis showed std-norm collapses learning under shaped + # reward; that conclusion still applies here. Disabling std normalization + # keeps advantage magnitudes proportional to (reward - mean), which is + # the part the taste signal actually increases. Re-enable and tune + # advantage_batch_normalize=true if cross-prompt magnitudes get unstable. diff --git a/tasks/openenv-fleet-grpo-vl.yaml b/tasks/openenv-fleet-grpo-vl.yaml new file mode 100644 index 0000000000..27d9152ae1 --- /dev/null +++ b/tasks/openenv-fleet-grpo-vl.yaml @@ -0,0 +1,69 @@ +# Fleet VL/CUA GRPO Training via SkyPilot - Qwen3.5-9B (Vision-Language) +# Usage: sky launch tasks/openenv-fleet-grpo-vl.yaml --env FLEET_API_KEY= --env WANDB_API_KEY= +# +# VL (Vision-Language) training for browser_use environments with screenshots. +# Based on working config from SkyRL PR #288 (feat/vl-support-clean). +# +# Model: Qwen/Qwen3.5-9B (9B params, natively multimodal, GatedDeltaNet) +# GPUs: 8x H200 (single node, TP=1) +# +# NOTE: Requires vLLM >= 0.17.0 for native Qwen3.5/GDN support + FlashAttention 4 + +name: fleet-vl-grpo-qwen3-5-9b + +resources: + disk_size: 750 + ports: 6479 + ordered: + # RunPod reserved H200s (priority) + - accelerators: H200:8 + cloud: runpod + # GCP fallback + - accelerators: H200:8 + cloud: gcp + use_spot: true + image_id: projects/deeplearning-platform-release/global/images/common-cu128-ubuntu-2204-nvidia-570-v20260305 + # Other providers + - accelerators: H200:8 + cloud: lambda + - accelerators: H200:8 + cloud: nebius + - accelerators: H200-SXM:8 + cloud: vast + +num_nodes: 2 + +workdir: + url: https://github.com/fleet-ai/SkyRL-v2.git + ref: main + +envs: + WANDB_API_KEY: "" + FLEET_API_KEY: "" + LOGGER: "wandb" + INFERENCE_BACKEND: "vllm" + DATA_VERSION: "v6" + ENV_KEYS: "" + DIFFICULTY: "" + MODALITY: "browser_use" + MAX_TURNS: 80 + MAX_INPUT_LENGTH: 80000 + MAX_GENERATE_LENGTH: 4096 + NUM_EPOCHS: 10 + RUN_ID: "" + MAX_TASKS: "" + RESUME_RUN_NAME: "" + AWS_ACCESS_KEY_ID: "" + AWS_SECRET_ACCESS_KEY: "" + AWS_REGION: "us-east-1" + S3_DATASET_BUCKET: "fleet-internal-datasets" + S3_CHECKPOINT_BUCKET: "skyrl-checkpoints" + S3_TRAJECTORY_BUCKET: "skyrl-trajectories" + +setup: | + bash scripts/fleet-common-setup.sh \ + --openenv-branch deniz/fleet_client \ + --extra-setup scripts/fleet-qwen35-extra-setup.sh + +run: | + bash scripts/fleet-vl-run.sh diff --git a/tasks/task-gen-grpo-qwen3_5-35b.yaml b/tasks/task-gen-grpo-qwen3_5-35b.yaml new file mode 100644 index 0000000000..db2d53b8e8 --- /dev/null +++ b/tasks/task-gen-grpo-qwen3_5-35b.yaml @@ -0,0 +1,65 @@ +# Task Generation GRPO Training via SkyPilot - Qwen3.5-35B-A3B (MoE, Multi-Node) +# Usage: sky launch tasks/task-gen-grpo-qwen3_5-35b.yaml --env FLEET_API_KEY= --env WANDB_API_KEY= +# +# MoE: 35B total, 3B active (256 experts, 9 active/token). GatedDeltaNet architecture. +# Multi-node (2-node default, 16 GPUs total): TP=2, 8 inference engines. +# flash_attn=false (SDPA) to avoid Xid 31 in GatedDeltaNet with vLLM 0.18.0. + +name: task-gen-grpo-qwen3-5-35b + +resources: + disk_size: 750 + memory: 1500+ + ports: 6479 + ordered: + # RunPod reserved H200s (priority) + - accelerators: H200:8 + cloud: runpod + # GCP fallback + - accelerators: H200:8 + cloud: gcp + use_spot: true + image_id: projects/deeplearning-platform-release/global/images/common-cu128-ubuntu-2204-nvidia-570-v20260305 + # Other providers + - accelerators: H200:8 + cloud: lambda + +num_nodes: 2 + +workdir: + url: https://github.com/fleet-ai/SkyRL-v2.git + ref: main + +envs: + WANDB_API_KEY: "" + FLEET_API_KEY: "" + LOGGER: "wandb" + INFERENCE_BACKEND: "vllm" + DATA_VERSION: "v55" + MODALITY: "tool_use" + MAX_TURNS: 10 + MAX_INPUT_LENGTH: 72000 + MAX_GENERATE_LENGTH: 4096 + NUM_EPOCHS: 20 + JUDGE_MODEL: "anthropic/claude-sonnet-4.5" + OPENROUTER_API_KEY: "" + EVALUATOR_MODEL: "anthropic/claude-sonnet-4.5" + K_ROLLOUTS: 4 + ALPHA: "1.0" + MAX_EVAL_STEPS: 20 + AWS_ACCESS_KEY_ID: "" + AWS_SECRET_ACCESS_KEY: "" + AWS_REGION: "us-east-1" + S3_DATASET_BUCKET: "fleet-internal-datasets" + S3_CHECKPOINT_BUCKET: "skyrl-checkpoints" + S3_TRAJECTORY_BUCKET: "skyrl-trajectories" + NUM_INFERENCE_ENGINES: 8 + +setup: | + bash scripts/fleet-common-setup.sh \ + --openenv-branch deniz/fleet_client \ + --extra-setup scripts/fleet-qwen35-extra-setup.sh \ + --env-class task_gen + +run: | + bash scripts/fleet-task-gen-35b-run.sh diff --git a/tasks/task-gen-grpo-qwen3_5-9b.yaml b/tasks/task-gen-grpo-qwen3_5-9b.yaml new file mode 100644 index 0000000000..7c3c1edbf5 --- /dev/null +++ b/tasks/task-gen-grpo-qwen3_5-9b.yaml @@ -0,0 +1,63 @@ +# Task Generation GRPO Training via SkyPilot - Qwen3.5-9B +# Usage: sky launch tasks/task-gen-grpo-qwen3_5-9b.yaml --env FLEET_API_KEY= --env WANDB_API_KEY= +# +# Qwen3.5-9B: MoE with ~1B active params. Fits on single H200 GPU (TP=1). +# 8 inference engines on 8x H200 node. + +name: task-gen-grpo-qwen3-5-9b + +resources: + disk_size: 500 + memory: 800+ + ports: 6479 + ordered: + # RunPod reserved H200s (priority) + - accelerators: H200:8 + cloud: runpod + # GCP fallback + - accelerators: H200:8 + cloud: gcp + use_spot: true + image_id: projects/deeplearning-platform-release/global/images/common-cu128-ubuntu-2204-nvidia-570-v20260305 + # Other providers + - accelerators: H200:8 + cloud: lambda + +num_nodes: 1 + +workdir: + url: https://github.com/fleet-ai/SkyRL-v2.git + ref: main + +envs: + WANDB_API_KEY: "" + FLEET_API_KEY: "" + LOGGER: "wandb" + INFERENCE_BACKEND: "vllm" + DATA_VERSION: "v55" + MODALITY: "tool_use" + MAX_TURNS: 10 + MAX_INPUT_LENGTH: 65536 + MAX_GENERATE_LENGTH: 4096 + NUM_EPOCHS: 20 + JUDGE_MODEL: "anthropic/claude-sonnet-4.5" + OPENROUTER_API_KEY: "" + EVALUATOR_MODEL: "anthropic/claude-sonnet-4.5" + K_ROLLOUTS: 4 + ALPHA: "1.0" + MAX_EVAL_STEPS: 20 + AWS_ACCESS_KEY_ID: "" + AWS_SECRET_ACCESS_KEY: "" + AWS_REGION: "us-east-1" + S3_DATASET_BUCKET: "fleet-internal-datasets" + S3_CHECKPOINT_BUCKET: "skyrl-checkpoints" + S3_TRAJECTORY_BUCKET: "skyrl-trajectories" + +setup: | + bash scripts/fleet-common-setup.sh \ + --openenv-branch deniz/fleet_client \ + --extra-setup scripts/fleet-qwen35-extra-setup.sh \ + --env-class task_gen + +run: | + bash scripts/fleet-task-gen-run.sh