Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ algorithm:
- `optimizer`: Optimizer configuration for actor.
- `lr`: Learning rate for actor.
- `warmup_style`: Warmup style for actor's learning rate.
- `sample_strategy`: The sampling strategy used for loading experiences from experience buffer.
- `sample_strategy`: The sampling strategy used for loading experiences from experience buffer. Supported types: `default`, `staleness_control`, `mix`.
- `advantage_fn`: The advantage function used for computing advantages.
- `kl_penalty_fn`: The KL penalty function used for computing KL penalty applied in reward.
- `kl_loss_fn`: The KL loss function used for computing KL loss.
Expand Down
2 changes: 1 addition & 1 deletion docs/sphinx_doc/source_zh/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ algorithm:
- `optimizer`: Actor 优化器的参数。
- `lr`: 优化器的学习率。
- `warmup_style`: 学习率的预热策略。
- `sample_strategy`: 从 experience buffer 加载 experience 时使用的采样策略。
- `sample_strategy`: 从 experience buffer 加载 experience 时使用的采样策略。支持类型:`default`、`staleness_control`、`mix`。
- `advantage_fn`: 用于计算优势值的函数。
- `kl_penalty_fn`: 用于在奖励中计算 KL 惩罚的函数。
- `kl_loss_fn`: 用于计算 KL 损失的函数。
Expand Down
1 change: 1 addition & 0 deletions tests/buffer/experience_storage_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ async def test_sql_experience_buffer(self):
prompt_length=i,
reward=float(i),
logprobs=torch.tensor([0.1]),
info={"model_version": 0},
)
for i in range(1, self.put_batch_size + 1)
]
Expand Down
251 changes: 251 additions & 0 deletions tests/buffer/sample_strategy_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
import asyncio
import shutil
from collections import deque

import torch
from parameterized import parameterized_class

from tests.tools import RayUnittestBaseAysnc, get_template_config
from trinity.algorithm.sample_strategy import SAMPLE_STRATEGY
from trinity.algorithm.sample_strategy.sample_strategy import SampleStrategy
from trinity.buffer.buffer import get_buffer_writer
from trinity.common.config import ExperienceBufferConfig, ReplayBufferConfig
from trinity.common.constants import StorageType
from trinity.common.experience import Experience


@parameterized_class(
("exp_write_batch_size",),
[
(3,),
(6,),
],
)
class ExperienceStorageTest(RayUnittestBaseAysnc):
def setUp(self):
self.config = get_template_config()
self.num_steps = 20

def _default_exp_list(self):
return [
[
Experience(
tokens=torch.tensor([float(k) for k in range(j + 3)]),
reward=float(i), # using reward to carry model_version for testing
prompt_length=2,
info={"model_version": i, "use_count": 0},
)
for j in range(self.exp_write_batch_size)
]
for i in range(self.num_steps)
]

def _default_steps(self):
return [0, 5, 10, 15]

def _init_buffer_writer_and_sample_strategy(self):
# Initialize buffer writer and sample strategy
self.buffer_writer = get_buffer_writer(
self.config.buffer.trainer_input.experience_buffer, # type: ignore [arg-type]
)
self.sample_strategy: SampleStrategy = SAMPLE_STRATEGY.get(
self.config.algorithm.sample_strategy
)(
buffer_config=self.config.buffer,
**self.config.algorithm.sample_strategy_args,
)

async def _verify_model_version(self, step, expected_versions):
batch, metrics, _ = await self.sample_strategy.sample(step=step)
self.assertEqual(
batch.rewards.tolist(), expected_versions, f"Model versions mismatch at step {step}"
)
self.assertEqual(
metrics["sample/model_version/min"],
min(expected_versions),
f"Min model version mismatch at step {step}",
)
self.assertEqual(
metrics["sample/model_version/max"],
max(expected_versions),
f"Max model version mismatch at step {step}",
)
self.assertEqual(
metrics["sample/model_version/mean"],
sum(expected_versions) / len(expected_versions),
f"Mean model version mismatch at step {step}",
)

async def _verify_sampling_model_versions(self, exps_list, expected_model_versions_map):
self._init_buffer_writer_and_sample_strategy()

# Write experiences to buffer, while sample and validate model versions
current_task = None
for step, exps in enumerate(exps_list):
await self.buffer_writer.write_async(exps)
if step in expected_model_versions_map:
if current_task:
await current_task
current_task = asyncio.create_task(
self._verify_model_version(step, expected_model_versions_map[step])
)
await asyncio.sleep(0.1)

if current_task:
await current_task

async def _flexible_verify_model_version(self, step, max_staleness):
_, metrics, _ = await self.sample_strategy.sample(step=step)
self.assertGreaterEqual(
metrics["sample/model_version/min"],
step - max_staleness,
f"Min model version mismatch at step {step}",
)

