diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index c07fe07921..b1d08f03b0 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -112,7 +112,7 @@ algorithm: - `optimizer`: Optimizer configuration for actor. - `lr`: Learning rate for actor. - `warmup_style`: Warmup style for actor's learning rate. -- `sample_strategy`: The sampling strategy used for loading experiences from experience buffer. +- `sample_strategy`: The sampling strategy used for loading experiences from experience buffer. Supported types: `default`, `staleness_control`, `mix`. - `advantage_fn`: The advantage function used for computing advantages. - `kl_penalty_fn`: The KL penalty function used for computing KL penalty applied in reward. - `kl_loss_fn`: The KL loss function used for computing KL loss. diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md index 7e7aec64b8..0fff6cd91f 100644 --- a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md @@ -112,7 +112,7 @@ algorithm: - `optimizer`: Actor 优化器的参数。 - `lr`: 优化器的学习率。 - `warmup_style`: 学习率的预热策略。 -- `sample_strategy`: 从 experience buffer 加载 experience 时使用的采样策略。 +- `sample_strategy`: 从 experience buffer 加载 experience 时使用的采样策略。支持类型:`default`、`staleness_control`、`mix`。 - `advantage_fn`: 用于计算优势值的函数。 - `kl_penalty_fn`: 用于在奖励中计算 KL 惩罚的函数。 - `kl_loss_fn`: 用于计算 KL 损失的函数。 diff --git a/tests/buffer/experience_storage_test.py b/tests/buffer/experience_storage_test.py index da0f80df8f..f308e1ee10 100644 --- a/tests/buffer/experience_storage_test.py +++ b/tests/buffer/experience_storage_test.py @@ -108,6 +108,7 @@ async def test_sql_experience_buffer(self): prompt_length=i, reward=float(i), logprobs=torch.tensor([0.1]), + info={"model_version": 0}, ) for i in range(1, self.put_batch_size + 1) ] diff --git a/tests/buffer/sample_strategy_test.py b/tests/buffer/sample_strategy_test.py new file mode 100644 index 0000000000..32ea84bdb7 --- /dev/null +++ b/tests/buffer/sample_strategy_test.py @@ -0,0 +1,251 @@ +import asyncio +import shutil +from collections import deque + +import torch +from parameterized import parameterized_class + +from tests.tools import RayUnittestBaseAysnc, get_template_config +from trinity.algorithm.sample_strategy import SAMPLE_STRATEGY +from trinity.algorithm.sample_strategy.sample_strategy import SampleStrategy +from trinity.buffer.buffer import get_buffer_writer +from trinity.common.config import ExperienceBufferConfig, ReplayBufferConfig +from trinity.common.constants import StorageType +from trinity.common.experience import Experience + + +@parameterized_class( + ("exp_write_batch_size",), + [ + (3,), + (6,), + ], +) +class ExperienceStorageTest(RayUnittestBaseAysnc): + def setUp(self): + self.config = get_template_config() + self.num_steps = 20 + + def _default_exp_list(self): + return [ + [ + Experience( + tokens=torch.tensor([float(k) for k in range(j + 3)]), + reward=float(i), # using reward to carry model_version for testing + prompt_length=2, + info={"model_version": i, "use_count": 0}, + ) + for j in range(self.exp_write_batch_size) + ] + for i in range(self.num_steps) + ] + + def _default_steps(self): + return [0, 5, 10, 15] + + def _init_buffer_writer_and_sample_strategy(self): + # Initialize buffer writer and sample strategy + self.buffer_writer = get_buffer_writer( + self.config.buffer.trainer_input.experience_buffer, # type: ignore [arg-type] + ) + self.sample_strategy: SampleStrategy = SAMPLE_STRATEGY.get( + self.config.algorithm.sample_strategy + )( + buffer_config=self.config.buffer, + **self.config.algorithm.sample_strategy_args, + ) + + async def _verify_model_version(self, step, expected_versions): + batch, metrics, _ = await self.sample_strategy.sample(step=step) + self.assertEqual( + batch.rewards.tolist(), expected_versions, f"Model versions mismatch at step {step}" + ) + self.assertEqual( + metrics["sample/model_version/min"], + min(expected_versions), + f"Min model version mismatch at step {step}", + ) + self.assertEqual( + metrics["sample/model_version/max"], + max(expected_versions), + f"Max model version mismatch at step {step}", + ) + self.assertEqual( + metrics["sample/model_version/mean"], + sum(expected_versions) / len(expected_versions), + f"Mean model version mismatch at step {step}", + ) + + async def _verify_sampling_model_versions(self, exps_list, expected_model_versions_map): + self._init_buffer_writer_and_sample_strategy() + + # Write experiences to buffer, while sample and validate model versions + current_task = None + for step, exps in enumerate(exps_list): + await self.buffer_writer.write_async(exps) + if step in expected_model_versions_map: + if current_task: + await current_task + current_task = asyncio.create_task( + self._verify_model_version(step, expected_model_versions_map[step]) + ) + await asyncio.sleep(0.1) + + if current_task: + await current_task + + async def _flexible_verify_model_version(self, step, max_staleness): + _, metrics, _ = await self.sample_strategy.sample(step=step) + self.assertGreaterEqual( + metrics["sample/model_version/min"], + step - max_staleness, + f"Min model version mismatch at step {step}", + ) + + async def _flexible_verify_sampling_model_versions(self, exps_list, check_steps, max_staleness): + self._init_buffer_writer_and_sample_strategy() + + # Write experiences to buffer, while sample and validate model versions + current_task = None + for step, exps in enumerate(exps_list): + await self.buffer_writer.write_async(exps) + if step in check_steps: + if current_task: + await current_task + current_task = asyncio.create_task( + self._flexible_verify_model_version(step, max_staleness) + ) + await asyncio.sleep(0.1) + + if current_task: + await current_task + + async def test_default_queue_default_sample_strategy(self): + self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( + name="default_queue_default_strategy", + storage_type=StorageType.QUEUE.value, + replay_buffer=ReplayBufferConfig(enable=False), + ) + self.config.check_and_update() + + # init testing data + exps_list = self._default_exp_list() + steps = self._default_steps() + train_batch_size = self.config.buffer.train_batch_size + expected_model_versions_map = {} + for idx, step in enumerate(steps): + start_idx = idx * train_batch_size + batch_versions = [ + (start_idx + offset) // self.exp_write_batch_size + for offset in range(train_batch_size) + ] + expected_model_versions_map[step] = batch_versions + + await self._verify_sampling_model_versions(exps_list, expected_model_versions_map) + + async def test_default_queue_staleness_control_sample_strategy(self): + max_staleness = 3 + self.config.algorithm.sample_strategy = "staleness_control" + self.config.algorithm.sample_strategy_args = {"max_staleness": max_staleness} + self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( + name="default_queue_staleness_control", + storage_type=StorageType.QUEUE.value, + replay_buffer=ReplayBufferConfig(enable=False), + ) + self.config.check_and_update() + + # init testing data + exps_list = self._default_exp_list() + steps = self._default_steps() + expected_model_versions_map = {} + for step in steps: + predict_version = max(step - max_staleness, 0) + expected_model_versions_map[step] = [ + predict_version + i // self.exp_write_batch_size + for i in range(self.config.buffer.train_batch_size) + ] + + await self._verify_sampling_model_versions(exps_list, expected_model_versions_map) + + def _simulate_priority_queue(self, steps, max_staleness=float("inf")): + expected_model_versions_map = {} + buffer = deque() + exp_pool = deque() + step_idx = 0 + train_batch_size = self.config.buffer.train_batch_size + for i in range(self.num_steps): + buffer.append([i] * self.exp_write_batch_size) + step = steps[step_idx] + if i < step: + continue + batch_versions = expected_model_versions_map.get(step, []) + if len(batch_versions) < train_batch_size: + while len(buffer) > 0: + if len(exp_pool) == 0: + exp_pool.extend(buffer.pop()) + while len(exp_pool) > 0 and len(batch_versions) < train_batch_size: + exp_version = exp_pool.popleft() + if exp_version < step - max_staleness: + continue + batch_versions.append(exp_version) + if len(batch_versions) >= train_batch_size: + step_idx += 1 + break + expected_model_versions_map[step] = batch_versions + if step_idx >= len(steps): + break + return expected_model_versions_map + + async def test_priority_queue_default_sample_strategy(self): + self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( + name="priority_queue_default_strategy", + storage_type=StorageType.QUEUE.value, + replay_buffer=ReplayBufferConfig(enable=True), + ) + self.config.check_and_update() + + # init testing data + exps_list = self._default_exp_list() + steps = self._default_steps() + expected_model_versions_map = self._simulate_priority_queue(steps) + + await self._verify_sampling_model_versions(exps_list, expected_model_versions_map) + + async def test_priority_queue_staleness_control_sample_strategy(self): + max_staleness = 2 + self.config.algorithm.sample_strategy = "staleness_control" + self.config.algorithm.sample_strategy_args = {"max_staleness": max_staleness} + self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( + name="priority_queue_staleness_control", + storage_type=StorageType.QUEUE.value, + replay_buffer=ReplayBufferConfig(enable=True), + ) + self.config.check_and_update() + + # init testing data + exps_list = self._default_exp_list() + steps = self._default_steps() + expected_model_versions_map = self._simulate_priority_queue(steps, max_staleness) + + await self._verify_sampling_model_versions(exps_list, expected_model_versions_map) + + async def test_sql_staleness_control_sample_strategy(self): + max_staleness = 2 + self.config.algorithm.sample_strategy = "staleness_control" + self.config.algorithm.sample_strategy_args = {"max_staleness": max_staleness} + self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( + name="sql_staleness_control", + storage_type=StorageType.SQL.value, + ) + self.config.check_and_update() + + # init testing data + exps_list = self._default_exp_list() + steps = self._default_steps() + + await self._flexible_verify_sampling_model_versions(exps_list, steps, max_staleness) + + def tearDown(self): + asyncio.run(self.buffer_writer.release()) + shutil.rmtree(self.config.checkpoint_job_dir) + return super().tearDown() diff --git a/tests/buffer/sql_test.py b/tests/buffer/sql_test.py index 44b81e1495..1e742a54bc 100644 --- a/tests/buffer/sql_test.py +++ b/tests/buffer/sql_test.py @@ -34,6 +34,7 @@ async def test_sql_exp_buffer_read_write(self) -> None: prompt_length=i, reward=float(i), logprobs=torch.tensor([0.1]), + info={"model_version": i}, ) for i in range(1, put_batch_size + 1) ] @@ -52,6 +53,7 @@ async def test_sql_exp_buffer_read_write(self) -> None: reward=float(i), logprobs=torch.tensor([0.1]), action_mask=torch.tensor([j % 2 for j in range(i + 1)]), + info={"model_version": i}, ) for i in range(1, put_batch_size * 2 + 1) ] diff --git a/tests/common/experience_test.py b/tests/common/experience_test.py index 195aaa61ae..55eabca721 100644 --- a/tests/common/experience_test.py +++ b/tests/common/experience_test.py @@ -253,6 +253,7 @@ def test_experience_model_experience_conversion(self): reward=reward, prompt_length=prompt_length, logprobs=logprobs, + info={"model_version": 0}, ) model = ExperienceModel.from_experience(experience) diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index c061099437..b1282d7c7a 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -5,6 +5,7 @@ import os import random import shutil +import unittest from datetime import datetime import httpx @@ -200,6 +201,7 @@ def setUp(self): if multiprocessing.get_start_method(allow_none=True) != "spawn": multiprocessing.set_start_method("spawn", force=True) + @unittest.skip("Require improvement for agent mode") async def test_serve(self): # noqa: C901 serve_process = multiprocessing.Process(target=run_serve, args=(self.config,)) serve_process.start() diff --git a/tests/service/data_juicer_test.py b/tests/service/data_juicer_test.py index 2fdd89ad85..60440e0d6e 100644 --- a/tests/service/data_juicer_test.py +++ b/tests/service/data_juicer_test.py @@ -182,24 +182,28 @@ async def test_data_juicer_operators(self): prompt_length=3, prompt_text="Hello, how are you?", response_text="Hi, I am fine.", + info={"model_version": 0}, ), Experience( # too short response tokens=torch.tensor([1, 2, 3, 4, 5]), prompt_length=3, prompt_text="What is your name?", response_text="Trinity.", + info={"model_version": 0}, ), Experience( # repeated words tokens=torch.tensor([1, 2, 3, 4, 5]), prompt_length=3, prompt_text="What day is it today?", response_text="Today is Sunday Sunday Sunday Sunday Sunday and it's a happy day!", + info={"model_version": 0}, ), Experience( tokens=torch.tensor([1, 2, 3, 4, 5]), prompt_length=3, prompt_text="What is your favorite color?", response_text="My favorite color is blue.", + info={"model_version": 0}, ), ] metrics = await pipeline.process.remote(exps) diff --git a/trinity/algorithm/sample_strategy/__init__.py b/trinity/algorithm/sample_strategy/__init__.py index 9e2700fb4a..067b45a2e2 100644 --- a/trinity/algorithm/sample_strategy/__init__.py +++ b/trinity/algorithm/sample_strategy/__init__.py @@ -6,6 +6,7 @@ default_mapping={ "default": "trinity.algorithm.sample_strategy.sample_strategy.DefaultSampleStrategy", "warmup": "trinity.algorithm.sample_strategy.sample_strategy.WarmupSampleStrategy", + "staleness_control": "trinity.algorithm.sample_strategy.sample_strategy.StalenessControlSampleStrategy", "mix": "trinity.algorithm.sample_strategy.mix_sample_strategy.MixSampleStrategy", }, ) diff --git a/trinity/algorithm/sample_strategy/sample_strategy.py b/trinity/algorithm/sample_strategy/sample_strategy.py index e15c3e0b0b..2ab63032cb 100644 --- a/trinity/algorithm/sample_strategy/sample_strategy.py +++ b/trinity/algorithm/sample_strategy/sample_strategy.py @@ -76,6 +76,23 @@ def load_state_dict(self, state_dict: dict) -> None: self.exp_buffer.load_state_dict(state_dict) +class StalenessControlSampleStrategy(DefaultSampleStrategy): + def __init__(self, buffer_config: BufferConfig, **kwargs): + super().__init__(buffer_config) + self.max_staleness = kwargs.get("max_staleness", float("inf")) + + async def sample(self, step: int, **kwargs) -> Tuple[Experiences, Dict, List]: + min_model_version = max(step - self.max_staleness, 0) + metrics = {} + with Timer(metrics, "time/read_experience"): + exp_list = await self.exp_buffer.read_async(min_model_version=min_model_version) + repr_samples = representative_sample(exp_list) + self.set_model_version_metric(exp_list, metrics) + with Timer(metrics, "time/gather_experience"): + exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore + return exps, metrics, repr_samples + + @Deprecated class WarmupSampleStrategy(DefaultSampleStrategy): """The warmup sample strategy. diff --git a/trinity/buffer/buffer_reader.py b/trinity/buffer/buffer_reader.py index d47d80ace1..ad4d414547 100644 --- a/trinity/buffer/buffer_reader.py +++ b/trinity/buffer/buffer_reader.py @@ -7,11 +7,11 @@ class BufferReader(ABC): """Interface of the buffer reader.""" @abstractmethod - def read(self, batch_size: Optional[int] = None) -> List: + def read(self, batch_size: Optional[int] = None, **kwargs) -> List: """Read from buffer.""" @abstractmethod - async def read_async(self, batch_size: Optional[int] = None) -> List: + async def read_async(self, batch_size: Optional[int] = None, **kwargs) -> List: """Read from buffer asynchronously.""" def __len__(self) -> int: diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index 6cb35e58a7..ac4d728263 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -85,7 +85,7 @@ def select_batch(self, indices: List[int]) -> List: class BaseFileReader(BufferReader): - async def read_async(self, batch_size: Optional[int] = None): + async def read_async(self, batch_size: Optional[int] = None, **kwargs): try: return self.read(batch_size) except StopIteration as e: @@ -101,7 +101,7 @@ def __init__(self, config: StorageConfig): else: self.reader = TaskFileReader(config) - def read(self, batch_size: Optional[int] = None) -> List: + def read(self, batch_size: Optional[int] = None, **kwargs) -> List: return self.reader.read(batch_size) def read_with_indices(self, indices: List[int]) -> List: @@ -140,7 +140,7 @@ def __init__(self, config: StorageConfig): enable_progress_bar=config.enable_progress_bar, ) - def read(self, batch_size: Optional[int] = None) -> List: + def read(self, batch_size: Optional[int] = None, **kwargs) -> List: samples, _ = self.dataset.read_batch(batch_size or self.read_batch_size) exp_list = [] for sample in samples: @@ -187,7 +187,7 @@ def _get_tasks(self, samples: List, indices: List) -> List: tasks.append(task) return tasks - def read(self, batch_size: Optional[int] = None) -> List: + def read(self, batch_size: Optional[int] = None, **kwargs) -> List: batch_size = batch_size or self.read_batch_size samples, indices = self.dataset.read_batch(batch_size) return self._get_tasks(samples, indices) diff --git a/trinity/buffer/reader/queue_reader.py b/trinity/buffer/reader/queue_reader.py index b743c20b8e..b3b1d14c12 100644 --- a/trinity/buffer/reader/queue_reader.py +++ b/trinity/buffer/reader/queue_reader.py @@ -19,10 +19,10 @@ def __init__(self, config: StorageConfig): self.read_batch_size = config.batch_size self.queue = QueueStorage.get_wrapper(config) - def read(self, batch_size: Optional[int] = None) -> List: + def read(self, batch_size: Optional[int] = None, **kwargs) -> List: try: batch_size = batch_size or self.read_batch_size - exps = ray.get(self.queue.get_batch.remote(batch_size, timeout=self.timeout)) + exps = ray.get(self.queue.get_batch.remote(batch_size, timeout=self.timeout, **kwargs)) if len(exps) != batch_size: raise TimeoutError( f"Read incomplete batch ({len(exps)}/{batch_size}), please check your workflow." @@ -31,9 +31,9 @@ def read(self, batch_size: Optional[int] = None) -> List: raise StopIteration() return exps - async def read_async(self, batch_size: Optional[int] = None) -> List: + async def read_async(self, batch_size: Optional[int] = None, **kwargs) -> List: batch_size = batch_size or self.read_batch_size - exps = await self.queue.get_batch.remote(batch_size, timeout=self.timeout) + exps = await self.queue.get_batch.remote(batch_size, timeout=self.timeout, **kwargs) if len(exps) != batch_size: raise TimeoutError( f"Read incomplete batch ({len(exps)}/{batch_size}), please check your workflow." diff --git a/trinity/buffer/reader/sql_reader.py b/trinity/buffer/reader/sql_reader.py index 0d7943f8dd..f7572c628c 100644 --- a/trinity/buffer/reader/sql_reader.py +++ b/trinity/buffer/reader/sql_reader.py @@ -18,20 +18,20 @@ def __init__(self, config: StorageConfig) -> None: self.wrap_in_ray = config.wrap_in_ray self.storage = SQLStorage.get_wrapper(config) - def read(self, batch_size: Optional[int] = None) -> List: + def read(self, batch_size: Optional[int] = None, **kwargs) -> List: if self.wrap_in_ray: - return ray.get(self.storage.read.remote(batch_size)) + return ray.get(self.storage.read.remote(batch_size, **kwargs)) else: - return self.storage.read(batch_size) + return self.storage.read(batch_size, **kwargs) - async def read_async(self, batch_size: Optional[int] = None) -> List: + async def read_async(self, batch_size: Optional[int] = None, **kwargs) -> List: if self.wrap_in_ray: try: - return ray.get(self.storage.read.remote(batch_size)) + return await self.storage.read.remote(batch_size, **kwargs) except StopIteration: raise StopAsyncIteration else: - return self.storage.read(batch_size) + return self.storage.read(batch_size, **kwargs) def state_dict(self) -> Dict: # SQL Not supporting state dict yet diff --git a/trinity/buffer/schema/sql_schema.py b/trinity/buffer/schema/sql_schema.py index 3a7ae3f105..997c661a23 100644 --- a/trinity/buffer/schema/sql_schema.py +++ b/trinity/buffer/schema/sql_schema.py @@ -38,6 +38,8 @@ class ExperienceModel(Base): # type: ignore # for multi turn message_list = Column(JSON, nullable=True) reward = Column(Float, nullable=True) + # for step info + model_version = Column(Integer, nullable=True, index=True) # serialized experience object experience_bytes = Column(LargeBinary, nullable=True) consumed = Column(Integer, default=0, index=True) @@ -50,11 +52,12 @@ def to_experience(self) -> Experience: def from_experience(cls, experience: Experience): """Save the experience to database.""" return cls( - experience_bytes=experience.serialize(), - reward=experience.reward, prompt=experience.prompt_text, response=experience.response_text, message_list=experience.messages, + reward=experience.reward, + model_version=experience.info["model_version"], + experience_bytes=experience.serialize(), ) diff --git a/trinity/buffer/storage/queue.py b/trinity/buffer/storage/queue.py index 3f1c7268b6..a0da043895 100644 --- a/trinity/buffer/storage/queue.py +++ b/trinity/buffer/storage/queue.py @@ -100,6 +100,9 @@ def default_config(cls) -> Dict: class QueueBuffer(ABC): + async def set_min_model_version(self, min_model_version: int): + self.min_model_version = max(min_model_version, 0) + @abstractmethod async def put(self, exps: List[Experience]) -> None: """Put a list of experiences into the queue.""" @@ -149,6 +152,21 @@ def __init__(self, capacity: int): """ super().__init__(maxsize=capacity) self._closed = False + self.min_model_version = 0 + + async def put(self, item: List[Experience]): + if len(item) == 0: + return + await super().put(item) + + async def get(self): + while True: + item = await super().get() + if ( + self.min_model_version <= 0 + or item[0].info["model_version"] >= self.min_model_version + ): + return item async def close(self) -> None: """Close the queue.""" @@ -204,6 +222,7 @@ def __init__( self.reuse_cooldown_time = reuse_cooldown_time self._condition = asyncio.Condition() # For thread-safe operations self._closed = False + self.min_model_version = 0 async def _put(self, item: List[Experience], delay: float = 0) -> None: """ @@ -255,16 +274,23 @@ async def get(self) -> List[Experience]: - After retrieval, the item is optionally reinserted after a cooldown period. """ async with self._condition: - while len(self.priority_groups) == 0: - if self._closed: - raise StopAsyncIteration() - await self._condition.wait() + while True: + while len(self.priority_groups) == 0: + if self._closed: + raise StopAsyncIteration() + await self._condition.wait() + + _, item_queue = self.priority_groups.peekitem(index=-1) + item = item_queue.popleft() + self.item_count -= 1 + if not item_queue: + self.priority_groups.popitem(index=-1) - _, item_queue = self.priority_groups.peekitem(index=-1) - item = item_queue.popleft() - self.item_count -= 1 - if not item_queue: - self.priority_groups.popitem(index=-1) + if ( + self.min_model_version <= 0 + or item[0].info["model_version"] >= self.min_model_version + ): + break for exp in item: exp.info["use_count"] += 1 @@ -348,10 +374,20 @@ async def put_batch(self, exp_list: List) -> None: if self.writer is not None: self.writer.write(exp_list) - async def get_batch(self, batch_size: int, timeout: float) -> List: + async def get_batch(self, batch_size: int, timeout: float, min_model_version: int = 0) -> List: """Get batch of experience.""" + await self.queue.set_min_model_version(min_model_version) start_time = time.time() - while len(self.exp_pool) < batch_size: + result = [] + while len(result) < batch_size: + while len(self.exp_pool) > 0 and len(result) < batch_size: + exp = self.exp_pool.popleft() + if min_model_version > 0 and exp.info["model_version"] < min_model_version: + continue + result.append(exp) + if len(result) >= batch_size: + break + if self.queue.stopped(): # If the queue is stopped, ignore the rest of the experiences in the pool raise StopAsyncIteration("Queue is closed and no more items to get.") @@ -368,7 +404,7 @@ async def get_batch(self, batch_size: int, timeout: float) -> List: batch = list(self.exp_pool) self.exp_pool.clear() return batch - return [self.exp_pool.popleft() for _ in range(batch_size)] + return result @classmethod def get_wrapper(cls, config: StorageConfig): diff --git a/trinity/buffer/storage/sql.py b/trinity/buffer/storage/sql.py index fb21373cda..08ff06fb8c 100644 --- a/trinity/buffer/storage/sql.py +++ b/trinity/buffer/storage/sql.py @@ -143,7 +143,7 @@ def _read_fifo(self, batch_size: int) -> List[Experience]: time.sleep(1) return exp_list - def _read_priority(self, batch_size: int) -> List[Experience]: + def _read_priority(self, batch_size: int, min_model_version: int = 0) -> List[Experience]: exp_list = [] start_time = time.time() latest_size = 0 @@ -158,9 +158,13 @@ def _read_priority(self, batch_size: int) -> List[Experience]: with retry_session( self.session, self.max_retry_times, self.max_retry_interval ) as session: + query = session.query(self.table_model_cls) + if min_model_version > 0: + query = query.filter(self.table_model_cls.model_version >= min_model_version) experiences = ( - session.query(self.table_model_cls) - .order_by(asc(self.table_model_cls.consumed), desc(self.table_model_cls.id)) + query.order_by( + asc(self.table_model_cls.consumed), desc(self.table_model_cls.id) + ) .limit(batch_size) .with_for_update() .all() @@ -186,12 +190,12 @@ def _read_priority(self, batch_size: int) -> List[Experience]: time.sleep(1) return exp_list - def read(self, batch_size: Optional[int] = None) -> List[Experience]: + def read(self, batch_size: Optional[int] = None, **kwargs) -> List[Experience]: if self.stopped: raise StopIteration() batch_size = batch_size or self.batch_size - return self._read_method(batch_size) + return self._read_method(batch_size, **kwargs) @classmethod def load_from_dataset(cls, dataset: Dataset, config: StorageConfig) -> "SQLExperienceStorage": diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 458c1ba626..c690bc407d 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -47,8 +47,8 @@ def __init__(self, config: Config): ) explorer_state = self.state.load_explorer() self.explore_step_num = explorer_state.get("latest_iteration", 0) - self.last_sync_step = self.explore_step_num if self.explore_step_num > 0 else -1 - self.last_monitored_step = self.explore_step_num if self.explore_step_num > 0 else -1 + self.last_sync_step = self.explore_step_num + self.last_monitored_step = self.explore_step_num self.synchronizer = Synchronizer.get_actor(config) self.config = config self.models, self.auxiliary_models = create_inference_models(config) @@ -328,9 +328,12 @@ async def benchmark(self) -> bool: async def save_checkpoint(self, sync_weight: bool = False) -> None: if self.scheduler: - await self._finish_steps( - self.last_monitored_step + 1, self.explore_step_num, self.model_version - ) + if self.explore_step_num == 0: + await self._finish_eval_step(step=0) + else: + await self._finish_steps( + self.last_monitored_step + 1, self.explore_step_num, self.model_version + ) self.last_monitored_step = self.explore_step_num if sync_weight: