diff --git a/skyrl/backends/skyrl_train/workers/worker.py b/skyrl/backends/skyrl_train/workers/worker.py index 181e3f428a..a1464112f0 100644 --- a/skyrl/backends/skyrl_train/workers/worker.py +++ b/skyrl/backends/skyrl_train/workers/worker.py @@ -701,11 +701,16 @@ def forward_backward( micro_batch_size = self.cfg.micro_train_batch_size_per_gpu all_metrics = defaultdict(list) all_loss_fn_outputs = [] # Handle separately from scalar metrics + total_tokens = data["loss_mask"].sum().clamp(min=1).item() for micro_batch in BatchIterator(data, micro_batch_size, drop_last=False): microbatch_weight = micro_batch_size / len(data) metrics = self._forward_backward_micro( - micro_batch, microbatch_weight, loss_fn=loss_fn, loss_fn_config=loss_fn_config + micro_batch, + microbatch_weight, + total_tokens=total_tokens, + loss_fn=loss_fn, + loss_fn_config=loss_fn_config, ) # Extract loss_fn_outputs before reduce_metrics (it's not a scalar metric) @@ -715,10 +720,11 @@ def forward_backward( for k, v in metrics.items(): all_metrics[k].append(v) - # TODO: SFT path still averages metrics across microbatches and workers. - # This needs to be unified with the RL path which sums. + # "token_mean_legacy" preserves the old per-microbatch averaging behavior. resolved_loss_name = loss_fn or self.cfg.algorithm.policy_loss_type - sum_loss_metrics = resolved_loss_name != "cross_entropy" + sum_loss_metrics = not ( + resolved_loss_name == "cross_entropy" and self.cfg.algorithm.loss_reduction == "token_mean_legacy" + ) # Reduce across microbatches and all-reduce metrics across DP ranks # NOTE: Sum loss metrics because scaling is already applied at the advantage level @@ -736,6 +742,7 @@ def _forward_backward_micro( self, experience: Experience, microbatch_weight: float, + total_tokens: float, loss_fn: Optional[str] = None, loss_fn_config: Optional[Dict[str, Any]] = None, ) -> Dict[str, float]: @@ -812,10 +819,20 @@ def _forward_backward_micro( rollout_logprobs=rollout_action_logprobs, ) + # DP all-reduce averages gradients, but policy losses are pre-scaled sums + # (see `apply_loss_reduction_to_advantages_minibatch`), so we multiply by + # dp_size to recover the correct sum reduction across workers. + grad_sum_correction_factor = self.mesh_rank.dp_size + # SFT path: skip KL/entropy terms, return per-token outputs for Tinker API if resolved_loss_name == "cross_entropy": unscaled_loss = policy_loss - loss = unscaled_loss * microbatch_weight + + if self.cfg.algorithm.loss_reduction == "token_mean_legacy": + loss = unscaled_loss * microbatch_weight + else: + loss = (unscaled_loss / total_tokens) * grad_sum_correction_factor + self.strategy.backward(loss, self.model, self.optimizer) # Compute elementwise loss for Tinker API (per-token NLL) @@ -848,7 +865,7 @@ def _forward_backward_micro( ) status = { - "loss": loss.item(), + "sft_loss": (unscaled_loss / total_tokens).item(), "response_length": num_actions, "lr": self.scheduler.get_last_lr()[0], "loss_fn_outputs": loss_fn_outputs, @@ -880,11 +897,6 @@ def _forward_backward_micro( kl_loss = torch.tensor(0.0) kl_loss_term = kl_loss * self.cfg.algorithm.kl_loss_coef - # DP all-reduce averages gradients, but policy losses are pre-scaled sums - # (see `apply_loss_reduction_to_advantages_minibatch`), so we multiply by - # dp_size to recover the correct sum reduction across workers. - grad_sum_correction_factor = self.mesh_rank.dp_size - # NOTE: The KL and entropy loss terms are not pre-scaled, # so we just average them across microbatches and DP workers. loss = policy_loss * grad_sum_correction_factor + (kl_loss_term - entropy_loss_term) * microbatch_weight diff --git a/skyrl/backends/skyrl_train_backend.py b/skyrl/backends/skyrl_train_backend.py index 9c9d90884f..13ae074a6e 100644 --- a/skyrl/backends/skyrl_train_backend.py +++ b/skyrl/backends/skyrl_train_backend.py @@ -475,11 +475,13 @@ def _extract_metrics(self, data: dict) -> dict[str, float]: """ metrics: dict[str, float] = {} - # SFT path returns 'loss'; RL path returns 'final_loss' / 'policy_loss' - if "loss" in data: - metrics["total_loss:sum"] = float(data["loss"]) + # SFT path returns 'sft_loss'; RL path returns 'final_loss' / 'policy_loss' + if "sft_loss" in data: + metrics["total_loss:sum"] = float(data["sft_loss"]) elif "final_loss" in data: metrics["total_loss:sum"] = float(data["final_loss"]) + elif "loss" in data: + metrics["total_loss:sum"] = float(data["loss"]) if "policy_loss" in data: metrics["pg_loss:sum"] = float(data["policy_loss"]) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_training_step.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_training_step.py index 77bb868a2e..e9e09f2add 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_training_step.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_training_step.py @@ -268,7 +268,7 @@ async def test_sft_forward_backward_with_cross_entropy(ray_init_fixture, cfg, st all_loss_fn_outputs = [] for result in results: assert isinstance(result, dict) - assert "loss" in result + assert "sft_loss" in result assert "loss_fn_outputs" in result, "SFT path should return loss_fn_outputs" loss_fn_outputs = result["loss_fn_outputs"] @@ -288,3 +288,38 @@ async def test_sft_forward_backward_with_cross_entropy(ray_init_fixture, cfg, st finally: ray.shutdown() + + +@pytest.mark.asyncio +async def test_sft_cross_entropy_microbatch_invariance(ray_init_fixture): + """Loss should be identical regardless of micro_train_batch_size_per_gpu.""" + batch_size = 4 + num_actions = 4 + # Variable-length sequences to exercise the bug (uniform lengths would hide it) + dummy_batch = make_dummy_training_batch(batch_size=batch_size, num_actions=num_actions, action_lengths=[1, 2, 3, 4]) + + losses = [] + for micro_batch_size in [1, 2]: + cfg = get_test_actor_config() + cfg.trainer.placement.policy_num_gpus_per_node = 1 + cfg.trainer.strategy = "fsdp2" + cfg.trainer.micro_train_batch_size_per_gpu = micro_batch_size + cfg.trainer.use_sample_packing = False + validate_cfg(cfg) + + try: + actor_group = init_worker_with_type( + "policy", + shared_pg=None, + colocate_all=False, + num_gpus_per_node=1, + cfg=cfg, + ) + results = ray.get( + actor_group.async_run_ray_method("mesh", "forward_backward", data=dummy_batch, loss_fn="cross_entropy") + ) + losses.append(results[0]["sft_loss"]) + finally: + ray.shutdown() + + assert abs(losses[0] - losses[1]) < 1e-5, f"Loss should be microbatch-invariant but got {losses[0]} vs {losses[1]}" diff --git a/tests/train/test_trainer.py b/tests/train/test_trainer.py index 2bb81e1c6b..51449b0965 100644 --- a/tests/train/test_trainer.py +++ b/tests/train/test_trainer.py @@ -522,7 +522,9 @@ def create_test_worker(worker_class): # Mock _forward_backward_micro to track calls policy_forward_backward_micro_calls = [] - def mock_policy_forward_backward_micro(experience, microbatch_weight, loss_fn=None, loss_fn_config=None): + def mock_policy_forward_backward_micro( + experience, microbatch_weight, total_tokens=None, loss_fn=None, loss_fn_config=None + ): policy_forward_backward_micro_calls.append(experience) return {"policy_loss": 0.5, "ppo_clip_ratio": 0.1, "policy_entropy": 2.0, "response_length": response_length}