Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/buffer/experience_pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import ray
import torch

from tests.tools import RayUnittestBaseAysnc, get_template_config
from tests.tools import RayUnittestBaseAsync, get_template_config
from trinity.buffer import get_buffer_reader
from trinity.buffer.pipelines.experience_pipeline import ExperiencePipeline
from trinity.common.config import (
Expand Down Expand Up @@ -34,7 +34,7 @@ def get_experiences(task_num: int, repeat_times: int = 1, step_num: int = 1) ->
]


class TestExperiencePipeline(RayUnittestBaseAysnc):
class TestExperiencePipeline(RayUnittestBaseAsync):
def setUp(self):
if os.path.exists(BUFFER_FILE_PATH):
os.remove(BUFFER_FILE_PATH)
Expand Down
4 changes: 2 additions & 2 deletions tests/buffer/experience_storage_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from parameterized import parameterized

from tests.tools import RayUnittestBaseAysnc
from tests.tools import RayUnittestBaseAsync
from trinity.buffer.reader.sql_reader import SQLReader
from trinity.buffer.writer.sql_writer import SQLWriter
from trinity.common.config import ExperienceBufferConfig, ReplayBufferConfig
Expand All @@ -17,7 +17,7 @@
DB_PATH = os.path.join(os.path.dirname(__file__), "test.db")


class ExperienceStorageTest(RayUnittestBaseAysnc):
class ExperienceStorageTest(RayUnittestBaseAsync):
def setUp(self):
self.total_num = 8
self.put_batch_size = 2
Expand Down
4 changes: 2 additions & 2 deletions tests/buffer/queue_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from parameterized import parameterized

from tests.tools import RayUnittestBaseAysnc
from tests.tools import RayUnittestBaseAsync
from trinity.buffer.reader.queue_reader import QueueReader
from trinity.buffer.writer.queue_writer import QueueWriter
from trinity.common.config import ExperienceBufferConfig, ReplayBufferConfig
Expand All @@ -17,7 +17,7 @@
BUFFER_FILE_PATH = os.path.join(os.path.dirname(__file__), "test_queue_buffer.jsonl")


class TestQueueBuffer(RayUnittestBaseAysnc):
class TestQueueBuffer(RayUnittestBaseAsync):
@parameterized.expand(
[
(
Expand Down
4 changes: 2 additions & 2 deletions tests/buffer/reader_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from tests.tools import RayUnittestBaseAysnc, get_unittest_dataset_config
from tests.tools import RayUnittestBaseAsync, get_unittest_dataset_config
from trinity.buffer.buffer import get_buffer_reader
from trinity.buffer.reader import READER
from trinity.buffer.reader.file_reader import FileReader, TaskFileReader
Expand All @@ -12,7 +12,7 @@ def __init__(self, config):
super().__init__(config)


class TestBufferReader(RayUnittestBaseAysnc):
class TestBufferReader(RayUnittestBaseAsync):
async def test_buffer_reader_registration(self) -> None:
config = get_unittest_dataset_config("countdown", "train")
config.batch_size = 2
Expand Down
6 changes: 3 additions & 3 deletions tests/buffer/sample_strategy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from parameterized import parameterized_class

from tests.tools import RayUnittestBaseAysnc, get_template_config
from tests.tools import RayUnittestBaseAsync, get_template_config
from trinity.algorithm.sample_strategy import SAMPLE_STRATEGY
from trinity.algorithm.sample_strategy.sample_strategy import SampleStrategy
from trinity.buffer.buffer import get_buffer_writer
Expand All @@ -21,7 +21,7 @@
(6,),
],
)
class ExperienceStorageTest(RayUnittestBaseAysnc):
class ExperienceStorageTest(RayUnittestBaseAsync):
def setUp(self):
self.config = get_template_config()
self.num_steps = 20
Expand Down Expand Up @@ -249,5 +249,5 @@ async def test_sql_staleness_control_sample_strategy(self):

def tearDown(self):
asyncio.run(self.buffer_writer.release())
shutil.rmtree(self.config.checkpoint_job_dir)
shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True)
return super().tearDown()
4 changes: 2 additions & 2 deletions tests/buffer/sql_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from parameterized import parameterized

from tests.tools import RayUnittestBaseAysnc
from tests.tools import RayUnittestBaseAsync
from trinity.buffer import get_buffer_reader
from trinity.buffer.reader.sql_reader import SQLReader
from trinity.buffer.writer.sql_writer import SQLWriter
Expand All @@ -19,7 +19,7 @@
db_path = os.path.join(os.path.dirname(__file__), "test.db")


