diff --git a/.gitignore b/.gitignore index 30d77b4300..5ffba098fc 100644 --- a/.gitignore +++ b/.gitignore @@ -91,3 +91,6 @@ runs/ # docs trinity*.rst modules.rst + +# wandb +wandb/ diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 18bdacb174..3d2df6e357 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -74,7 +74,7 @@ def both(config: Config) -> None: try: ref_explore = explorer.explore_step.remote() ref_train = trainer.train_step.remote(algo_type) - explore_continue, _ = ray.get(ref_explore) + explore_continue, explore_iter_num = ray.get(ref_explore) train_continue, train_iter_num = ray.get(ref_train) if not explore_continue: logger.info("Explorer finished, stopping...") @@ -88,9 +88,9 @@ def both(config: Config) -> None: logger.error(e) logger.error("Training stopped due to exception.") raise e - if (train_iter_num - 1) % config.trainer.eval_interval == 0: + if train_iter_num % config.trainer.eval_interval == 0: try: - ray.get(explorer.eval.remote(train_iter_num)) + ray.get(explorer.eval.remote()) logger.info("Evaluation finished.") except Exception as e: logger.error(e) diff --git a/trinity/common/config.py b/trinity/common/config.py index b9dc6bcad5..c44960df7f 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -310,7 +310,7 @@ def check_and_update(self) -> None: # check eval_interval if self.trainer.eval_interval % self.synchronizer.sync_iteration_interval != 0: self.trainer.eval_interval = ( - self.trainer.eval_interval // self.synchronizer.sync_iteration_interval + max(self.trainer.eval_interval // self.synchronizer.sync_iteration_interval, 1) ) * self.synchronizer.sync_iteration_interval print( f"Warning: eval_interval is not a multiple of sync_iteration_interval; adjusted to the nearest integer={self.trainer.eval_interval}." diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 6e68ba894c..8d64f060c9 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -244,7 +244,7 @@ def train_dpo_iteration(self, experiences: Experiences) -> Tuple[bool, int]: self._save_checkpoint() self.global_steps += 1 - return True, self.global_steps + return True, self.global_steps - 1 def train_sft_iteration(self, experiences: Experiences) -> Tuple[bool, int]: metrics = {} @@ -310,7 +310,7 @@ def train_sft_iteration(self, experiences: Experiences) -> Tuple[bool, int]: with _timer("save_checkpoint", timing_raw): self._save_checkpoint() self.global_steps += 1 - return True, self.global_steps + return True, self.global_steps - 1 def train_rft_iteration(self, experiences: Experiences) -> Tuple[bool, int]: metrics = {} @@ -457,10 +457,10 @@ def train_rft_iteration(self, experiences: Experiences) -> Tuple[bool, int]: with _timer("save_checkpoint", timing_raw): self._save_checkpoint() # stop training - return False, self.global_steps + return False, self.global_steps - 1 else: # continue - return True, self.global_steps + return True, self.global_steps - 1 def _log_single_experience( self, experiences: Experiences, idx: int, skip_special_tokens: bool