async def _flexible_verify_sampling_model_versions(self, exps_list, check_steps, max_staleness):
self._init_buffer_writer_and_sample_strategy()

# Write experiences to buffer, while sample and validate model versions
current_task = None
for step, exps in enumerate(exps_list):
await self.buffer_writer.write_async(exps)
if step in check_steps:
if current_task:
await current_task
current_task = asyncio.create_task(
self._flexible_verify_model_version(step, max_staleness)
)
await asyncio.sleep(0.1)

if current_task:
await current_task

async def test_default_queue_default_sample_strategy(self):
self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig(
name="default_queue_default_strategy",
storage_type=StorageType.QUEUE.value,
replay_buffer=ReplayBufferConfig(enable=False),
)
self.config.check_and_update()

# init testing data
exps_list = self._default_exp_list()
steps = self._default_steps()
train_batch_size = self.config.buffer.train_batch_size
expected_model_versions_map = {}
for idx, step in enumerate(steps):
start_idx = idx * train_batch_size
batch_versions = [
(start_idx + offset) // self.exp_write_batch_size
for offset in range(train_batch_size)
]
expected_model_versions_map[step] = batch_versions

await self._verify_sampling_model_versions(exps_list, expected_model_versions_map)

async def test_default_queue_staleness_control_sample_strategy(self):
max_staleness = 3
self.config.algorithm.sample_strategy = "staleness_control"
self.config.algorithm.sample_strategy_args = {"max_staleness": max_staleness}
self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig(
name="default_queue_staleness_control",
storage_type=StorageType.QUEUE.value,
replay_buffer=ReplayBufferConfig(enable=False),
)
self.config.check_and_update()

# init testing data
exps_list = self._default_exp_list()
steps = self._default_steps()
expected_model_versions_map = {}
for step in steps:
predict_version = max(step - max_staleness, 0)
expected_model_versions_map[step] = [
predict_version + i // self.exp_write_batch_size
for i in range(self.config.buffer.train_batch_size)
]

await self._verify_sampling_model_versions(exps_list, expected_model_versions_map)

def _simulate_priority_queue(self, steps, max_staleness=float("inf")):
expected_model_versions_map = {}
buffer = deque()
exp_pool = deque()
step_idx = 0
train_batch_size = self.config.buffer.train_batch_size
for i in range(self.num_steps):
buffer.append([i] * self.exp_write_batch_size)
step = steps[step_idx]
if i < step:
continue
batch_versions = expected_model_versions_map.get(step, [])
if len(batch_versions) < train_batch_size:
while len(buffer) > 0:
if len(exp_pool) == 0:
exp_pool.extend(buffer.pop())
while len(exp_pool) > 0 and len(batch_versions) < train_batch_size:
exp_version = exp_pool.popleft()
if exp_version < step - max_staleness:
continue
batch_versions.append(exp_version)
if len(batch_versions) >= train_batch_size:
step_idx += 1
break
expected_model_versions_map[step] = batch_versions
if step_idx >= len(steps):
break
return expected_model_versions_map

async def test_priority_queue_default_sample_strategy(self):
self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig(
name="priority_queue_default_strategy",
storage_type=StorageType.QUEUE.value,
replay_buffer=ReplayBufferConfig(enable=True),
)
self.config.check_and_update()

# init testing data
exps_list = self._default_exp_list()
steps = self._default_steps()
expected_model_versions_map = self._simulate_priority_queue(steps)

await self._verify_sampling_model_versions(exps_list, expected_model_versions_map)

async def test_priority_queue_staleness_control_sample_strategy(self):
max_staleness = 2
self.config.algorithm.sample_strategy = "staleness_control"
self.config.algorithm.sample_strategy_args = {"max_staleness": max_staleness}
self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig(
name="priority_queue_staleness_control",
storage_type=StorageType.QUEUE.value,
replay_buffer=ReplayBufferConfig(enable=True),
)
self.config.check_and_update()

# init testing data
exps_list = self._default_exp_list()
steps = self._default_steps()
expected_model_versions_map = self._simulate_priority_queue(steps, max_staleness)

await self._verify_sampling_model_versions(exps_list, expected_model_versions_map)

async def test_sql_staleness_control_sample_strategy(self):
max_staleness = 2
self.config.algorithm.sample_strategy = "staleness_control"
self.config.algorithm.sample_strategy_args = {"max_staleness": max_staleness}
self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig(
name="sql_staleness_control",
storage_type=StorageType.SQL.value,
)
self.config.check_and_update()

# init testing data
exps_list = self._default_exp_list()
steps = self._default_steps()

await self._flexible_verify_sampling_model_versions(exps_list, steps, max_staleness)