class TestSQLBuffer(RayUnittestBaseAysnc):
class TestSQLBuffer(RayUnittestBaseAsync):
@parameterized.expand(
[
(True,),
Expand Down
2 changes: 1 addition & 1 deletion tests/buffer/task_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def setUpClass(cls):
def tearDownClass(cls):
super().tearDownClass()
if os.path.exists(cls.temp_output_path):
shutil.rmtree(cls.temp_output_path)
shutil.rmtree(cls.temp_output_path, ignore_errors=True)

def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, int]]) -> None:
for task, index in zip(batch_tasks, indices):
Expand Down
2 changes: 1 addition & 1 deletion tests/common/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,4 +184,4 @@ def test_chat_template_path(self):

def tearDown(self):
if os.path.exists(CHECKPOINT_ROOT_DIR):
shutil.rmtree(CHECKPOINT_ROOT_DIR)
shutil.rmtree(CHECKPOINT_ROOT_DIR, ignore_errors=True)
20 changes: 10 additions & 10 deletions tests/common/vllm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from transformers import AutoTokenizer

from tests.tools import (
RayUnittestBaseAysnc,
RayUnittestBaseAsync,
get_api_model_path,
get_model_path,
get_template_config,
Expand Down Expand Up @@ -113,7 +113,7 @@ async def prepare_engines(engines, auxiliary_engines):
(2, 1, 3, True, True),
],
)
class ModelWrapperTest(RayUnittestBaseAysnc):
class ModelWrapperTest(RayUnittestBaseAsync):
def setUp(self):
# configure the model
self.config = get_template_config()
Expand Down Expand Up @@ -233,7 +233,7 @@ async def test_generate(self):
(20, 5, 15),
],
)
class TestModelLen(RayUnittestBaseAysnc):
class TestModelLen(RayUnittestBaseAsync):
def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
Expand Down Expand Up @@ -302,7 +302,7 @@ def _check_experience(exp):
)


class TestModelLenWithoutPromptTruncation(RayUnittestBaseAysnc):
class TestModelLenWithoutPromptTruncation(RayUnittestBaseAsync):
def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
Expand Down Expand Up @@ -351,7 +351,7 @@ async def test_model_len(self):
)


