diff --git a/tinker_cookbook/rl/train.py b/tinker_cookbook/rl/train.py index e40f60da..ab8d3384 100644 --- a/tinker_cookbook/rl/train.py +++ b/tinker_cookbook/rl/train.py @@ -256,6 +256,7 @@ class Config: eval_every: int = 20 # 0 = disabled save_every: int = 20 # 0 = disabled load_checkpoint_path: str | None = None + checkpoint_name_prefix: str | None = None async_config: AsyncConfig | None = None stream_minibatch_config: StreamMinibatchConfig | None = None @@ -328,7 +329,12 @@ async def do_sync_training_with_stream_minibatch( """ # Initial sampling client sampling_client, _ = await save_checkpoint_and_get_sampling_client( - training_client, start_batch, cfg.log_path, cfg.save_every, start_batch + training_client, + start_batch, + cfg.log_path, + cfg.save_every, + start_batch, + cfg.checkpoint_name_prefix, ) for i_batch in range(start_batch, end_batch): @@ -687,13 +693,17 @@ async def save_checkpoint_and_get_sampling_client( log_path: str, save_every: int, start_batch: int = 0, + checkpoint_name_prefix: str | None = None, ) -> tuple[tinker.SamplingClient, dict[str, Any]]: metrics = {} with timed("save_checkpoint", metrics): if save_every > 0 and i_batch > start_batch and i_batch % save_every == 0: + name = f"{i_batch:06d}" + if checkpoint_name_prefix: + name = f"{checkpoint_name_prefix}_{name}" path_dict = await checkpoint_utils.save_checkpoint_async( training_client=training_client, - name=f"{i_batch:06d}", + name=name, log_path=log_path, loop_state={"batch": i_batch}, kind="both", @@ -753,6 +763,7 @@ async def compute_full_batch_metrics_and_get_sampling_client( log_path: str, save_every: int, do_compute_post_kl: bool, + checkpoint_name_prefix: str | None = None, ) -> tuple[tinker.SamplingClient, dict[str, Any]]: """ At the end of the iteration, this will compute metrics for the full batch @@ -770,7 +781,11 @@ async def compute_full_batch_metrics_and_get_sampling_client( # Get a sampling client using the new weights sampling_client, checkpoint_metrics = await save_checkpoint_and_get_sampling_client( - training_client, i_batch, log_path, save_every + training_client, + i_batch, + log_path, + save_every, + checkpoint_name_prefix, ) metrics.update(checkpoint_metrics) @@ -902,6 +917,7 @@ async def do_train_step_streaming_and_get_sampling_client( cfg.log_path, cfg.save_every, cfg.compute_post_kl, + cfg.checkpoint_name_prefix, ) metrics.update(full_batch_metrics) return sampling_client, metrics @@ -949,6 +965,7 @@ async def do_train_step_and_get_sampling_client( cfg.log_path, cfg.save_every, cfg.compute_post_kl, + cfg.checkpoint_name_prefix, ) metrics.update(full_batch_metrics) @@ -971,7 +988,12 @@ async def do_sync_training( """Implements fully synchronous on-policy training""" # Initial sampling client sampling_client, _ = await save_checkpoint_and_get_sampling_client( - training_client, start_batch, cfg.log_path, cfg.save_every, start_batch + training_client, + start_batch, + cfg.log_path, + cfg.save_every, + start_batch, + cfg.checkpoint_name_prefix, ) for i_batch in range(start_batch, end_batch):