From 07a643687586cc5ccd11dd2430e022c2d6caa881 Mon Sep 17 00:00:00 2001 From: yuchang Date: Mon, 21 Apr 2025 12:19:20 +0800 Subject: [PATCH 1/3] fix some bugs for eval --- scripts/config/alfworld.yaml | 4 ++-- scripts/config/countdown.yaml | 4 ++-- scripts/config/gsm8k.yaml | 5 +++-- scripts/config/webshop.yaml | 4 ++-- tests/common/tmp/template_config.yaml | 3 ++- trinity/cli/launcher.py | 15 ++++++++++----- trinity/common/config.py | 8 ++++++++ trinity/common/workflows/workflow.py | 9 +++++---- trinity/explorer/explorer.py | 8 +++----- trinity/trainer/trainer.py | 3 +++ trinity/trainer/verl_trainer.py | 3 +++ 11 files changed, 43 insertions(+), 23 deletions(-) diff --git a/scripts/config/alfworld.yaml b/scripts/config/alfworld.yaml index fbb984d243..8af91b763d 100644 --- a/scripts/config/alfworld.yaml +++ b/scripts/config/alfworld.yaml @@ -3,8 +3,8 @@ data: batch_size: 4 dataset_path: 'scripts/data_prepare/alfworld_data' default_workflow_type: 'alfworld_workflow' - dataset_config: - split: 'train' + train_split: 'train' + eval_split: '' format_config: prompt_key: 'game_file' model: diff --git a/scripts/config/countdown.yaml b/scripts/config/countdown.yaml index 23df47aba9..3c523b1e52 100644 --- a/scripts/config/countdown.yaml +++ b/scripts/config/countdown.yaml @@ -3,8 +3,8 @@ data: batch_size: 96 dataset_path: 'countdown_dataset/oneshot-split' default_workflow_type: 'math_workflow' - dataset_config: - split: 'train' + train_split: 'train' + eval_split: '' default_reward_fn_type: 'countdown_reward' format_config: prompt_key: 'question' diff --git a/scripts/config/gsm8k.yaml b/scripts/config/gsm8k.yaml index 182a0166d0..9ea63c665d 100644 --- a/scripts/config/gsm8k.yaml +++ b/scripts/config/gsm8k.yaml @@ -1,8 +1,8 @@ data: # basic info dataset_path: '/PATH/TO/DATASET/' - dataset_config: - split: 'train' + train_split: 'train' + eval_split: '' format_config: prompt_key: 'question' response_key: 'answer' @@ -70,6 +70,7 @@ trainer: algorithm_type: ppo trainer_config_path: 'scripts/config/train_gsm8k.yaml' sft_warmup_iteration: 0 # Set to integer to enable sft warmup + eval_interval: 10 monitor: cache_root_dir: "" project: "Trinity-RFT-gsm8k" diff --git a/scripts/config/webshop.yaml b/scripts/config/webshop.yaml index e8cbaeb165..43f5107723 100644 --- a/scripts/config/webshop.yaml +++ b/scripts/config/webshop.yaml @@ -3,8 +3,8 @@ data: batch_size: 4 dataset_path: 'scripts/data_prepare/webshop_data' default_workflow_type: 'webshop_workflow' - dataset_config: - split: 'train' + train_split: 'train' + eval_split: '' format_config: prompt_key: 'task_id' model: diff --git a/tests/common/tmp/template_config.yaml b/tests/common/tmp/template_config.yaml index f218531ef3..0394f89487 100644 --- a/tests/common/tmp/template_config.yaml +++ b/tests/common/tmp/template_config.yaml @@ -3,7 +3,8 @@ data: dataset_path: '' total_epoch: 1 batch_size: 1 - split: train + train_split: 'train' + eval_split: '' default_workflow_type: '' default_reward_fn_type: '' dataset_config: {} diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index daebabd4d1..8a9ce5b007 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -72,7 +72,6 @@ def both(config: Config) -> None: ray.get([explorer.sync_weight.remote(), trainer.sync_weight.remote()]) algo_type = config.trainer.algorithm_type - global_iter_num = 0 while True: try: explore_continue = explorer.explore_step.remote() @@ -89,10 +88,16 @@ def both(config: Config) -> None: logger.error(e) logger.error("Training stopped due to exception.") raise e - global_iter_num += 1 - if global_iter_num % config.trainer.eval_interval == 0: - ray.wait([explorer.eval.remote()]) - logger.info("Eval step finished.") + train_step_num = ray.get(trainer.get_current_step.remote()) + if (train_step_num - 1) % config.trainer.eval_interval == 0: + ref, _ = ray.wait([explorer.eval.remote(step=train_step_num)]) + try: + ray.get(ref) + logger.info("Evaluation finished.") + except Exception as e: + logger.error(e) + logger.error("Evaluation failed.") + raise e def main() -> None: diff --git a/trinity/common/config.py b/trinity/common/config.py index 93f1ec94b3..601438ebd3 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -306,6 +306,14 @@ def check_and_update(self) -> None: self.synchronizer.backend = self.explorer.backend if self.synchronizer.sync_method == "online" and self.mode != "both": raise ValueError("Online synchronization is only supported in both mode") + + # check eval_interval + if self.trainer.eval_interval % self.synchronizer.sync_iteration_interval != 0: + self.trainer.eval_interval = ( + self.trainer.eval_interval // self.synchronizer.sync_iteration_interval + ) * self.synchronizer.sync_iteration_interval + print(f"Warning: eval_interval is not a multiple of sync_iteration_interval; adjusted to the nearest integer={self.trainer.eval_interval}.") + # check monitor if not self.monitor.cache_root_dir: # create a cache dir in /.cache diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index 63ddab3a71..f60412df13 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -58,7 +58,8 @@ def run(self) -> List[Experience]: else: messages = [{"role": "user", "content": self.task_desc}] logger.debug("start chat") - responses = self.model.chat(messages, n=self.repeat_times) + n = 1 if self.is_eval else self.repeat_times + responses = self.model.chat(messages, n=n) for response in responses: reward = self.reward_fn( # type: ignore [misc] response=response.response_text, # type: ignore [arg-type] @@ -69,9 +70,9 @@ def run(self) -> List[Experience]: f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}" ) if isinstance(reward, dict): - if response.info is None: - response.info = {} - response.info.update(reward) + if response.metrics is None: + response.metrics = {} + response.metrics.update(reward) reward = sum(reward.values()) response.reward = reward return responses diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 1ee219c0df..80817b15f6 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -210,9 +210,9 @@ def explore_step(self) -> bool: self.logger.info("Explore step finished.") return True - def eval(self) -> bool: + def eval(self, step) -> bool: """Evaluation on all evaluation data samples.""" - self.logger.info("\n\nEvaluation started.\n\n") + self.logger.info("Evaluation started.") st = time.time() all_metrics = defaultdict(list) @@ -231,11 +231,9 @@ def eval(self) -> bool: for metric_name, metric_value in status.metric.items(): all_metrics[metric_name].append(metric_value) - self.logger.info("Evaluation finished.") - log_metrics = self.monitor.calculate_metrics(all_metrics, prefix="eval") # type: ignore log_metrics["eval/total_time"] = time.time() - st - self.monitor.log(log_metrics, step=self.iteration) # type: ignore + self.monitor.log(log_metrics, step=step) # type: ignore return True def sync_weight(self) -> None: diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index df20e02d72..8d6b4e0ed8 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -99,6 +99,9 @@ def sync_weight(self) -> None: if self.config.synchronizer.sync_method == "online": self.engine.sync_weight() + def get_current_step(self) -> int: + return self.engine.get_current_step() + class TrainEngineWrapper(ABC): """A wrapper class to wrap various training engines.""" diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index a8e39a1f97..8ba1668118 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -559,3 +559,6 @@ def sft_to_rft(self) -> None: def shutdown(self) -> None: pass + + def get_current_step(self) -> int: + return self.global_steps From e44b20c90fb21707cf4cf8a4b538f2f5f11dd412 Mon Sep 17 00:00:00 2001 From: yuchang Date: Mon, 21 Apr 2025 17:28:26 +0800 Subject: [PATCH 2/3] fix some minor bugs --- trinity/cli/launcher.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 5d4eced566..06af1fa3a6 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -72,10 +72,10 @@ def both(config: Config) -> None: algo_type = config.trainer.algorithm_type while True: try: - ref = explorer.explore_step.remote() - explore_continue, explore_iter_num = ray.get(ref) - ref = trainer.train_step.remote(algo_type) - train_continue, train_iter_num = ray.get(ref) + ref_explore = explorer.explore_step.remote() + ref_train = trainer.train_step.remote(algo_type) + explore_continue, _ = ray.get(ref_explore) + train_continue, train_iter_num = ray.get(ref_train) if not explore_continue: logger.info("Explorer finished, stopping...") break From e5f3c5ff806e290f808548f75151148f53ea4dcd Mon Sep 17 00:00:00 2001 From: yuchang Date: Mon, 21 Apr 2025 17:34:31 +0800 Subject: [PATCH 3/3] fix some typos --- scripts/config/gsm8k.yaml | 2 +- trinity/cli/launcher.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/config/gsm8k.yaml b/scripts/config/gsm8k.yaml index 9ea63c665d..629d2d13d8 100644 --- a/scripts/config/gsm8k.yaml +++ b/scripts/config/gsm8k.yaml @@ -70,7 +70,7 @@ trainer: algorithm_type: ppo trainer_config_path: 'scripts/config/train_gsm8k.yaml' sft_warmup_iteration: 0 # Set to integer to enable sft warmup - eval_interval: 10 + eval_interval: 50 monitor: cache_root_dir: "" project: "Trinity-RFT-gsm8k" diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 06af1fa3a6..18bdacb174 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -19,7 +19,7 @@ def explore(config: Config) -> None: try: ray.get(explorer.prepare.remote()) ray.get(explorer.sync_weight.remote()) - ray.get([explorer.explore.remote()]) + ray.get(explorer.explore.remote()) logger.info("Explore finished.") except Exception as e: logger.error(f"Explore failed: {e}") @@ -33,7 +33,7 @@ def train(config: Config) -> None: trainer = Trainer.remote(config) try: ray.get(trainer.prepare.remote()) - ray.get([trainer.train.remote(algo_type)]) + ray.get(trainer.train.remote(algo_type)) logger.info("Train finished.") except Exception as e: logger.error(f"Train failed {e}.")