diff --git a/scripts/config/alfworld.yaml b/scripts/config/alfworld.yaml index fbb984d243..8af91b763d 100644 --- a/scripts/config/alfworld.yaml +++ b/scripts/config/alfworld.yaml @@ -3,8 +3,8 @@ data: batch_size: 4 dataset_path: 'scripts/data_prepare/alfworld_data' default_workflow_type: 'alfworld_workflow' - dataset_config: - split: 'train' + train_split: 'train' + eval_split: '' format_config: prompt_key: 'game_file' model: diff --git a/scripts/config/countdown.yaml b/scripts/config/countdown.yaml index 23df47aba9..3c523b1e52 100644 --- a/scripts/config/countdown.yaml +++ b/scripts/config/countdown.yaml @@ -3,8 +3,8 @@ data: batch_size: 96 dataset_path: 'countdown_dataset/oneshot-split' default_workflow_type: 'math_workflow' - dataset_config: - split: 'train' + train_split: 'train' + eval_split: '' default_reward_fn_type: 'countdown_reward' format_config: prompt_key: 'question' diff --git a/scripts/config/gsm8k.yaml b/scripts/config/gsm8k.yaml index 182a0166d0..629d2d13d8 100644 --- a/scripts/config/gsm8k.yaml +++ b/scripts/config/gsm8k.yaml @@ -1,8 +1,8 @@ data: # basic info dataset_path: '/PATH/TO/DATASET/' - dataset_config: - split: 'train' + train_split: 'train' + eval_split: '' format_config: prompt_key: 'question' response_key: 'answer' @@ -70,6 +70,7 @@ trainer: algorithm_type: ppo trainer_config_path: 'scripts/config/train_gsm8k.yaml' sft_warmup_iteration: 0 # Set to integer to enable sft warmup + eval_interval: 50 monitor: cache_root_dir: "" project: "Trinity-RFT-gsm8k" diff --git a/scripts/config/webshop.yaml b/scripts/config/webshop.yaml index e8cbaeb165..43f5107723 100644 --- a/scripts/config/webshop.yaml +++ b/scripts/config/webshop.yaml @@ -3,8 +3,8 @@ data: batch_size: 4 dataset_path: 'scripts/data_prepare/webshop_data' default_workflow_type: 'webshop_workflow' - dataset_config: - split: 'train' + train_split: 'train' + eval_split: '' format_config: prompt_key: 'task_id' model: diff --git a/tests/common/tmp/template_config.yaml b/tests/common/tmp/template_config.yaml index f218531ef3..0394f89487 100644 --- a/tests/common/tmp/template_config.yaml +++ b/tests/common/tmp/template_config.yaml @@ -3,7 +3,8 @@ data: dataset_path: '' total_epoch: 1 batch_size: 1 - split: train + train_split: 'train' + eval_split: '' default_workflow_type: '' default_reward_fn_type: '' dataset_config: {} diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index daebabd4d1..18bdacb174 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -19,8 +19,7 @@ def explore(config: Config) -> None: try: ray.get(explorer.prepare.remote()) ray.get(explorer.sync_weight.remote()) - ref, _ = ray.wait([explorer.explore.remote()]) - ray.get(ref) + ray.get(explorer.explore.remote()) logger.info("Explore finished.") except Exception as e: logger.error(f"Explore failed: {e}") @@ -34,8 +33,7 @@ def train(config: Config) -> None: trainer = Trainer.remote(config) try: ray.get(trainer.prepare.remote()) - ref, _ = ray.wait([trainer.train.remote(algo_type)]) - ray.get(ref) + ray.get(trainer.train.remote(algo_type)) logger.info("Train finished.") except Exception as e: logger.error(f"Train failed {e}.") @@ -67,20 +65,21 @@ def both(config: Config) -> None: if config.trainer.sft_warmup_iteration > 0: for step in range(config.trainer.sft_warmup_iteration): - ray.get([trainer.train_step.remote(AlgorithmType.SFT)]) + ray.get(trainer.train_step.remote(AlgorithmType.SFT)) logger.info(f"SFT warmup step {step} finished.") ray.get([explorer.sync_weight.remote(), trainer.sync_weight.remote()]) algo_type = config.trainer.algorithm_type - global_iter_num = 0 while True: try: - explore_continue = explorer.explore_step.remote() - train_continue = trainer.train_step.remote(algo_type) - if not ray.get(explore_continue): + ref_explore = explorer.explore_step.remote() + ref_train = trainer.train_step.remote(algo_type) + explore_continue, _ = ray.get(ref_explore) + train_continue, train_iter_num = ray.get(ref_train) + if not explore_continue: logger.info("Explorer finished, stopping...") break - if not ray.get(train_continue): + if not train_continue: logger.info("Trainer finished, stopping...") break ray.get([explorer.sync_weight.remote(), trainer.sync_weight.remote()]) @@ -89,10 +88,14 @@ def both(config: Config) -> None: logger.error(e) logger.error("Training stopped due to exception.") raise e - global_iter_num += 1 - if global_iter_num % config.trainer.eval_interval == 0: - ray.wait([explorer.eval.remote()]) - logger.info("Eval step finished.") + if (train_iter_num - 1) % config.trainer.eval_interval == 0: + try: + ray.get(explorer.eval.remote(train_iter_num)) + logger.info("Evaluation finished.") + except Exception as e: + logger.error(e) + logger.error("Evaluation failed.") + raise e def main() -> None: diff --git a/trinity/common/config.py b/trinity/common/config.py index 93f1ec94b3..b9dc6bcad5 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -306,6 +306,16 @@ def check_and_update(self) -> None: self.synchronizer.backend = self.explorer.backend if self.synchronizer.sync_method == "online" and self.mode != "both": raise ValueError("Online synchronization is only supported in both mode") + + # check eval_interval + if self.trainer.eval_interval % self.synchronizer.sync_iteration_interval != 0: + self.trainer.eval_interval = ( + self.trainer.eval_interval // self.synchronizer.sync_iteration_interval + ) * self.synchronizer.sync_iteration_interval + print( + f"Warning: eval_interval is not a multiple of sync_iteration_interval; adjusted to the nearest integer={self.trainer.eval_interval}." + ) + # check monitor if not self.monitor.cache_root_dir: # create a cache dir in /.cache diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index 63ddab3a71..f60412df13 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -58,7 +58,8 @@ def run(self) -> List[Experience]: else: messages = [{"role": "user", "content": self.task_desc}] logger.debug("start chat") - responses = self.model.chat(messages, n=self.repeat_times) + n = 1 if self.is_eval else self.repeat_times + responses = self.model.chat(messages, n=n) for response in responses: reward = self.reward_fn( # type: ignore [misc] response=response.response_text, # type: ignore [arg-type] @@ -69,9 +70,9 @@ def run(self) -> List[Experience]: f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}" ) if isinstance(reward, dict): - if response.info is None: - response.info = {} - response.info.update(reward) + if response.metrics is None: + response.metrics = {} + response.metrics.update(reward) reward = sum(reward.values()) response.reward = reward return responses diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 1ee219c0df..6beac2b049 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -3,7 +3,7 @@ import os import time from collections import defaultdict -from typing import List, Optional +from typing import List, Optional, Tuple import ray import torch @@ -149,16 +149,20 @@ def get_weight(self, name: str) -> torch.Tensor: def explore(self) -> None: """Explore the entire dataset.""" - while self.explore_step(): + explore_status, _ = self.explore_step() + while explore_status: self.sync_weight() self.logger.info("Explorer finished.") - def explore_step(self) -> bool: + def explore_step(self) -> Tuple[bool, int]: """Explore for one step. Different from `explore()` which consumes all tasks in the task set, `explore_step()` only consume `sync_iteration_interval * batch_size` number of tasks. + explore_status: + explore_status: whether there are more tasks to explore. + explore_iter_num: the number of explore iterations """ if self.task_iter is None: self.task_iter = iter(self.taskset) @@ -175,7 +179,7 @@ def explore_step(self) -> bool: self.runner_pool.run_tasks(tasks) except StopIteration: self.logger.warning("No more tasks in the task set. Stop exploring.") - return False + return False, self.iteration # wait for all tasks of this step to finish while self.runner_pool.has_next(): @@ -190,7 +194,7 @@ def explore_step(self) -> bool: self.runner_pool.run_tasks(next(self.task_iter)) # type: ignore except StopIteration: self.logger.warning("No more tasks in the task set. Stop exploring.") - return False + return False, self.iteration else: for metric_name, metric_value in status.metric.items(): all_metrics[metric_name].append(metric_value) @@ -208,11 +212,11 @@ def explore_step(self) -> bool: ) self.logger.info("Explore step finished.") - return True + return True, self.iteration - def eval(self) -> bool: + def eval(self, step) -> bool: """Evaluation on all evaluation data samples.""" - self.logger.info("\n\nEvaluation started.\n\n") + self.logger.info("Evaluation started.") st = time.time() all_metrics = defaultdict(list) @@ -231,11 +235,9 @@ def eval(self) -> bool: for metric_name, metric_value in status.metric.items(): all_metrics[metric_name].append(metric_value) - self.logger.info("Evaluation finished.") - log_metrics = self.monitor.calculate_metrics(all_metrics, prefix="eval") # type: ignore log_metrics["eval/total_time"] = time.time() - st - self.monitor.log(log_metrics, step=self.iteration) # type: ignore + self.monitor.log(log_metrics, step=step) # type: ignore return True def sync_weight(self) -> None: diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index df20e02d72..dd5d184fad 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -7,6 +7,7 @@ Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. """ from abc import ABC, abstractmethod +from typing import Tuple import ray @@ -45,18 +46,23 @@ def prepare(self) -> None: def train(self, algo_type: AlgorithmType = AlgorithmType.PPO): """Train the model.""" while True: - if not self.train_iteration(algo_type): + train_status, _ = self.train_iteration(algo_type) + if not train_status: break - def train_step(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> bool: - """Train one step. Each step contains `sync_iteration_interval` iteration.""" + def train_step(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool, int]: + """Train one step. Each step contains `sync_iteration_interval` iteration. + Returns: + train_status: Whether to continue training. + train_iter_num: The number of training iterations""" for _ in range(self.config.synchronizer.sync_iteration_interval): - if not self.train_iteration(algo_type): - return False + train_status, train_iter_num = self.train_iteration(algo_type) + if not train_status: + return False, train_iter_num self.logger.info("Trainer finished.") - return True + return True, train_iter_num - def train_iteration(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> bool: + def train_iteration(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool, int]: """Train one iteration. Args: @@ -108,15 +114,15 @@ def prepare(self) -> None: """Do some preparation before training started.""" @abstractmethod - def train_rft_iteration(self, experiences) -> bool: + def train_rft_iteration(self, experiences) -> Tuple[bool, int]: """Train on the RFT data.""" @abstractmethod - def train_sft_iteration(self, experiences) -> bool: + def train_sft_iteration(self, experiences) -> Tuple[bool, int]: """Train on the SFT data.""" @abstractmethod - def train_dpo_iteration(self, experiences) -> bool: + def train_dpo_iteration(self, experiences) -> Tuple[bool, int]: """Train on the DPO data.""" @abstractmethod diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index a8e39a1f97..6e68ba894c 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -4,6 +4,7 @@ Modified from verl/trainer/ppo/ray_trainer.py """ import os +from typing import Tuple import pandas as pd import ray @@ -182,7 +183,7 @@ def _create_dataloader(self): # else: self.total_training_steps = float("inf") - def train_dpo_iteration(self, experiences: Experiences) -> bool: + def train_dpo_iteration(self, experiences: Experiences) -> Tuple[bool, int]: metrics = {} timing_raw = {} @@ -243,9 +244,9 @@ def train_dpo_iteration(self, experiences: Experiences) -> bool: self._save_checkpoint() self.global_steps += 1 - return True + return True, self.global_steps - def train_sft_iteration(self, experiences: Experiences) -> bool: + def train_sft_iteration(self, experiences: Experiences) -> Tuple[bool, int]: metrics = {} timing_raw = {} @@ -309,9 +310,9 @@ def train_sft_iteration(self, experiences: Experiences) -> bool: with _timer("save_checkpoint", timing_raw): self._save_checkpoint() self.global_steps += 1 - return True + return True, self.global_steps - def train_rft_iteration(self, experiences: Experiences) -> bool: + def train_rft_iteration(self, experiences: Experiences) -> Tuple[bool, int]: metrics = {} timing_raw = {} @@ -456,10 +457,10 @@ def train_rft_iteration(self, experiences: Experiences) -> bool: with _timer("save_checkpoint", timing_raw): self._save_checkpoint() # stop training - return False + return False, self.global_steps else: # continue - return True + return True, self.global_steps def _log_single_experience( self, experiences: Experiences, idx: int, skip_special_tokens: bool