def tearDown(self):
asyncio.run(self.buffer_writer.release())
shutil.rmtree(self.config.checkpoint_job_dir)
return super().tearDown()
2 changes: 2 additions & 0 deletions tests/buffer/sql_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ async def test_sql_exp_buffer_read_write(self) -> None:
prompt_length=i,
reward=float(i),
logprobs=torch.tensor([0.1]),
info={"model_version": i},
)
for i in range(1, put_batch_size + 1)
]
Expand All @@ -52,6 +53,7 @@ async def test_sql_exp_buffer_read_write(self) -> None:
reward=float(i),
logprobs=torch.tensor([0.1]),
action_mask=torch.tensor([j % 2 for j in range(i + 1)]),
info={"model_version": i},
)
for i in range(1, put_batch_size * 2 + 1)
]
Expand Down
1 change: 1 addition & 0 deletions tests/common/experience_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def test_experience_model_experience_conversion(self):
reward=reward,
prompt_length=prompt_length,
logprobs=logprobs,
info={"model_version": 0},
)

model = ExperienceModel.from_experience(experience)
Expand Down
2 changes: 2 additions & 0 deletions tests/explorer/explorer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import random
import shutil
import unittest
from datetime import datetime

import httpx
Expand Down Expand Up @@ -200,6 +201,7 @@ def setUp(self):
if multiprocessing.get_start_method(allow_none=True) != "spawn":
multiprocessing.set_start_method("spawn", force=True)

@unittest.skip("Require improvement for agent mode")
async def test_serve(self): # noqa: C901
serve_process = multiprocessing.Process(target=run_serve, args=(self.config,))
serve_process.start()
Expand Down
4 changes: 4 additions & 0 deletions tests/service/data_juicer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,24 +182,28 @@ async def test_data_juicer_operators(self):
prompt_length=3,
prompt_text="Hello, how are you?",
response_text="Hi, I am fine.",
info={"model_version": 0},
),
Experience( # too short response
tokens=torch.tensor([1, 2, 3, 4, 5]),
prompt_length=3,
prompt_text="What is your name?",
response_text="Trinity.",
info={"model_version": 0},
),
Experience( # repeated words
tokens=torch.tensor([1, 2, 3, 4, 5]),
prompt_length=3,
prompt_text="What day is it today?",
response_text="Today is Sunday Sunday Sunday Sunday Sunday and it's a happy day!",
info={"model_version": 0},
),
Experience(
tokens=torch.tensor([1, 2, 3, 4, 5]),
prompt_length=3,
prompt_text="What is your favorite color?",
response_text="My favorite color is blue.",
info={"model_version": 0},
),
]
metrics = await pipeline.process.remote(exps)
Expand Down
1 change: 1 addition & 0 deletions trinity/algorithm/sample_strategy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
default_mapping={
"default": "trinity.algorithm.sample_strategy.sample_strategy.DefaultSampleStrategy",
"warmup": "trinity.algorithm.sample_strategy.sample_strategy.WarmupSampleStrategy",
"staleness_control": "trinity.algorithm.sample_strategy.sample_strategy.StalenessControlSampleStrategy",
"mix": "trinity.algorithm.sample_strategy.mix_sample_strategy.MixSampleStrategy",
},
)
Expand Down
17 changes: 17 additions & 0 deletions trinity/algorithm/sample_strategy/sample_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,23 @@ def load_state_dict(self, state_dict: dict) -> None:
self.exp_buffer.load_state_dict(state_dict)


class StalenessControlSampleStrategy(DefaultSampleStrategy):
def __init__(self, buffer_config: BufferConfig, **kwargs):
super().__init__(buffer_config)
self.max_staleness = kwargs.get("max_staleness", float("inf"))

async def sample(self, step: int, **kwargs) -> Tuple[Experiences, Dict, List]:
min_model_version = max(step - self.max_staleness, 0)
metrics = {}
with Timer(metrics, "time/read_experience"):
exp_list = await self.exp_buffer.read_async(min_model_version=min_model_version)
repr_samples = representative_sample(exp_list)
self.set_model_version_metric(exp_list, metrics)
with Timer(metrics, "time/gather_experience"):
exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore
return exps, metrics, repr_samples


@Deprecated
class WarmupSampleStrategy(DefaultSampleStrategy):
"""The warmup sample strategy.
Expand Down
4 changes: 2 additions & 2 deletions trinity/buffer/buffer_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ class BufferReader(ABC):
"""Interface of the buffer reader."""

@abstractmethod
def read(self, batch_size: Optional[int] = None) -> List:
def read(self, batch_size: Optional[int] = None, **kwargs) -> List:
"""Read from buffer."""

@abstractmethod
async def read_async(self, batch_size: Optional[int] = None) -> List:
async def read_async(self, batch_size: Optional[int] = None, **kwargs) -> List:
"""Read from buffer asynchronously."""

def __len__(self) -> int:
Expand Down
Loading