From 12bdb893da84de2227091979104196097386cf5d Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Sun, 28 Dec 2025 19:27:38 +0800 Subject: [PATCH 1/4] [Refactor] refactor RL Trainer fit loop and support debug_rollout --- xtuner/v1/train/rl_trainer.py | 204 +++++++++++++++++++++------------- 1 file changed, 127 insertions(+), 77 deletions(-) diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index ef2d37ecb..a46412d51 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -6,7 +6,6 @@ from shutil import rmtree from typing import cast -import numpy as np import ray import torch from mmengine import load @@ -85,6 +84,8 @@ class RLTrainerConfig(BaseModel): hf_max_keep: int | None = None seed: int = 42 debug: bool = False + debug_rollout: bool = False + rollout_steps: int | None = None @model_validator(mode="after") def _convert_work_dir(self): @@ -182,6 +183,9 @@ class RLTrainer: Defaults to None. seed (int): Random seed for reproducible training. Defaults to 42. debug (bool): Enable debug mode with additional logging. Defaults to False. + debug_rollout (bool): Enable debug mode for rollout workers. Defaults to False. + rollout_steps (int | None): Total number of rollout steps to perform. + If specified, overrides total_epochs. Defaults to None. **Examples:** @@ -235,6 +239,8 @@ def __init__( hf_max_keep: int | None = None, seed: int = 42, debug: bool = False, + debug_rollout: bool = False, + rollout_steps: int | None = None, trainer_cfg: RLTrainerConfig | None = None, ): """Initialize the RL training system.""" @@ -273,6 +279,7 @@ def __init__( self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True) self._debug = debug + self._debug_rollout = debug_rollout self._seed = seed self._set_deterministic() self._set_random_seed(seed) @@ -356,6 +363,10 @@ def __init__( // dataflow_config.global_batch_size * total_epochs ) + if rollout_steps is not None: + self._rollout_steps = rollout_steps + self.logger.info(f"Set rollout steps to {self._rollout_steps} according to rollout_steps arg") + bind_train_rollout(train_controller=self._train_controller, env_controller=self._rollout_env_controller) # update weights if rollout_config.skip_load_weights == True if rollout_config.skip_load_weights: @@ -463,6 +474,79 @@ def _build_train_controller(self, train_worker_cfg: WorkerConfig) -> TrainingCon train_controller = TrainingController.remote(workers=train_workers) return train_controller + def _initial_evaluate(self): + """Performs an initial evaluation before the training loop starts.""" + if self._debug_rollout: + return + if self._enable_initial_evaluate and self._enable_evaluate and self._evaluator: + ray.get(self._rollout_env_controller.update_active_workers.remote()) + scores, eval_data_groups = ray.get(self._evaluator.run.remote(return_samples=True)) + trajectory_save_path = self.exp_dir / "eval_0_trajectory.jsonl" + self._save_trajectories(eval_data_groups, trajectory_save_path) + self.logger.info(f"Initial rollout evaluate scores {scores} and start training") + + def _rollout_step(self, rollout_idx: int, step_timer_dict: dict): + """Performs a single rollout step to generate experience.""" + with timer("generation", step_timer_dict): + ray.get(self._rollout_env_controller.update_active_workers.remote()) + data_groups, multimodal_train_infos = ray.get(self._rollout_dataflow.run.remote()) + with timer("save_trajectory", step_timer_dict): + trajectory_save_path = self.exp_dir / f"rollout_idx_{rollout_idx}_trajectory.jsonl" + self._save_trajectories(data_groups, trajectory_save_path) + self.logger.info(f"Rollout_idx {rollout_idx} finished, saved trajectories to {trajectory_save_path}") + if not self._debug_rollout: + with timer("rollout_offload", step_timer_dict): + ray.get(self._rollout_dataflow.pause.remote()) + ray.get(self._rollout_env_controller.offload.remote()) + return data_groups, multimodal_train_infos + + def _train_step(self, rollout_idx: int, data_groups, multimodal_train_infos, step_timer_dict: dict): + """Performs a single training step on the generated experience.""" + with timer( + "onload", + step_timer_dict, + ): + ray.get(self._train_controller.onload.remote(target="all")) + self.logger.info("Training controller loaded") + + with timer("prepare_data", step_timer_dict): + data_batches, data_info = self._prepare_train_data( + data_groups, self._train_worker_cfg.pack_max_length, multimodal_train_infos + ) + self.logger.info(f"Prepared {len(data_batches)} training data batches") + self._log_data_info(rollout_idx, data_info) + + with timer("training", step_timer_dict): + ray.get( + self._train_controller.fit.remote( + data_batches, pack_max_length=self._train_worker_cfg.pack_max_length, rollout_idx=rollout_idx + ) + ) + + def _sync_weights_and_save(self, rollout_idx: int, step_timer_dict: dict): + """Synchronizes weights and saves checkpoints.""" + with timer("save_ckpt", step_timer_dict): + ray.get(self._train_controller.offload.remote(target="optimizer")) + self._maybe_save_hf() + self._maybe_save_checkpoint() + + with timer("sync_weight", step_timer_dict): + bind_train_rollout(train_controller=self._train_controller, env_controller=self._rollout_env_controller) + ray.get(self._rollout_env_controller.onload_weights.remote()) + ray.get(self._train_controller.update_weights.remote()) + self.logger.info("Model weights synchronized successfully.") + ray.get(self._train_controller.offload.remote(target="model")) + ray.get(self._rollout_env_controller.onload_kvcache.remote()) + + def _evaluate_step(self, rollout_idx: int, step_timer_dict: dict): + """Performs an evaluation step.""" + if self._enable_evaluate and self._evaluator and rollout_idx % self._eval_step == 0: + with timer("evaluation", step_timer_dict): + scores, eval_data_groups = ray.get(self._evaluator.run.remote(return_samples=True)) + trajectory_save_path = self.exp_dir / f"eval_{rollout_idx}_trajectory.jsonl" + self._save_trajectories(eval_data_groups, trajectory_save_path) + self.logger.info(f"Evaluate idx {rollout_idx} scores {scores}") + def fit(self): """Run the RL training loop. @@ -475,71 +559,29 @@ def fit(self): self.logger.info(f"Rollout steps {self._rollout_steps} reached, stop training") return - if self._enable_initial_evaluate and self._enable_evaluate and self._evaluator: - ray.get(self._rollout_env_controller.update_active_workers.remote()) - scores, eval_data_groups = ray.get(self._evaluator.run.remote(return_samples=True)) - trajectory_save_path = self.exp_dir / "eval_0_trajectory.jsonl" - self._save_trajectories(eval_data_groups, trajectory_save_path) - self.logger.info(f"Initial rollout evaluate scores {scores} and start training") + self._initial_evaluate() for rollout_idx in range(self._cur_step + 1, self._rollout_steps + 1): - timer_log_str = f"Rollout {rollout_idx} start \n" + self.logger.info(f"Rollout {rollout_idx}/{self._rollout_steps} start") step_timer_dict = {} - # 1. Rollout - with timer("generation", step_timer_dict): - ray.get(self._rollout_env_controller.update_active_workers.remote()) - data_groups, multimodal_train_infos = ray.get(self._rollout_dataflow.run.remote()) - # 2. Offload rollout models and save trajectories - with timer("offload_and_dump", step_timer_dict): - ray.get(self._rollout_env_controller.offload.remote()) - trajectory_save_path = self.exp_dir / f"rollout_idx_{rollout_idx}_trajectory.jsonl" - self._save_trajectories(data_groups, trajectory_save_path) - self.logger.info(f"Rollout_idx {rollout_idx} finished, saved trajectories to {trajectory_save_path}") - - # 3. Onload training models and prepare data - with timer("onload_and_prepare_data", step_timer_dict): - ray.get(self._train_controller.onload.remote(target="all")) - self.logger.info("Training controller loaded") - data_batches, data_info = self._prepare_train_data( - data_groups, self._train_worker_cfg.pack_max_length, multimodal_train_infos - ) - self.logger.info(f"Prepared {len(data_batches)} training data batches") - self._log_data_info(rollout_idx, data_info) - - # 4. Training Step - with timer("training", step_timer_dict): - ray.get( - self._train_controller.fit.remote( - data_batches, pack_max_length=self._train_worker_cfg.pack_max_length, rollout_idx=rollout_idx - ) - ) + with timer("step", step_timer_dict): + # 1. Rollout to generate experience + data_groups, multimodal_train_infos = self._rollout_step(rollout_idx, step_timer_dict) - # 5. Saving and sync weights - with timer("saving and sync_weight", step_timer_dict): - ray.get(self._train_controller.offload.remote(target="optimizer")) - self._maybe_save_hf() - self._maybe_save_checkpoint() + if not self._debug_rollout: + # 2. Train on the generated experience + self._train_step(rollout_idx, data_groups, multimodal_train_infos, step_timer_dict) - bind_train_rollout( - train_controller=self._train_controller, env_controller=self._rollout_env_controller - ) - ray.get(self._rollout_env_controller.onload_weights.remote()) - ray.get(self._train_controller.update_weights.remote()) - self.logger.info("Model weights synchronized successfully.") - ray.get(self._train_controller.offload.remote(target="model")) - ray.get(self._rollout_env_controller.onload_kvcache.remote()) + # 3. Synchronize weights and save checkpoints + self._sync_weights_and_save(rollout_idx, step_timer_dict) + # 4. Evaluate model performance + self._evaluate_step(rollout_idx, step_timer_dict) + + # 5. Log timing information timer_log_str = f"Rollout {rollout_idx} training finished and timing listed: \n" timer_log_str += timer_logger(step_timer_dict) - self.logger.info(timer_log_str) - - # evaluate - if self._enable_evaluate and self._evaluator and rollout_idx % self._eval_step == 0: - scores, eval_data_groups = ray.get(self._evaluator.run.remote(return_samples=True)) - trajectory_save_path = self.exp_dir / f"eval_{rollout_idx}_trajectory.jsonl" - self._save_trajectories(eval_data_groups, trajectory_save_path) - self.logger.info(f"Evaluate idx {rollout_idx} scores {scores}") self._cur_step = rollout_idx def _log_data_info(self, rollout_idx: int, data_info: dict): @@ -632,22 +674,26 @@ def _prepare_train_data(self, data_groups, pack_max_length, multimodal_train_inf data_batches.append(data_dict) random.shuffle(data_batches) - advantages_list = np.array(advantages_list) + rewards_t = torch.tensor(rewards_list).float() if rewards_list else torch.tensor([0.0]).float() + advantages_t = torch.tensor(advantages_list).float() if advantages_list else torch.tensor([0.0]).float() + prompt_len_t = torch.tensor(prompt_len_list).float() if prompt_len_list else torch.tensor([0.0]).float() + response_len_t = torch.tensor(response_len_list).float() if response_len_list else torch.tensor([0.0]).float() + info_dict = { "batch_size": len(rewards_list), - "rewards/mean": np.mean(rewards_list), - "rewards/min": np.min(rewards_list), - "rewards/max": np.max(rewards_list), - "advantages/mean": np.mean(advantages_list), - "advantages/min": np.min(advantages_list), - "advantages/max": np.max(advantages_list), - "response_len/mean": np.mean(response_len_list), - "response_len/min": np.min(response_len_list), - "response_len/max": np.max(response_len_list), - "response_len/std": np.std(response_len_list), - "prompt_len/mean": np.mean(prompt_len_list), - "prompt_len/min": np.min(prompt_len_list), - "prompt_len/max": np.max(prompt_len_list), + "rewards/mean": rewards_t.mean().item(), + "rewards/min": rewards_t.min().item(), + "rewards/max": rewards_t.max().item(), + "advantages/mean": advantages_t.mean().item(), + "advantages/min": advantages_t.min().item(), + "advantages/max": advantages_t.max().item(), + "response_len/mean": response_len_t.mean().item(), + "response_len/min": response_len_t.min().item(), + "response_len/max": response_len_t.max().item(), + "response_len/std": response_len_t.std().item(), + "prompt_len/mean": prompt_len_t.mean().item(), + "prompt_len/min": prompt_len_t.min().item(), + "prompt_len/max": prompt_len_t.max().item(), } return data_batches, info_dict @@ -683,18 +729,18 @@ def _save_trajectories(self, data_groups, save_path): response_ids = self.tokenizer.encode(data.env.rollout.response, add_special_tokens=False) rollout_response_len_list.append(len(response_ids)) - rewards = torch.tensor(rewards).float() - rollout_response_lens = None + rewards_tensor = torch.tensor(rewards).float() + rollout_response_lens: torch.Tensor = torch.tensor([0.0]).float() if len(rollout_response_len_list) > 0: rollout_response_lens = torch.tensor(rollout_response_len_list).float() _count = 0 with open(save_path, "w", encoding="utf-8") as f: item = { - "reward_mean": rewards.mean().item(), - "reward_std": rewards.std().item(), - "reward_max": rewards.max().item(), - "reward_min": rewards.min().item(), + "reward_mean": rewards_tensor.mean().item(), + "reward_std": rewards_tensor.std().item(), + "reward_max": rewards_tensor.max().item(), + "reward_min": rewards_tensor.min().item(), "response_len_mean": rollout_response_lens.mean().item(), "response_len_std": rollout_response_lens.std().item(), "response_len_max": rollout_response_lens.max().item(), @@ -705,8 +751,12 @@ def _save_trajectories(self, data_groups, save_path): json.dump(item, f, ensure_ascii=False, indent=2) f.write("\n") for group in data_groups: + if not is_valid_for_training(group): + self.logger.error(f"Skip one data group {group} due to rollout failed or empty response.") + continue for data in group: item = { + "action_id": data.uid.action_id, "prompt": data.data.extra_info["raw_prompt"], "response": data.env.rollout.response, "response_len": rollout_response_len_list[_count], From 4e15498aeac26f44f9fc8caa2b15cd23e460ee0c Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Sun, 28 Dec 2025 19:35:48 +0800 Subject: [PATCH 2/4] [Feat] support tensorboard in RL Trainer --- xtuner/v1/rl/base/__init__.py | 3 +- xtuner/v1/rl/base/controller.py | 9 +-- xtuner/v1/rl/base/worker.py | 37 ++++++++-- xtuner/v1/train/rl_trainer.py | 116 ++++++++++++++++++++++++++++++-- 4 files changed, 150 insertions(+), 15 deletions(-) diff --git a/xtuner/v1/rl/base/__init__.py b/xtuner/v1/rl/base/__init__.py index d54016f42..d75603b57 100644 --- a/xtuner/v1/rl/base/__init__.py +++ b/xtuner/v1/rl/base/__init__.py @@ -1,6 +1,6 @@ from .controller import TrainingController, TrainingControllerProxy from .loss import BaseRLLossConfig, RLLossContextInputItem -from .worker import TrainingWorker, TrainingWorkerClass, TrainingWorkerProxy, WorkerConfig +from .worker import TrainingWorker, TrainingWorkerClass, TrainingWorkerProxy, WorkerConfig, WorkerLogItem __all__ = [ @@ -12,4 +12,5 @@ "WorkerConfig", "BaseRLLossConfig", "RLLossContextInputItem", + "WorkerLogItem", ] diff --git a/xtuner/v1/rl/base/controller.py b/xtuner/v1/rl/base/controller.py index e5c2a7088..0346c899a 100644 --- a/xtuner/v1/rl/base/controller.py +++ b/xtuner/v1/rl/base/controller.py @@ -1,5 +1,5 @@ import math -from typing import Literal, TypedDict +from typing import List, Literal, TypedDict import ray import torch @@ -10,7 +10,7 @@ from xtuner.v1.train.trainer import LoadCheckpointConfig from xtuner.v1.utils import ray_method -from .worker import TrainingWorker +from .worker import TrainingWorker, WorkerLogItem class ColateItem(TypedDict): @@ -165,7 +165,7 @@ def _grouped_by_max_length(self, packed_data_batches): return sorted(packed_data_batches, key=lambda x: x["seq_ctx"].max_length_q, reverse=True) @ray_method - def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx: int): + def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx: int) -> List[WorkerLogItem]: has_rollout_routed_experts = False language_cfg = None if data_batches[0]["seq_ctx"].rollout_routed_experts is not None: @@ -256,7 +256,8 @@ def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx: rollout_idx=rollout_idx, ) ) - ray.get(handles) + log_infos = ray.get(handles) + return log_infos @ray_method def offload(self, target: Literal["model", "optimizer", "all"] = "all"): diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index 5e2f07941..11ab5bbc0 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -14,11 +14,12 @@ from ray.actor import ActorClass, ActorProxy from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.distributed.tensor import DTensor +from typing_extensions import NotRequired from xtuner.v1.config.fsdp import FSDPConfig from xtuner.v1.config.optim import LRConfig, OptimConfig from xtuner.v1.data_proto.sequence_context import SequenceContext -from xtuner.v1.engine.train_engine import TrainEngine +from xtuner.v1.engine.train_engine import LossLog, OtherLog, TrainEngine from xtuner.v1.engine.vision_compose_train_engine import ( VisionComposeTrainEngine, ) @@ -141,6 +142,19 @@ class WorkerInputItem(TypedDict): rollout_logprobs: torch.Tensor | None +class WorkerTrainLogItem(TypedDict): + loss_log: LossLog + other_log: OtherLog + + +class WorkerLogItem(TypedDict): + train_entropy: float + rollout_entropy: NotRequired[float] + mismatch_metrics: NotRequired[dict[str, float]] + rollout_is_metrics: NotRequired[dict[str, float]] + train_metrics: List[WorkerTrainLogItem] + + class TrainingWorker(SingleAcceleratorWorker): _SAVE_OPTIMIZER_DIR = "optimizer" _SAVE_MODEL_DIR = "model" @@ -312,7 +326,7 @@ def _update_other_log(self, other_log: dict): return other_log @ray_method - def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int): + def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLogItem: # NOTE: sglang会清除logger handle, 重新创建 self.logger = get_logger(log_dir=self.log_dir, tag="TrainingWorker") loss_cfg = self.config.loss_cfg @@ -459,27 +473,36 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int): all_rollout_is_metrics.append(rollout_is_metrics) all_mismatch_metrics.append(mismatch_metrics) + worker_log_item: WorkerLogItem = { + "train_entropy": 0.0, + "train_metrics": [], + } logger_msg = f"Rollout {rollout_idx}: " - sum_entropy = cast(torch.Tensor, sum_entropy) dist.all_reduce(sum_entropy, op=dist.ReduceOp.SUM) - avg_sum_entropy = sum_entropy / global_grad_tokens if global_grad_tokens > 0 else 0 + avg_sum_entropy = sum_entropy / global_grad_tokens if global_grad_tokens > 0 else torch.tensor(0.0) + worker_log_item["train_entropy"] = avg_sum_entropy.item() logger_msg += f"avg entropy: {avg_sum_entropy:.4f}" if sum_rollout_entropy is not None: sum_rollout_entropy = cast(torch.Tensor, sum_rollout_entropy) dist.all_reduce(sum_rollout_entropy, op=dist.ReduceOp.SUM) - avg_rollout_entropy = sum_rollout_entropy / global_grad_tokens if global_grad_tokens > 0 else 0 + avg_rollout_entropy = ( + sum_rollout_entropy / global_grad_tokens if global_grad_tokens > 0 else torch.tensor(0.0) + ) + worker_log_item["rollout_entropy"] = avg_rollout_entropy.item() logger_msg += f", avg rollout entropy: {avg_rollout_entropy:.4f}" if len(all_mismatch_metrics) > 0: mismatch_metrics = merge_rollout_is_metrics(all_mismatch_metrics, DEVICE) if len(mismatch_metrics) > 0: + worker_log_item["mismatch_metrics"] = mismatch_metrics logger_msg += f"\n rollout mismatch metrics:\n{json.dumps(mismatch_metrics, indent=4)}" if len(all_rollout_is_metrics) > 0: rollout_is_metrics = merge_rollout_is_metrics(all_rollout_is_metrics, DEVICE) if len(rollout_is_metrics) > 0: + worker_log_item["rollout_is_metrics"] = rollout_is_metrics logger_msg += f"\n rollout importance sampling metrics:\n{json.dumps(rollout_is_metrics, indent=4)}" self.logger.info(logger_msg) @@ -527,6 +550,8 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int): other_log = self._update_other_log(other_log) # type: ignore[arg-type] grad_norm = self._engine.clip_grad_norm() self._engine.step_optimizer(grad_norm) + worker_log_item["train_metrics"].append(WorkerTrainLogItem(loss_log=loss_log, other_log=other_log)) + log_info = dict() # type: ignore[var-annotated] log_info.update(loss_log) for k, v in other_log.items(): @@ -543,6 +568,8 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int): log_str = f"Rollout {rollout_idx} Step {i}: " + log_str self.logger.info(log_str) + return worker_log_item + @ray_method def save_hf(self, hf_dir: str, save_dtype: torch.dtype = torch.bfloat16): self._engine.save_hf(hf_dir, save_dtype) diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index a46412d51..8cf34d98c 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -4,7 +4,7 @@ from datetime import datetime from pathlib import Path from shutil import rmtree -from typing import cast +from typing import List, cast import ray import torch @@ -16,8 +16,10 @@ from typing_extensions import Self from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast +from xtuner.v1._writer import TensorboardWriter from xtuner.v1.data_proto.rl_data import is_valid_for_training from xtuner.v1.data_proto.sequence_context import SequenceContext +from xtuner.v1.engine.train_engine import LossLog, OtherLog from xtuner.v1.patch import patch_default_save_plan from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers, AutoCPUWorkers, CPUResourcesConfig from xtuner.v1.ray.config.worker import RolloutConfig @@ -31,6 +33,7 @@ TrainingWorkerClass, TrainingWorkerProxy, WorkerConfig, + WorkerLogItem, ) from xtuner.v1.rl.base import TrainingWorker as BaseTrainingWorker from xtuner.v1.train import ResumeConfig @@ -395,6 +398,8 @@ def __init__( with env_path.open("w") as f: json.dump(environment_variables, f, indent=2) + self._writer = TensorboardWriter(log_dir / "tb") + def _resolve_load_checkpoint_cfg( self, auto_resume: bool, load_checkpoint_cfg: LoadCheckpointConfig ) -> LoadCheckpointConfig: @@ -482,22 +487,36 @@ def _initial_evaluate(self): ray.get(self._rollout_env_controller.update_active_workers.remote()) scores, eval_data_groups = ray.get(self._evaluator.run.remote(return_samples=True)) trajectory_save_path = self.exp_dir / "eval_0_trajectory.jsonl" - self._save_trajectories(eval_data_groups, trajectory_save_path) + self._save_trajectories(eval_data_groups, trajectory_save_path, is_eval=True) self.logger.info(f"Initial rollout evaluate scores {scores} and start training") + tb_scores = {f"eval/{k}": v for k, v in scores.items()} + self._writer.add_scalars( + tag_scalar_dict=tb_scores, + global_step=0, + ) def _rollout_step(self, rollout_idx: int, step_timer_dict: dict): """Performs a single rollout step to generate experience.""" with timer("generation", step_timer_dict): ray.get(self._rollout_env_controller.update_active_workers.remote()) data_groups, multimodal_train_infos = ray.get(self._rollout_dataflow.run.remote()) + self._writer.add_scalar( + tag="time/generation", scalar_value=step_timer_dict["generation"], global_step=rollout_idx + ) with timer("save_trajectory", step_timer_dict): trajectory_save_path = self.exp_dir / f"rollout_idx_{rollout_idx}_trajectory.jsonl" self._save_trajectories(data_groups, trajectory_save_path) self.logger.info(f"Rollout_idx {rollout_idx} finished, saved trajectories to {trajectory_save_path}") + self._writer.add_scalar( + tag="time/save_trajectory", scalar_value=step_timer_dict["save_trajectory"], global_step=rollout_idx + ) if not self._debug_rollout: with timer("rollout_offload", step_timer_dict): ray.get(self._rollout_dataflow.pause.remote()) ray.get(self._rollout_env_controller.offload.remote()) + self._writer.add_scalar( + tag="time/rollout_offload", scalar_value=step_timer_dict["rollout_offload"], global_step=rollout_idx + ) return data_groups, multimodal_train_infos def _train_step(self, rollout_idx: int, data_groups, multimodal_train_infos, step_timer_dict: dict): @@ -516,12 +535,72 @@ def _train_step(self, rollout_idx: int, data_groups, multimodal_train_infos, ste self.logger.info(f"Prepared {len(data_batches)} training data batches") self._log_data_info(rollout_idx, data_info) + self._writer.add_scalar( + tag="time/onload", + scalar_value=step_timer_dict["onload"], + global_step=rollout_idx, + ) + + self._writer.add_scalar( + tag="time/prepare_data", + scalar_value=step_timer_dict["prepare_data"], + global_step=rollout_idx, + ) + with timer("training", step_timer_dict): - ray.get( + workers_log_item: List[WorkerLogItem] = ray.get( self._train_controller.fit.remote( data_batches, pack_max_length=self._train_worker_cfg.pack_max_length, rollout_idx=rollout_idx ) ) + self._writer.add_scalar(tag="time/training", scalar_value=step_timer_dict["training"], global_step=rollout_idx) + + rank0_log_item = workers_log_item[0] + rank0_rollout_is_metrics = rank0_log_item.get("rollout_is_metrics") + rank0_mismatch_metrics = rank0_log_item.get("mismatch_metrics") + rank0_rollout_entropy = rank0_log_item.get("rollout_entropy") + # These metrics are already aggregated across distributed workers and logging only the metrics from rank 0. + if rank0_rollout_is_metrics is not None: + tb_rollout_is_metrics = {f"rollout_is/{k}": v for k, v in rank0_rollout_is_metrics.items()} + self._writer.add_scalars(tag_scalar_dict=tb_rollout_is_metrics, global_step=rollout_idx) + if rank0_mismatch_metrics is not None: + tb_mismatch_metrics = {f"mismatch/{k}": v for k, v in rank0_mismatch_metrics.items()} + self._writer.add_scalars(tag_scalar_dict=tb_mismatch_metrics, global_step=rollout_idx) + if rank0_rollout_entropy is not None: + tb_rollout_entropy = {"entropy/rollout": rank0_rollout_entropy} + self._writer.add_scalars(tag_scalar_dict=tb_rollout_entropy, global_step=rollout_idx) + tb_entropy = {"entropy/train": rank0_log_item["train_entropy"]} + self._writer.add_scalars(tag_scalar_dict=tb_entropy, global_step=rollout_idx) + + for worker_idx, log_item in enumerate(workers_log_item): + mini_batch_metrics: dict[str, List[float]] = {} + for mini_batch_log in log_item["train_metrics"]: + loss_log: LossLog = mini_batch_log["loss_log"] + other_log: OtherLog = mini_batch_log["other_log"] + # Aggregate logs for the mini-batch + for k, v in loss_log.items(): + v = v.item() if isinstance(v, torch.Tensor) else v + v = cast(float, v) + mini_batch_metrics.setdefault(k, []).append(v) + + for k, v in other_log.items(): + if k == "extra_info" and isinstance(v, dict): + for extra_k, extra_v in v.items(): + extra_v = extra_v.item() if isinstance(extra_v, torch.Tensor) else extra_v + mini_batch_metrics.setdefault(extra_k, []).append(extra_v) + else: + v = v.item() if isinstance(v, torch.Tensor) else v + v = cast(float, v) + mini_batch_metrics.setdefault(k, []).append(v) + + for key, value in mini_batch_metrics.items(): + for i, v in enumerate(value): + global_step = rollout_idx * len(value) + i + self._writer.add_scalar( + tag=f"train_metrics/worker_{worker_idx}/{key}", + scalar_value=v, + global_step=global_step, + ) def _sync_weights_and_save(self, rollout_idx: int, step_timer_dict: dict): """Synchronizes weights and saves checkpoints.""" @@ -538,14 +617,30 @@ def _sync_weights_and_save(self, rollout_idx: int, step_timer_dict: dict): ray.get(self._train_controller.offload.remote(target="model")) ray.get(self._rollout_env_controller.onload_kvcache.remote()) + self._writer.add_scalar( + tag="time/save_ckpt", + scalar_value=step_timer_dict["save_ckpt"], + global_step=rollout_idx, + ) + self._writer.add_scalar( + tag="time/sync_weight", + scalar_value=step_timer_dict["sync_weight"], + global_step=rollout_idx, + ) + def _evaluate_step(self, rollout_idx: int, step_timer_dict: dict): """Performs an evaluation step.""" if self._enable_evaluate and self._evaluator and rollout_idx % self._eval_step == 0: with timer("evaluation", step_timer_dict): scores, eval_data_groups = ray.get(self._evaluator.run.remote(return_samples=True)) trajectory_save_path = self.exp_dir / f"eval_{rollout_idx}_trajectory.jsonl" - self._save_trajectories(eval_data_groups, trajectory_save_path) + self._save_trajectories(eval_data_groups, trajectory_save_path, is_eval=True) self.logger.info(f"Evaluate idx {rollout_idx} scores {scores}") + tb_scores = {f"eval/{k}": v for k, v in scores.items()} + self._writer.add_scalars( + tag_scalar_dict=tb_scores, + global_step=rollout_idx, + ) def fit(self): """Run the RL training loop. @@ -579,6 +674,11 @@ def fit(self): self._evaluate_step(rollout_idx, step_timer_dict) # 5. Log timing information + self._writer.add_scalar( + tag="time/step", + scalar_value=step_timer_dict["step"], + global_step=rollout_idx, + ) timer_log_str = f"Rollout {rollout_idx} training finished and timing listed: \n" timer_log_str += timer_logger(step_timer_dict) self.logger.info(timer_log_str) @@ -697,7 +797,7 @@ def _prepare_train_data(self, data_groups, pack_max_length, multimodal_train_inf } return data_batches, info_dict - def _save_trajectories(self, data_groups, save_path): + def _save_trajectories(self, data_groups, save_path, is_eval: bool = False): rewards = [] rollout_response_len_list = [] @@ -750,6 +850,12 @@ def _save_trajectories(self, data_groups, save_path): } json.dump(item, f, ensure_ascii=False, indent=2) f.write("\n") + tb_prefix = "eval" if is_eval else "response" + tb_item = {f"{tb_prefix}/{k}": v for k, v in item.items() if isinstance(v, (int, float))} + self._writer.add_scalars( + tag_scalar_dict=tb_item, + global_step=self._cur_step, + ) for group in data_groups: if not is_valid_for_training(group): self.logger.error(f"Skip one data group {group} due to rollout failed or empty response.") From 50d025f15bb67b51b8f53844606a0ced974c6227 Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Mon, 29 Dec 2025 17:33:33 +0800 Subject: [PATCH 3/4] fix comments --- xtuner/v1/rl/base/worker.py | 49 +++++++++++++++++++++++------------ xtuner/v1/train/rl_trainer.py | 28 +++++--------------- 2 files changed, 39 insertions(+), 38 deletions(-) diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index 11ab5bbc0..8eb181904 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -142,9 +142,19 @@ class WorkerInputItem(TypedDict): rollout_logprobs: torch.Tensor | None +class RLOtherLog(TypedDict): + maxvio: NotRequired[float] + consumed_tokens: float + consumed_img_tokens: NotRequired[float] + efficient_attn_ratio: float + max_ratio: NotRequired[float] + loss: NotRequired[float] + grad_norm: NotRequired[float] + + class WorkerTrainLogItem(TypedDict): loss_log: LossLog - other_log: OtherLog + rl_other_log: RLOtherLog class WorkerLogItem(TypedDict): @@ -313,17 +323,29 @@ def compute_ref_logprobs( self._ref_model.to_device("cpu") return loss_ctx_input_list - def _update_other_log(self, other_log: dict): + def _update_other_log(self, other_log: OtherLog) -> RLOtherLog: from xtuner.v1.model.utils import ModelForwardExtraLogInfo - extra_info = other_log.get("extra_info", {}) + extra_info: ModelForwardExtraLogInfo | dict = other_log.get("extra_info", {}) if isinstance(extra_info, ModelForwardExtraLogInfo): extra_info_dict = extra_info.get() else: extra_info_updated = ModelForwardExtraLogInfo(extra_info) extra_info_dict = extra_info_updated.get() - other_log["extra_info"] = extra_info_dict - return other_log + + for k, v in extra_info_dict.items(): + if isinstance(v, torch.Tensor): + extra_info_dict[k] = v.item() + + rl_other_log: RLOtherLog = { + "maxvio": other_log.get("maxvio", 0.0), + "consumed_tokens": other_log.get("consumed_tokens", 0.0), + "consumed_img_tokens": other_log.get("consumed_img_tokens", 0.0), + "efficient_attn_ratio": other_log.get("efficient_attn_ratio", 0.0), + "max_ratio": extra_info_dict.get("max_ratio", 0.0), + "loss": extra_info_dict.get("loss", 0.0), + } + return rl_other_log @ray_method def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLogItem: @@ -547,20 +569,13 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLo loss_log, other_log = self._engine.train_step( data_batches=engine_input, ) - other_log = self._update_other_log(other_log) # type: ignore[arg-type] grad_norm = self._engine.clip_grad_norm() self._engine.step_optimizer(grad_norm) - worker_log_item["train_metrics"].append(WorkerTrainLogItem(loss_log=loss_log, other_log=other_log)) - - log_info = dict() # type: ignore[var-annotated] - log_info.update(loss_log) - for k, v in other_log.items(): - if k == "extra_info": - for extra_k, extra_v in v.items(): - log_info[extra_k] = extra_v.item() if isinstance(extra_v, torch.Tensor) else extra_v - else: - log_info[k] = v.item() if isinstance(v, torch.Tensor) else v - log_info["grad_norm"] = grad_norm.item() + rl_other_log = self._update_other_log(other_log) # type: ignore[arg-type] + rl_other_log["grad_norm"] = grad_norm.item() + worker_log_item["train_metrics"].append(WorkerTrainLogItem(loss_log=loss_log, rl_other_log=rl_other_log)) + + log_info = {**loss_log, **rl_other_log} log_str = ", ".join( f"{key}={value:.4f}" if isinstance(value, float) else f"{key}={value}" for key, value in log_info.items() diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index 8cf34d98c..f07af0714 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -19,7 +19,6 @@ from xtuner.v1._writer import TensorboardWriter from xtuner.v1.data_proto.rl_data import is_valid_for_training from xtuner.v1.data_proto.sequence_context import SequenceContext -from xtuner.v1.engine.train_engine import LossLog, OtherLog from xtuner.v1.patch import patch_default_save_plan from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers, AutoCPUWorkers, CPUResourcesConfig from xtuner.v1.ray.config.worker import RolloutConfig @@ -556,15 +555,15 @@ def _train_step(self, rollout_idx: int, data_groups, multimodal_train_infos, ste self._writer.add_scalar(tag="time/training", scalar_value=step_timer_dict["training"], global_step=rollout_idx) rank0_log_item = workers_log_item[0] + # These metrics are already aggregated across distributed workers and logging only the metrics from rank 0. rank0_rollout_is_metrics = rank0_log_item.get("rollout_is_metrics") rank0_mismatch_metrics = rank0_log_item.get("mismatch_metrics") rank0_rollout_entropy = rank0_log_item.get("rollout_entropy") - # These metrics are already aggregated across distributed workers and logging only the metrics from rank 0. if rank0_rollout_is_metrics is not None: tb_rollout_is_metrics = {f"rollout_is/{k}": v for k, v in rank0_rollout_is_metrics.items()} self._writer.add_scalars(tag_scalar_dict=tb_rollout_is_metrics, global_step=rollout_idx) if rank0_mismatch_metrics is not None: - tb_mismatch_metrics = {f"mismatch/{k}": v for k, v in rank0_mismatch_metrics.items()} + tb_mismatch_metrics = {f"{k}": v for k, v in rank0_mismatch_metrics.items()} self._writer.add_scalars(tag_scalar_dict=tb_mismatch_metrics, global_step=rollout_idx) if rank0_rollout_entropy is not None: tb_rollout_entropy = {"entropy/rollout": rank0_rollout_entropy} @@ -575,27 +574,14 @@ def _train_step(self, rollout_idx: int, data_groups, multimodal_train_infos, ste for worker_idx, log_item in enumerate(workers_log_item): mini_batch_metrics: dict[str, List[float]] = {} for mini_batch_log in log_item["train_metrics"]: - loss_log: LossLog = mini_batch_log["loss_log"] - other_log: OtherLog = mini_batch_log["other_log"] + rl_worker_log = {**mini_batch_log["loss_log"], **mini_batch_log["rl_other_log"]} # Aggregate logs for the mini-batch - for k, v in loss_log.items(): - v = v.item() if isinstance(v, torch.Tensor) else v - v = cast(float, v) - mini_batch_metrics.setdefault(k, []).append(v) - - for k, v in other_log.items(): - if k == "extra_info" and isinstance(v, dict): - for extra_k, extra_v in v.items(): - extra_v = extra_v.item() if isinstance(extra_v, torch.Tensor) else extra_v - mini_batch_metrics.setdefault(extra_k, []).append(extra_v) - else: - v = v.item() if isinstance(v, torch.Tensor) else v - v = cast(float, v) - mini_batch_metrics.setdefault(k, []).append(v) + for k, v in rl_worker_log.items(): + mini_batch_metrics.setdefault(k, []).append(cast(float, v)) for key, value in mini_batch_metrics.items(): for i, v in enumerate(value): - global_step = rollout_idx * len(value) + i + global_step = (rollout_idx - 1) * len(value) + i + 1 self._writer.add_scalar( tag=f"train_metrics/worker_{worker_idx}/{key}", scalar_value=v, @@ -851,7 +837,7 @@ def _save_trajectories(self, data_groups, save_path, is_eval: bool = False): json.dump(item, f, ensure_ascii=False, indent=2) f.write("\n") tb_prefix = "eval" if is_eval else "response" - tb_item = {f"{tb_prefix}/{k}": v for k, v in item.items() if isinstance(v, (int, float))} + tb_item = {f"{tb_prefix}/{k}": v for k, v in item.items()} self._writer.add_scalars( tag_scalar_dict=tb_item, global_step=self._cur_step, From 7abdb3ce0cf48092be0c560e7daf8ed166f4e46c Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Mon, 29 Dec 2025 17:49:51 +0800 Subject: [PATCH 4/4] replace List with list --- xtuner/v1/rl/base/controller.py | 4 ++-- xtuner/v1/rl/base/worker.py | 14 +++++++------- xtuner/v1/train/rl_trainer.py | 10 +++++----- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/xtuner/v1/rl/base/controller.py b/xtuner/v1/rl/base/controller.py index 0346c899a..e22f2f650 100644 --- a/xtuner/v1/rl/base/controller.py +++ b/xtuner/v1/rl/base/controller.py @@ -1,5 +1,5 @@ import math -from typing import List, Literal, TypedDict +from typing import Literal, TypedDict import ray import torch @@ -165,7 +165,7 @@ def _grouped_by_max_length(self, packed_data_batches): return sorted(packed_data_batches, key=lambda x: x["seq_ctx"].max_length_q, reverse=True) @ray_method - def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx: int) -> List[WorkerLogItem]: + def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx: int) -> list[WorkerLogItem]: has_rollout_routed_experts = False language_cfg = None if data_batches[0]["seq_ctx"].rollout_routed_experts is not None: diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index a050473e1..3d226cf2a 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -144,8 +144,8 @@ class WorkerInputItem(TypedDict): class RLOtherLog(TypedDict): maxvio: NotRequired[float] - consumed_tokens: float - consumed_img_tokens: NotRequired[float] + step_consumed_tokens: int + step_consumed_img_tokens: NotRequired[float] efficient_attn_ratio: float max_ratio: NotRequired[float] loss: NotRequired[float] @@ -323,7 +323,7 @@ def compute_ref_logprobs( self._ref_model.to_device("cpu") return loss_ctx_input_list - def _update_other_log(self, other_log: OtherLog) -> RLOtherLog: + def _get_rl_other_log(self, other_log: OtherLog) -> RLOtherLog: from xtuner.v1.model.utils import ModelForwardExtraLogInfo extra_info: ModelForwardExtraLogInfo | dict = other_log.get("extra_info", {}) @@ -339,9 +339,9 @@ def _update_other_log(self, other_log: OtherLog) -> RLOtherLog: rl_other_log: RLOtherLog = { "maxvio": other_log.get("maxvio", 0.0), - "consumed_tokens": other_log.get("consumed_tokens", 0.0), - "consumed_img_tokens": other_log.get("consumed_img_tokens", 0.0), - "efficient_attn_ratio": other_log.get("efficient_attn_ratio", 0.0), + "step_consumed_tokens": other_log["step_consumed_tokens"], + "step_consumed_img_tokens": float(other_log.get("step_consumed_img_tokens", 0.0)), + "efficient_attn_ratio": other_log["efficient_attn_ratio"], "max_ratio": extra_info_dict.get("max_ratio", 0.0), "loss": extra_info_dict.get("loss", 0.0), } @@ -569,7 +569,7 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLo ) grad_norm = self._engine.clip_grad_norm() self._engine.step_optimizer(grad_norm) - rl_other_log = self._update_other_log(other_log) # type: ignore[arg-type] + rl_other_log = self._get_rl_other_log(other_log) # type: ignore[arg-type] rl_other_log["grad_norm"] = grad_norm.item() worker_log_item["train_metrics"].append(WorkerTrainLogItem(loss_log=loss_log, rl_other_log=rl_other_log)) diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index f07af0714..3c507f8fa 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -486,7 +486,7 @@ def _initial_evaluate(self): ray.get(self._rollout_env_controller.update_active_workers.remote()) scores, eval_data_groups = ray.get(self._evaluator.run.remote(return_samples=True)) trajectory_save_path = self.exp_dir / "eval_0_trajectory.jsonl" - self._save_trajectories(eval_data_groups, trajectory_save_path, is_eval=True) + self._save_trajectories(eval_data_groups, trajectory_save_path, 0, is_eval=True) self.logger.info(f"Initial rollout evaluate scores {scores} and start training") tb_scores = {f"eval/{k}": v for k, v in scores.items()} self._writer.add_scalars( @@ -504,7 +504,7 @@ def _rollout_step(self, rollout_idx: int, step_timer_dict: dict): ) with timer("save_trajectory", step_timer_dict): trajectory_save_path = self.exp_dir / f"rollout_idx_{rollout_idx}_trajectory.jsonl" - self._save_trajectories(data_groups, trajectory_save_path) + self._save_trajectories(data_groups, trajectory_save_path, rollout_idx) self.logger.info(f"Rollout_idx {rollout_idx} finished, saved trajectories to {trajectory_save_path}") self._writer.add_scalar( tag="time/save_trajectory", scalar_value=step_timer_dict["save_trajectory"], global_step=rollout_idx @@ -620,7 +620,7 @@ def _evaluate_step(self, rollout_idx: int, step_timer_dict: dict): with timer("evaluation", step_timer_dict): scores, eval_data_groups = ray.get(self._evaluator.run.remote(return_samples=True)) trajectory_save_path = self.exp_dir / f"eval_{rollout_idx}_trajectory.jsonl" - self._save_trajectories(eval_data_groups, trajectory_save_path, is_eval=True) + self._save_trajectories(eval_data_groups, trajectory_save_path, rollout_idx, is_eval=True) self.logger.info(f"Evaluate idx {rollout_idx} scores {scores}") tb_scores = {f"eval/{k}": v for k, v in scores.items()} self._writer.add_scalars( @@ -783,7 +783,7 @@ def _prepare_train_data(self, data_groups, pack_max_length, multimodal_train_inf } return data_batches, info_dict - def _save_trajectories(self, data_groups, save_path, is_eval: bool = False): + def _save_trajectories(self, data_groups, save_path, rollout_idx, is_eval: bool = False): rewards = [] rollout_response_len_list = [] @@ -840,7 +840,7 @@ def _save_trajectories(self, data_groups, save_path, is_eval: bool = False): tb_item = {f"{tb_prefix}/{k}": v for k, v in item.items()} self._writer.add_scalars( tag_scalar_dict=tb_item, - global_step=self._cur_step, + global_step=rollout_idx, ) for group in data_groups: if not is_valid_for_training(group):