From e35e681a5cd3310ee170f35400e62057cb941a6f Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Wed, 31 Dec 2025 13:19:03 +0800 Subject: [PATCH] Bug fix in multi stage resume --- tests/trainer/trainer_test.py | 13 ++++++++++--- trinity/cli/launcher.py | 1 + 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index c01b55408d..c37c05559c 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -331,12 +331,15 @@ def test_trainer(self, mock_load): ), ] self.config.check_and_update() + old_taskset_path = self.config.stages[1].buffer.explorer_input.taskset.path + self.config.stages[1].buffer.explorer_input.taskset.path = "/invalid/path" - mock_load.return_value = self.config + mock_load.return_value = deepcopy(self.config) - run(config_path="dummy.yaml") + with self.assertRaises(Exception): + run(config_path="dummy.yaml") - stage_configs = [cfg.check_and_update() for cfg in self.config] + stage_configs = [cfg.check_and_update() for cfg in deepcopy(self.config)] # sft warmup stage sft_config = stage_configs[0] @@ -351,6 +354,10 @@ def test_trainer(self, mock_load): self.assertEqual(parser.metric_min_step(response_metrics[0]), 1) self.assertEqual(parser.metric_max_step(response_metrics[0]), 3) + self.config.stages[1].buffer.explorer_input.taskset.path = old_taskset_path + mock_load.return_value = deepcopy(self.config) + run(config_path="dummy.yaml") + # grpo stage grpo_config = stage_configs[1] parser = TensorBoardParser(os.path.join(grpo_config.monitor.cache_dir, "tensorboard")) diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 28ba57ecc8..46e6cb2a0e 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -191,6 +191,7 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None): f"> Skipping completed stage {i + 1}/{len(config.stages)}...\n" "===========================================================" ) + stage_config.check_and_update() else: logger.info( "===========================================================\n"