class TestAPIServer(RayUnittestBaseAysnc):
class TestAPIServer(RayUnittestBaseAsync):
def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
Expand Down Expand Up @@ -482,7 +482,7 @@ async def test_api(self):
"""


class TestLogprobs(RayUnittestBaseAysnc):
class TestLogprobs(RayUnittestBaseAsync):
def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
Expand Down Expand Up @@ -669,7 +669,7 @@ async def test_logprobs_api(self):
)


class TestAsyncAPIServer(RayUnittestBaseAysnc):
class TestAsyncAPIServer(RayUnittestBaseAsync):
def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
Expand Down Expand Up @@ -880,7 +880,7 @@ def test_action_mask_with_tools(self):
(False, None),
],
)
class TestAPIServerToolCall(RayUnittestBaseAysnc):
class TestAPIServerToolCall(RayUnittestBaseAsync):
def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
Expand Down Expand Up @@ -1161,7 +1161,7 @@ async def test_api_tool_calls(self):
)


class TestSuperLongGeneration(RayUnittestBaseAysnc):
class TestSuperLongGeneration(RayUnittestBaseAsync):
def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
Expand Down Expand Up @@ -1217,7 +1217,7 @@ async def test_generate(self):
self.assertGreater(response.logprobs.shape[0], 1000)


class TestTinkerAPI(RayUnittestBaseAysnc):
class TestTinkerAPI(RayUnittestBaseAsync):
"""Test the Tinker API integration with the vLLM engine."""

def setUp(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/explorer/explorer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from tests.tools import (
RayUnittestBase,
RayUnittestBaseAysnc,
RayUnittestBaseAsync,
TensorBoardParser,
get_api_model_path,
get_checkpoint_path,
Expand Down Expand Up @@ -180,7 +180,7 @@ def run_agent(proxy_url, model_path: str):
return response.choices[0].message.content


class ServeTest(RayUnittestBaseAysnc):
class ServeTest(RayUnittestBaseAsync):
def setUp(self):
self.config = get_template_config()
self.config.mode = "serve"
Expand Down
34 changes: 24 additions & 10 deletions tests/manager/synchronizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,14 @@ def run_trainer(config: Config, max_steps: int, intervals: List[int]) -> None:
ray.init(ignore_reinit_error=True, namespace=config.ray_namespace)
trainer_monkey_patch(config, max_steps, intervals)
train(config)
ray.shutdown()


def run_explorer(config: Config, max_steps: int, intervals: List[int]) -> None:
ray.init(ignore_reinit_error=True, namespace=config.ray_namespace)
explorer_monkey_patch(config, max_steps, intervals)
explore(config)
ray.shutdown()


def run_both(
Expand All @@ -97,17 +99,26 @@ def run_both(
trainer_monkey_patch(config, max_steps, trainer_intervals)
explorer_monkey_patch(config, max_steps, explorer_intervals)
both(config)
ray.shutdown()


class BaseTestSynchronizer(unittest.TestCase):
def setUp(self):
if multiprocessing.get_start_method(allow_none=True) != "spawn":
multiprocessing.set_start_method("spawn", force=True)
self.process_list = []

def tearDown(self):
checkpoint_path = get_checkpoint_path()
ray.shutdown(_exiting_interpreter=True)
shutil.rmtree(os.path.join(checkpoint_path, "unittest"))
if os.path.exists(CHECKPOINT_ROOT_DIR):
shutil.rmtree(CHECKPOINT_ROOT_DIR, ignore_errors=True)
for process in self.process_list:
if process.is_alive():
process.terminate()
process.join(timeout=10)
if process.is_alive():
process.kill()
process.join()


class TestSynchronizerExit(BaseTestSynchronizer):
Expand Down Expand Up @@ -151,6 +162,8 @@ def test_synchronizer(self):
target=run_trainer, args=(trainer_config, 8, [2, 1, 2, 1, 2, 1, 2, 1])
)
trainer_process.start()
self.process_list.append(trainer_process)

ray.init(ignore_reinit_error=True)
while True:
try:
Expand All @@ -164,6 +177,7 @@ def test_synchronizer(self):
args=(explorer1_config, 8, [0, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5]),
)
explorer_process_1.start()
self.process_list.append(explorer_process_1)

self.assertEqual(
synchronizer, ray.get_actor("synchronizer", namespace=trainer_config.ray_namespace)
Expand All @@ -176,14 +190,13 @@ def test_synchronizer(self):
except ValueError:
print("waiting for explorer1 to start.")
time.sleep(5)
trainer_process.terminate()
trainer_process.join()

trainer_process.join(timeout=200)
self.assertEqual(
synchronizer, ray.get_actor("synchronizer", namespace=trainer_config.ray_namespace)
)

explorer_process_1.terminate()
explorer_process_1.join()
explorer_process_1.join(timeout=200)
time.sleep(6)
with self.assertRaises(ValueError):
ray.get_actor("synchronizer", namespace=trainer_config.ray_namespace)
Expand Down Expand Up @@ -278,6 +291,8 @@ def test_synchronizer(self):
target=run_trainer, args=(trainer_config, self.max_steps, self.trainer_intervals)
)
trainer_process.start()
self.process_list.append(trainer_process)

ray.init(ignore_reinit_error=True)
while True:
try:
Expand All @@ -291,10 +306,12 @@ def test_synchronizer(self):
args=(explorer1_config, self.max_steps, self.explorer1_intervals),
)
explorer_process_1.start()
self.process_list.append(explorer_process_1)
explorer_process_2 = multiprocessing.Process(
target=run_explorer, args=(explorer2_config, self.max_steps, self.explorer2_intervals)
)
explorer_process_2.start()
self.process_list.append(explorer_process_2)

explorer_process_1.join(timeout=200)
explorer_process_2.join(timeout=200)
Expand Down Expand Up @@ -364,6 +381,7 @@ def test_synchronizer(self):
args=(config, self.max_steps, self.trainer_intervals, self.explorer_intervals),
)
both_process.start()
self.process_list.append(both_process)
both_process.join(timeout=200)

# check the tensorboard
Expand All @@ -375,7 +393,3 @@ def test_synchronizer(self):
)
rollout_metrics = parser.metric_list("rollout")
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8)

def tearDown(self):
if os.path.exists(CHECKPOINT_ROOT_DIR):
shutil.rmtree(CHECKPOINT_ROOT_DIR)
2 changes: 1 addition & 1 deletion tests/service/data_juicer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ async def test_data_juicer_operators(self):
class TestDataJuicerTaskPipeline(RayUnittestBase):
def setUp(self):
if os.path.exists(TASKSET_OUTPUT_DIR):
shutil.rmtree(TASKSET_OUTPUT_DIR)
shutil.rmtree(TASKSET_OUTPUT_DIR, ignore_errors=True)

def test_data_juicer_task_pipeline(self):
config = get_template_config()
Expand Down
2 changes: 1 addition & 1 deletion tests/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def tearDownClass(cls):
ray.shutdown(_exiting_interpreter=True)


class RayUnittestBaseAysnc(unittest.IsolatedAsyncioTestCase):
class RayUnittestBaseAsync(unittest.IsolatedAsyncioTestCase):
@classmethod
def setUpClass(cls):
ray.init(ignore_reinit_error=True, namespace="trinity_unittest")
Expand Down
Loading