Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ classifiers = [
requires-python = ">=3.10,<3.13"
dependencies = [
"verl==0.5.0",
"ray[default]>=2.48.0",
"ray[default]>=2.50.0",
"vllm>=0.10.2,<=0.11.0",
"tensordict",
"wandb",
Expand Down
9 changes: 4 additions & 5 deletions scripts/docker/Dockerfile.uv
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,14 @@ RUN chmod 1777 /tmp && apt update && apt install -y \
&& ln -sf /usr/bin/python3 /usr/bin/python \
&& ln -sf /usr/bin/pip3 /usr/bin/pip

# For Aliyun users: update pip mirror to aliyun to speed up pip install
# ENV PIP_INDEX_URL=http://mirrors.cloud.aliyuncs.com/pypi/simple/
# ENV PIP_TRUSTED_HOST=mirrors.cloud.aliyuncs.com

ENV VIRTUAL_ENV=/opt/venv

# copy the Trinity-RFT dir into the workspace
COPY . .

# For Aliyun users: update pip mirror to aliyun to speed up pip install
# ENV UV_DEFAULT_INDEX=http://mirrors.cloud.aliyuncs.com/pypi/simple/

# Install uv
RUN pip install uv && uv venv /opt/venv --python=python3.12

Expand All @@ -40,7 +39,7 @@ RUN . /opt/venv/bin/activate && \

# Install flash_attn and Megatron
RUN . /opt/venv/bin/activate && \
uv pip install flash_attn==2.8.1 --no-deps --no-cache-dir && \
uv pip install flash_attn==2.8.1 --no-cache-dir && \
uv pip install -e .[megatron] && \
NVCC_APPEND_FLAGS="--threads 4" APEX_PARALLEL_BUILD=8 \
uv pip install -v --no-build-isolation \
Expand Down
61 changes: 39 additions & 22 deletions tests/common/vllm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ async def test_api(self):
)
self.assertEqual(2, len(response.choices))
self.assertTrue(hasattr(response.choices[0], "token_ids"))
self.assertTrue(len(response.choices[0].token_ids) > 0)
self.assertTrue(response.choices[0].token_ids is None)
with self.assertRaises(ValueError):
self.model_wrapper_no_history.extract_experience_from_history()
self.assertEqual(len(self.model_wrapper_no_history.history), 0)
Expand Down Expand Up @@ -496,6 +496,7 @@ def setUp(self):
self.config.explorer.rollout_model.tensor_parallel_size = 1
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
self.config.explorer.rollout_model.enable_openai_api = True
self.config.explorer.rollout_model.enable_log_requests = True

self.config.check_and_update()
self.engines, self.auxiliary_engines = create_inference_models(self.config)
Expand Down Expand Up @@ -540,17 +541,17 @@ async def test_logprobs_api(self):
logprobs_4 = self.model_wrapper.logprobs(response_2.tokens.tolist(), temperature=0.8)
self.assertEqual(logprobs_1.shape, logprobs_2.shape)
self.assertEqual(logprobs_3.shape, logprobs_4.shape)
self.assertFalse(torch.allclose(logprobs_1, logprobs_2, rtol=0.4))
self.assertFalse(torch.allclose(logprobs_3, logprobs_4, atol=0.4))
self.assertFalse(torch.allclose(logprobs_1, logprobs_2, rtol=0.3, atol=1e-3))
self.assertFalse(torch.allclose(logprobs_3, logprobs_4, rtol=0.3, atol=1e-3))
logprobs_1_prompt = logprobs_1[: response_1.prompt_length - 1]
logprobs_2_prompt = logprobs_2[: response_1.prompt_length - 1]
logprobs_3_prompt = logprobs_3[: response_2.prompt_length - 1]
logprobs_4_prompt = logprobs_4[: response_2.prompt_length - 1]
self.assertEqual(logprobs_1_prompt.shape, logprobs_2_prompt.shape)
self.assertFalse(torch.allclose(logprobs_1_prompt, logprobs_2_prompt, rtol=0.4))
self.assertFalse(torch.allclose(logprobs_3_prompt, logprobs_4_prompt, rtol=0.4))
self.assertTrue(torch.allclose(logprobs_1_prompt, logprobs_3_prompt, rtol=0.4))
self.assertTrue(torch.allclose(logprobs_2_prompt, logprobs_4_prompt, rtol=0.4))
self.assertFalse(torch.allclose(logprobs_1_prompt, logprobs_2_prompt, rtol=0.3, atol=1e-3))
self.assertFalse(torch.allclose(logprobs_3_prompt, logprobs_4_prompt, rtol=0.3, atol=1e-3))
self.assertTrue(torch.allclose(logprobs_1_prompt, logprobs_3_prompt, rtol=0.3, atol=1e-3))
self.assertTrue(torch.allclose(logprobs_2_prompt, logprobs_4_prompt, rtol=0.3, atol=1e-3))
logprobs_1_response = logprobs_1[response_1.prompt_length - 1 :]
logprobs_2_response = logprobs_2[response_1.prompt_length - 1 :]
logprobs_3_response = logprobs_3[response_2.prompt_length - 1 :]
Expand All @@ -559,10 +560,18 @@ async def test_logprobs_api(self):
self.assertEqual(logprobs_3_response.shape, logprobs_4_response.shape)
self.assertEqual(logprobs_1_response.shape, logprobs_2_response.shape)
self.assertEqual(response_1.logprobs.shape, logprobs_1_response.shape)
self.assertTrue(torch.allclose(response_1.logprobs, logprobs_1_response, rtol=0.5))
self.assertFalse(torch.allclose(response_1.logprobs, logprobs_2_response, rtol=0.5))
self.assertTrue(torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.8))
self.assertFalse(torch.allclose(response_2.logprobs, logprobs_3_response, rtol=0.8))
self.assertTrue(
torch.allclose(response_1.logprobs, logprobs_1_response, rtol=0.3, atol=1e-3)
)
self.assertFalse(
torch.allclose(response_1.logprobs, logprobs_2_response, rtol=0.3, atol=1e-3)
)
self.assertTrue(
torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.3, atol=1e-3)
)
self.assertFalse(
torch.allclose(response_2.logprobs, logprobs_3_response, rtol=0.3, atol=1e-3)
)

# test vllm engine logprobs with different temperature
response_1 = self.model_wrapper.chat(
Expand All @@ -581,17 +590,17 @@ async def test_logprobs_api(self):
logprobs_4 = self.model_wrapper.logprobs(response_2.tokens.tolist(), temperature=0.8)
self.assertEqual(logprobs_1.shape, logprobs_2.shape)
self.assertEqual(logprobs_3.shape, logprobs_4.shape)
self.assertFalse(torch.allclose(logprobs_1, logprobs_2, rtol=0.4))
self.assertFalse(torch.allclose(logprobs_3, logprobs_4, atol=0.4))
self.assertFalse(torch.allclose(logprobs_1, logprobs_2, rtol=0.3, atol=1e-3))
self.assertFalse(torch.allclose(logprobs_3, logprobs_4, rtol=0.3, atol=1e-3))
logprobs_1_prompt = logprobs_1[: response_1.prompt_length - 1]
logprobs_2_prompt = logprobs_2[: response_1.prompt_length - 1]
logprobs_3_prompt = logprobs_3[: response_2.prompt_length - 1]
logprobs_4_prompt = logprobs_4[: response_2.prompt_length - 1]
self.assertEqual(logprobs_1_prompt.shape, logprobs_2_prompt.shape)
self.assertFalse(torch.allclose(logprobs_1_prompt, logprobs_2_prompt, rtol=0.4))
self.assertFalse(torch.allclose(logprobs_3_prompt, logprobs_4_prompt, rtol=0.4))
self.assertTrue(torch.allclose(logprobs_1_prompt, logprobs_3_prompt, rtol=0.4))
self.assertTrue(torch.allclose(logprobs_2_prompt, logprobs_4_prompt, rtol=0.4))
self.assertFalse(torch.allclose(logprobs_1_prompt, logprobs_2_prompt, rtol=0.3, atol=1e-3))
self.assertFalse(torch.allclose(logprobs_3_prompt, logprobs_4_prompt, rtol=0.3, atol=1e-3))
self.assertTrue(torch.allclose(logprobs_1_prompt, logprobs_3_prompt, rtol=0.3, atol=1e-3))
self.assertTrue(torch.allclose(logprobs_2_prompt, logprobs_4_prompt, rtol=0.3, atol=1e-3))
logprobs_1_response = logprobs_1[response_1.prompt_length - 1 :]
logprobs_2_response = logprobs_2[response_1.prompt_length - 1 :]
logprobs_3_response = logprobs_3[response_2.prompt_length - 1 :]
Expand All @@ -600,10 +609,18 @@ async def test_logprobs_api(self):
self.assertEqual(logprobs_3_response.shape, logprobs_4_response.shape)
self.assertEqual(logprobs_1_response.shape, logprobs_2_response.shape)
self.assertEqual(response_1.logprobs.shape, logprobs_1_response.shape)
self.assertTrue(torch.allclose(response_1.logprobs, logprobs_1_response, rtol=0.5))
self.assertFalse(torch.allclose(response_1.logprobs, logprobs_2_response, rtol=0.5))
self.assertTrue(torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.8))
self.assertFalse(torch.allclose(response_2.logprobs, logprobs_3_response, rtol=0.8))
self.assertTrue(
torch.allclose(response_1.logprobs, logprobs_1_response, rtol=0.3, atol=1e-3)
)
self.assertFalse(
torch.allclose(response_1.logprobs, logprobs_2_response, rtol=0.3, atol=1e-3)
)
self.assertTrue(
torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.3, atol=1e-3)
)
self.assertFalse(
torch.allclose(response_2.logprobs, logprobs_3_response, rtol=0.3, atol=1e-3)
)

# test openai api and vllm engine logprobs consistency
await self.model_wrapper.clean_workflow_state()
Expand Down Expand Up @@ -747,7 +764,7 @@ async def test_api_async(self):
)
self.assertEqual(2, len(response.choices))
self.assertTrue(hasattr(response.choices[0], "token_ids"))
self.assertTrue(len(response.choices[0].token_ids) > 0)
self.assertTrue(response.choices[0].token_ids is None)
with self.assertRaises(ValueError):
self.model_wrapper_no_history.extract_experience_from_history()
self.assertEqual(len(self.model_wrapper_no_history.history), 0)
Expand Down
1 change: 1 addition & 0 deletions trinity/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def debug(
os.environ[PLUGIN_DIRS_ENV_VAR] = plugin_dir
load_plugins()
config = load_config(config_path)
config.mode = "explore"
config.check_and_update()
sys.path.insert(0, os.getcwd())
config.ray_namespace = DEBUG_NAMESPACE
Expand Down
2 changes: 1 addition & 1 deletion trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ class InferenceModelConfig:
use_v1: bool = True
enforce_eager: bool = False
enable_prefix_caching: bool = False
enable_chunked_prefill: bool = False
enable_chunked_prefill: bool = True
gpu_memory_utilization: float = 0.9
dtype: str = "bfloat16"
seed: int = 42
Expand Down
45 changes: 36 additions & 9 deletions trinity/common/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import asyncio
import socket
from abc import ABC, abstractmethod
from functools import partial
from typing import Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import httpx
import numpy as np
Expand All @@ -13,7 +12,6 @@
import torch
from PIL import Image
from torch import Tensor
from vllm.lora.request import LoRARequest

from trinity.common.constants import RunningStatus
from trinity.common.experience import Experience
Expand Down Expand Up @@ -96,7 +94,17 @@ def __init__(
engine_type: str = "vllm",
enable_lora: bool = False,
enable_history: bool = False,
enable_thinking: Optional[bool] = None,
):
"""Initialize the ModelWrapper.

Args:
model (InferenceModel): The inference model Ray actor.
engine_type (str): The type of the model engine. Default to "vllm".
enable_lora (bool): Whether to enable LoRA. Default to False.
enable_history (bool): Whether to enable history recording. Default to False.
enable_thinking (Optional[bool]): Whether to enable thinking mode. Default to None. Only used for Qwen3 series models.
"""
assert engine_type.startswith("vllm"), "Only vLLM model is supported for now."
self.model = model
self.api_address: str = None
Expand All @@ -105,6 +113,7 @@ def __init__(
self.logger = get_logger(__name__)
self.enable_lora = enable_lora
self.enable_history = enable_history
self.enable_thinking = enable_thinking
self.history = []
self.status = RunningStatus.RUNNING
self.workflow_state: Dict = {}
Expand Down Expand Up @@ -270,13 +279,13 @@ async def model_path_async(self) -> str:
"""Get the model path."""
return await self.model.get_model_path.remote()

def get_lora_request(self) -> Optional[LoRARequest]:
def get_lora_request(self) -> Any:
if self.enable_lora:
return ray.get(self.model.get_lora_request.remote())
else:
return None

async def get_lora_request_async(self) -> Optional[LoRARequest]:
async def get_lora_request_async(self) -> Any:
if self.enable_lora:
return await self.model.get_lora_request.remote()
else:
Expand All @@ -303,10 +312,18 @@ def get_openai_client(self) -> openai.OpenAI:
)
if self.enable_history:
# add a decorator to the openai client to record history
ori_create = partial(self.openai_client.chat.completions.create, logprobs=True)

ori_create = self.openai_client.chat.completions.create

def record_chat_completions(*args, **kwargs):
response = ori_create(*args, **kwargs)
logprobs = kwargs.pop("logprobs", True)
extra_body = kwargs.pop("extra_body", {})
if self.enable_thinking is not None:
if "chat_template_kwargs" not in extra_body:
extra_body["chat_template_kwargs"] = {}
extra_body["chat_template_kwargs"]["enable_thinking"] = self.enable_thinking
extra_body["return_token_ids"] = True
response = ori_create(*args, extra_body=extra_body, logprobs=logprobs, **kwargs)
self.history.extend(convert_api_output_to_experience(response))
return response

Expand All @@ -333,10 +350,20 @@ def get_openai_async_client(self) -> openai.AsyncOpenAI:
)
if self.enable_history:
# add a decorator to the openai client to record history
ori_create = partial(self.openai_async_client.chat.completions.create, logprobs=True)

ori_create = self.openai_async_client.chat.completions.create

async def record_chat_completions(*args, **kwargs):
response = await ori_create(*args, **kwargs)
logprobs = kwargs.pop("logprobs", True)
extra_body = kwargs.pop("extra_body", {})
if self.enable_thinking is not None:
if "chat_template_kwargs" not in extra_body:
extra_body["chat_template_kwargs"] = {}
extra_body["chat_template_kwargs"]["enable_thinking"] = self.enable_thinking
extra_body["return_token_ids"] = True
response = await ori_create(
*args, extra_body=extra_body, logprobs=logprobs, **kwargs
)
self.history.extend(convert_api_output_to_experience(response))
return response

Expand Down
Loading