Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions skyrl/backends/skyrl_train/workers/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,11 +701,16 @@ def forward_backward(
micro_batch_size = self.cfg.micro_train_batch_size_per_gpu
all_metrics = defaultdict(list)
all_loss_fn_outputs = [] # Handle separately from scalar metrics
total_tokens = data["loss_mask"].sum().clamp(min=1).item()

for micro_batch in BatchIterator(data, micro_batch_size, drop_last=False):
microbatch_weight = micro_batch_size / len(data)
metrics = self._forward_backward_micro(
micro_batch, microbatch_weight, loss_fn=loss_fn, loss_fn_config=loss_fn_config
micro_batch,
microbatch_weight,
total_tokens=total_tokens,
loss_fn=loss_fn,
loss_fn_config=loss_fn_config,
)

# Extract loss_fn_outputs before reduce_metrics (it's not a scalar metric)
Expand All @@ -715,10 +720,11 @@ def forward_backward(
for k, v in metrics.items():
all_metrics[k].append(v)

# TODO: SFT path still averages metrics across microbatches and workers.
# This needs to be unified with the RL path which sums.
# "token_mean_legacy" preserves the old per-microbatch averaging behavior.
resolved_loss_name = loss_fn or self.cfg.algorithm.policy_loss_type
sum_loss_metrics = resolved_loss_name != "cross_entropy"
sum_loss_metrics = not (
resolved_loss_name == "cross_entropy" and self.cfg.algorithm.loss_reduction == "token_mean_legacy"
)

# Reduce across microbatches and all-reduce metrics across DP ranks
# NOTE: Sum loss metrics because scaling is already applied at the advantage level
Expand All @@ -736,6 +742,7 @@ def _forward_backward_micro(
self,
experience: Experience,
microbatch_weight: float,
total_tokens: float,
loss_fn: Optional[str] = None,
loss_fn_config: Optional[Dict[str, Any]] = None,
Comment thread
devin-ai-integration[bot] marked this conversation as resolved.
) -> Dict[str, float]:
Expand Down Expand Up @@ -812,10 +819,20 @@ def _forward_backward_micro(
rollout_logprobs=rollout_action_logprobs,
)

# DP all-reduce averages gradients, but policy losses are pre-scaled sums
# (see `apply_loss_reduction_to_advantages_minibatch`), so we multiply by
# dp_size to recover the correct sum reduction across workers.
grad_sum_correction_factor = self.mesh_rank.dp_size

# SFT path: skip KL/entropy terms, return per-token outputs for Tinker API
if resolved_loss_name == "cross_entropy":
unscaled_loss = policy_loss
loss = unscaled_loss * microbatch_weight

if self.cfg.algorithm.loss_reduction == "token_mean_legacy":
loss = unscaled_loss * microbatch_weight
else:
loss = (unscaled_loss / total_tokens) * grad_sum_correction_factor
Comment thread
agolajko marked this conversation as resolved.
Comment thread
agolajko marked this conversation as resolved.

self.strategy.backward(loss, self.model, self.optimizer)

# Compute elementwise loss for Tinker API (per-token NLL)
Expand Down Expand Up @@ -848,7 +865,7 @@ def _forward_backward_micro(
)

status = {
"loss": loss.item(),
"sft_loss": (unscaled_loss / total_tokens).item(),
"response_length": num_actions,
"lr": self.scheduler.get_last_lr()[0],
"loss_fn_outputs": loss_fn_outputs,
Expand Down Expand Up @@ -880,11 +897,6 @@ def _forward_backward_micro(
kl_loss = torch.tensor(0.0)
kl_loss_term = kl_loss * self.cfg.algorithm.kl_loss_coef

# DP all-reduce averages gradients, but policy losses are pre-scaled sums
# (see `apply_loss_reduction_to_advantages_minibatch`), so we multiply by
# dp_size to recover the correct sum reduction across workers.
grad_sum_correction_factor = self.mesh_rank.dp_size

# NOTE: The KL and entropy loss terms are not pre-scaled,
# so we just average them across microbatches and DP workers.
loss = policy_loss * grad_sum_correction_factor + (kl_loss_term - entropy_loss_term) * microbatch_weight
Expand Down
8 changes: 5 additions & 3 deletions skyrl/backends/skyrl_train_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,11 +475,13 @@ def _extract_metrics(self, data: dict) -> dict[str, float]:
"""
metrics: dict[str, float] = {}

# SFT path returns 'loss'; RL path returns 'final_loss' / 'policy_loss'
if "loss" in data:
metrics["total_loss:sum"] = float(data["loss"])
# SFT path returns 'sft_loss'; RL path returns 'final_loss' / 'policy_loss'
if "sft_loss" in data:
metrics["total_loss:sum"] = float(data["sft_loss"])
elif "final_loss" in data:
metrics["total_loss:sum"] = float(data["final_loss"])
elif "loss" in data:
metrics["total_loss:sum"] = float(data["loss"])

if "policy_loss" in data:
metrics["pg_loss:sum"] = float(data["policy_loss"])
Expand Down
37 changes: 36 additions & 1 deletion tests/backends/skyrl_train/gpu/gpu_ci/test_training_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ async def test_sft_forward_backward_with_cross_entropy(ray_init_fixture, cfg, st
all_loss_fn_outputs = []
for result in results:
assert isinstance(result, dict)
assert "loss" in result
assert "sft_loss" in result
assert "loss_fn_outputs" in result, "SFT path should return loss_fn_outputs"

loss_fn_outputs = result["loss_fn_outputs"]
Expand All @@ -288,3 +288,38 @@ async def test_sft_forward_backward_with_cross_entropy(ray_init_fixture, cfg, st

finally:
ray.shutdown()


@pytest.mark.asyncio
async def test_sft_cross_entropy_microbatch_invariance(ray_init_fixture):
"""Loss should be identical regardless of micro_train_batch_size_per_gpu."""
batch_size = 4
num_actions = 4
# Variable-length sequences to exercise the bug (uniform lengths would hide it)
dummy_batch = make_dummy_training_batch(batch_size=batch_size, num_actions=num_actions, action_lengths=[1, 2, 3, 4])

losses = []
for micro_batch_size in [1, 2]:
cfg = get_test_actor_config()
cfg.trainer.placement.policy_num_gpus_per_node = 1
cfg.trainer.strategy = "fsdp2"
cfg.trainer.micro_train_batch_size_per_gpu = micro_batch_size
cfg.trainer.use_sample_packing = False
validate_cfg(cfg)

try:
actor_group = init_worker_with_type(
"policy",
shared_pg=None,
colocate_all=False,
num_gpus_per_node=1,
cfg=cfg,
)
results = ray.get(
actor_group.async_run_ray_method("mesh", "forward_backward", data=dummy_batch, loss_fn="cross_entropy")
)
losses.append(results[0]["sft_loss"])
finally:
ray.shutdown()

assert abs(losses[0] - losses[1]) < 1e-5, f"Loss should be microbatch-invariant but got {losses[0]} vs {losses[1]}"
4 changes: 3 additions & 1 deletion tests/train/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,9 @@ def create_test_worker(worker_class):
# Mock _forward_backward_micro to track calls
policy_forward_backward_micro_calls = []

def mock_policy_forward_backward_micro(experience, microbatch_weight, loss_fn=None, loss_fn_config=None):
def mock_policy_forward_backward_micro(
experience, microbatch_weight, total_tokens=None, loss_fn=None, loss_fn_config=None
):
policy_forward_backward_micro_calls.append(experience)
return {"policy_loss": 0.5, "ppo_clip_ratio": 0.1, "policy_entropy": 2.0, "response_length": response_length}

Expand Down
Loading