diff --git a/pyproject.toml b/pyproject.toml index 225079c305..f7a8162bfe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ requires-python = ">=3.10,<3.13" dependencies = [ "verl==0.5.0", - "ray[default]>=2.48.0", + "ray[default]>=2.50.0", "vllm>=0.10.2,<=0.11.0", "tensordict", "wandb", diff --git a/scripts/docker/Dockerfile.uv b/scripts/docker/Dockerfile.uv index 3d40a279c5..82d5389ada 100644 --- a/scripts/docker/Dockerfile.uv +++ b/scripts/docker/Dockerfile.uv @@ -22,15 +22,14 @@ RUN chmod 1777 /tmp && apt update && apt install -y \ && ln -sf /usr/bin/python3 /usr/bin/python \ && ln -sf /usr/bin/pip3 /usr/bin/pip -# For Aliyun users: update pip mirror to aliyun to speed up pip install -# ENV PIP_INDEX_URL=http://mirrors.cloud.aliyuncs.com/pypi/simple/ -# ENV PIP_TRUSTED_HOST=mirrors.cloud.aliyuncs.com - ENV VIRTUAL_ENV=/opt/venv # copy the Trinity-RFT dir into the workspace COPY . . +# For Aliyun users: update pip mirror to aliyun to speed up pip install +# ENV UV_DEFAULT_INDEX=http://mirrors.cloud.aliyuncs.com/pypi/simple/ + # Install uv RUN pip install uv && uv venv /opt/venv --python=python3.12 @@ -40,7 +39,7 @@ RUN . /opt/venv/bin/activate && \ # Install flash_attn and Megatron RUN . /opt/venv/bin/activate && \ - uv pip install flash_attn==2.8.1 --no-deps --no-cache-dir && \ + uv pip install flash_attn==2.8.1 --no-cache-dir && \ uv pip install -e .[megatron] && \ NVCC_APPEND_FLAGS="--threads 4" APEX_PARALLEL_BUILD=8 \ uv pip install -v --no-build-isolation \ diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index db7156cd4d..18731f361e 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -442,7 +442,7 @@ async def test_api(self): ) self.assertEqual(2, len(response.choices)) self.assertTrue(hasattr(response.choices[0], "token_ids")) - self.assertTrue(len(response.choices[0].token_ids) > 0) + self.assertTrue(response.choices[0].token_ids is None) with self.assertRaises(ValueError): self.model_wrapper_no_history.extract_experience_from_history() self.assertEqual(len(self.model_wrapper_no_history.history), 0) @@ -496,6 +496,7 @@ def setUp(self): self.config.explorer.rollout_model.tensor_parallel_size = 1 self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE self.config.explorer.rollout_model.enable_openai_api = True + self.config.explorer.rollout_model.enable_log_requests = True self.config.check_and_update() self.engines, self.auxiliary_engines = create_inference_models(self.config) @@ -540,17 +541,17 @@ async def test_logprobs_api(self): logprobs_4 = self.model_wrapper.logprobs(response_2.tokens.tolist(), temperature=0.8) self.assertEqual(logprobs_1.shape, logprobs_2.shape) self.assertEqual(logprobs_3.shape, logprobs_4.shape) - self.assertFalse(torch.allclose(logprobs_1, logprobs_2, rtol=0.4)) - self.assertFalse(torch.allclose(logprobs_3, logprobs_4, atol=0.4)) + self.assertFalse(torch.allclose(logprobs_1, logprobs_2, rtol=0.3, atol=1e-3)) + self.assertFalse(torch.allclose(logprobs_3, logprobs_4, rtol=0.3, atol=1e-3)) logprobs_1_prompt = logprobs_1[: response_1.prompt_length - 1] logprobs_2_prompt = logprobs_2[: response_1.prompt_length - 1] logprobs_3_prompt = logprobs_3[: response_2.prompt_length - 1] logprobs_4_prompt = logprobs_4[: response_2.prompt_length - 1] self.assertEqual(logprobs_1_prompt.shape, logprobs_2_prompt.shape) - self.assertFalse(torch.allclose(logprobs_1_prompt, logprobs_2_prompt, rtol=0.4)) - self.assertFalse(torch.allclose(logprobs_3_prompt, logprobs_4_prompt, rtol=0.4)) - self.assertTrue(torch.allclose(logprobs_1_prompt, logprobs_3_prompt, rtol=0.4)) - self.assertTrue(torch.allclose(logprobs_2_prompt, logprobs_4_prompt, rtol=0.4)) + self.assertFalse(torch.allclose(logprobs_1_prompt, logprobs_2_prompt, rtol=0.3, atol=1e-3)) + self.assertFalse(torch.allclose(logprobs_3_prompt, logprobs_4_prompt, rtol=0.3, atol=1e-3)) + self.assertTrue(torch.allclose(logprobs_1_prompt, logprobs_3_prompt, rtol=0.3, atol=1e-3)) + self.assertTrue(torch.allclose(logprobs_2_prompt, logprobs_4_prompt, rtol=0.3, atol=1e-3)) logprobs_1_response = logprobs_1[response_1.prompt_length - 1 :] logprobs_2_response = logprobs_2[response_1.prompt_length - 1 :] logprobs_3_response = logprobs_3[response_2.prompt_length - 1 :] @@ -559,10 +560,18 @@ async def test_logprobs_api(self): self.assertEqual(logprobs_3_response.shape, logprobs_4_response.shape) self.assertEqual(logprobs_1_response.shape, logprobs_2_response.shape) self.assertEqual(response_1.logprobs.shape, logprobs_1_response.shape) - self.assertTrue(torch.allclose(response_1.logprobs, logprobs_1_response, rtol=0.5)) - self.assertFalse(torch.allclose(response_1.logprobs, logprobs_2_response, rtol=0.5)) - self.assertTrue(torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.8)) - self.assertFalse(torch.allclose(response_2.logprobs, logprobs_3_response, rtol=0.8)) + self.assertTrue( + torch.allclose(response_1.logprobs, logprobs_1_response, rtol=0.3, atol=1e-3) + ) + self.assertFalse( + torch.allclose(response_1.logprobs, logprobs_2_response, rtol=0.3, atol=1e-3) + ) + self.assertTrue( + torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.3, atol=1e-3) + ) + self.assertFalse( + torch.allclose(response_2.logprobs, logprobs_3_response, rtol=0.3, atol=1e-3) + ) # test vllm engine logprobs with different temperature response_1 = self.model_wrapper.chat( @@ -581,17 +590,17 @@ async def test_logprobs_api(self): logprobs_4 = self.model_wrapper.logprobs(response_2.tokens.tolist(), temperature=0.8) self.assertEqual(logprobs_1.shape, logprobs_2.shape) self.assertEqual(logprobs_3.shape, logprobs_4.shape) - self.assertFalse(torch.allclose(logprobs_1, logprobs_2, rtol=0.4)) - self.assertFalse(torch.allclose(logprobs_3, logprobs_4, atol=0.4)) + self.assertFalse(torch.allclose(logprobs_1, logprobs_2, rtol=0.3, atol=1e-3)) + self.assertFalse(torch.allclose(logprobs_3, logprobs_4, rtol=0.3, atol=1e-3)) logprobs_1_prompt = logprobs_1[: response_1.prompt_length - 1] logprobs_2_prompt = logprobs_2[: response_1.prompt_length - 1] logprobs_3_prompt = logprobs_3[: response_2.prompt_length - 1] logprobs_4_prompt = logprobs_4[: response_2.prompt_length - 1] self.assertEqual(logprobs_1_prompt.shape, logprobs_2_prompt.shape) - self.assertFalse(torch.allclose(logprobs_1_prompt, logprobs_2_prompt, rtol=0.4)) - self.assertFalse(torch.allclose(logprobs_3_prompt, logprobs_4_prompt, rtol=0.4)) - self.assertTrue(torch.allclose(logprobs_1_prompt, logprobs_3_prompt, rtol=0.4)) - self.assertTrue(torch.allclose(logprobs_2_prompt, logprobs_4_prompt, rtol=0.4)) + self.assertFalse(torch.allclose(logprobs_1_prompt, logprobs_2_prompt, rtol=0.3, atol=1e-3)) + self.assertFalse(torch.allclose(logprobs_3_prompt, logprobs_4_prompt, rtol=0.3, atol=1e-3)) + self.assertTrue(torch.allclose(logprobs_1_prompt, logprobs_3_prompt, rtol=0.3, atol=1e-3)) + self.assertTrue(torch.allclose(logprobs_2_prompt, logprobs_4_prompt, rtol=0.3, atol=1e-3)) logprobs_1_response = logprobs_1[response_1.prompt_length - 1 :] logprobs_2_response = logprobs_2[response_1.prompt_length - 1 :] logprobs_3_response = logprobs_3[response_2.prompt_length - 1 :] @@ -600,10 +609,18 @@ async def test_logprobs_api(self): self.assertEqual(logprobs_3_response.shape, logprobs_4_response.shape) self.assertEqual(logprobs_1_response.shape, logprobs_2_response.shape) self.assertEqual(response_1.logprobs.shape, logprobs_1_response.shape) - self.assertTrue(torch.allclose(response_1.logprobs, logprobs_1_response, rtol=0.5)) - self.assertFalse(torch.allclose(response_1.logprobs, logprobs_2_response, rtol=0.5)) - self.assertTrue(torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.8)) - self.assertFalse(torch.allclose(response_2.logprobs, logprobs_3_response, rtol=0.8)) + self.assertTrue( + torch.allclose(response_1.logprobs, logprobs_1_response, rtol=0.3, atol=1e-3) + ) + self.assertFalse( + torch.allclose(response_1.logprobs, logprobs_2_response, rtol=0.3, atol=1e-3) + ) + self.assertTrue( + torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.3, atol=1e-3) + ) + self.assertFalse( + torch.allclose(response_2.logprobs, logprobs_3_response, rtol=0.3, atol=1e-3) + ) # test openai api and vllm engine logprobs consistency await self.model_wrapper.clean_workflow_state() @@ -747,7 +764,7 @@ async def test_api_async(self): ) self.assertEqual(2, len(response.choices)) self.assertTrue(hasattr(response.choices[0], "token_ids")) - self.assertTrue(len(response.choices[0].token_ids) > 0) + self.assertTrue(response.choices[0].token_ids is None) with self.assertRaises(ValueError): self.model_wrapper_no_history.extract_experience_from_history() self.assertEqual(len(self.model_wrapper_no_history.history), 0) diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 6e6c462c56..d9c0d95771 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -249,6 +249,7 @@ def debug( os.environ[PLUGIN_DIRS_ENV_VAR] = plugin_dir load_plugins() config = load_config(config_path) + config.mode = "explore" config.check_and_update() sys.path.insert(0, os.getcwd()) config.ray_namespace = DEBUG_NAMESPACE diff --git a/trinity/common/config.py b/trinity/common/config.py index 9738ec3b8e..b2a824a2db 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -484,7 +484,7 @@ class InferenceModelConfig: use_v1: bool = True enforce_eager: bool = False enable_prefix_caching: bool = False - enable_chunked_prefill: bool = False + enable_chunked_prefill: bool = True gpu_memory_utilization: float = 0.9 dtype: str = "bfloat16" seed: int = 42 diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index e08062f401..110db0a3b7 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -3,8 +3,7 @@ import asyncio import socket from abc import ABC, abstractmethod -from functools import partial -from typing import Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import httpx import numpy as np @@ -13,7 +12,6 @@ import torch from PIL import Image from torch import Tensor -from vllm.lora.request import LoRARequest from trinity.common.constants import RunningStatus from trinity.common.experience import Experience @@ -96,7 +94,17 @@ def __init__( engine_type: str = "vllm", enable_lora: bool = False, enable_history: bool = False, + enable_thinking: Optional[bool] = None, ): + """Initialize the ModelWrapper. + + Args: + model (InferenceModel): The inference model Ray actor. + engine_type (str): The type of the model engine. Default to "vllm". + enable_lora (bool): Whether to enable LoRA. Default to False. + enable_history (bool): Whether to enable history recording. Default to False. + enable_thinking (Optional[bool]): Whether to enable thinking mode. Default to None. Only used for Qwen3 series models. + """ assert engine_type.startswith("vllm"), "Only vLLM model is supported for now." self.model = model self.api_address: str = None @@ -105,6 +113,7 @@ def __init__( self.logger = get_logger(__name__) self.enable_lora = enable_lora self.enable_history = enable_history + self.enable_thinking = enable_thinking self.history = [] self.status = RunningStatus.RUNNING self.workflow_state: Dict = {} @@ -270,13 +279,13 @@ async def model_path_async(self) -> str: """Get the model path.""" return await self.model.get_model_path.remote() - def get_lora_request(self) -> Optional[LoRARequest]: + def get_lora_request(self) -> Any: if self.enable_lora: return ray.get(self.model.get_lora_request.remote()) else: return None - async def get_lora_request_async(self) -> Optional[LoRARequest]: + async def get_lora_request_async(self) -> Any: if self.enable_lora: return await self.model.get_lora_request.remote() else: @@ -303,10 +312,18 @@ def get_openai_client(self) -> openai.OpenAI: ) if self.enable_history: # add a decorator to the openai client to record history - ori_create = partial(self.openai_client.chat.completions.create, logprobs=True) + + ori_create = self.openai_client.chat.completions.create def record_chat_completions(*args, **kwargs): - response = ori_create(*args, **kwargs) + logprobs = kwargs.pop("logprobs", True) + extra_body = kwargs.pop("extra_body", {}) + if self.enable_thinking is not None: + if "chat_template_kwargs" not in extra_body: + extra_body["chat_template_kwargs"] = {} + extra_body["chat_template_kwargs"]["enable_thinking"] = self.enable_thinking + extra_body["return_token_ids"] = True + response = ori_create(*args, extra_body=extra_body, logprobs=logprobs, **kwargs) self.history.extend(convert_api_output_to_experience(response)) return response @@ -333,10 +350,20 @@ def get_openai_async_client(self) -> openai.AsyncOpenAI: ) if self.enable_history: # add a decorator to the openai client to record history - ori_create = partial(self.openai_async_client.chat.completions.create, logprobs=True) + + ori_create = self.openai_async_client.chat.completions.create async def record_chat_completions(*args, **kwargs): - response = await ori_create(*args, **kwargs) + logprobs = kwargs.pop("logprobs", True) + extra_body = kwargs.pop("extra_body", {}) + if self.enable_thinking is not None: + if "chat_template_kwargs" not in extra_body: + extra_body["chat_template_kwargs"] = {} + extra_body["chat_template_kwargs"]["enable_thinking"] = self.enable_thinking + extra_body["return_token_ids"] = True + response = await ori_create( + *args, extra_body=extra_body, logprobs=logprobs, **kwargs + ) self.history.extend(convert_api_output_to_experience(response)) return response diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index dbddd53a5f..d1f1d3f61f 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -2,17 +2,15 @@ import asyncio import os +from collections import defaultdict from typing import Any, Dict, List, Optional, Sequence import numpy as np import ray import torch -import vllm from packaging.version import parse as parse_version from PIL import Image from transformers import AutoProcessor -from vllm.lora.request import LoRARequest -from vllm.sampling_params import RequestOutputKind from trinity.common.config import InferenceModelConfig from trinity.common.experience import Experience @@ -22,7 +20,7 @@ ) from trinity.common.models.model import InferenceModel from trinity.common.models.utils import get_action_mask_method -from trinity.common.models.vllm_patch.api_patch import get_vllm_version +from trinity.common.models.vllm_patch import get_vllm_version from trinity.utils.log import get_logger @@ -38,20 +36,25 @@ def __init__( self, config: InferenceModelConfig, ) -> None: + import vllm + from vllm.sampling_params import RequestOutputKind + self.logger = get_logger(__name__) + self.vllm_version = get_vllm_version() self.config = config self.use_v1 = config.use_v1 if config.tensor_parallel_size != 1: os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" os.environ["VLLM_RAY_BUNDLE_INDICES"] = config.bundle_indices - if not vllm.envs.is_set("VLLM_USE_V1"): + if self.vllm_version <= parse_version("0.11.0") and not vllm.envs.is_set("VLLM_USE_V1"): self.logger.info(f"Using vLLM v{int(config.use_v1)} engine") os.environ["VLLM_USE_V1"] = str(int(config.use_v1)) if config.use_v1: + os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm" os.environ["VLLM_RAY_PER_WORKER_GPUS"] = str(int(config.use_v1)) os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" - if get_vllm_version() >= parse_version("0.11.0"): + if self.vllm_version >= parse_version("0.11.0"): os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0" if not config.enforce_eager: # To avoid torch compile conflicts when multiple model are started simultaneously. @@ -81,11 +84,22 @@ def __init__( max_model_len = config.max_model_len self.enable_lora = config.enable_lora self.default_lora_path = config.lora_kwargs.pop("default_lora_path", None) - rope_kwargs = { - key: getattr(config, key) - for key in ["rope_scaling", "rope_theta"] - if getattr(config, key) is not None - } + if self.vllm_version >= parse_version("0.12.0"): + rope_params = defaultdict(dict) + if config.rope_scaling is not None: + rope_params["rope_parameters"] = config.rope_scaling + if config.rope_theta is not None: + rope_params["rope_parameters"]["rope_theta"] = config.rope_theta + if len(rope_params) > 0: + rope_kwargs = {"hf_overrides": rope_params} + else: + rope_kwargs = {} + else: + rope_kwargs = { + key: getattr(config, key) + for key in ["rope_scaling", "rope_theta"] + if getattr(config, key) is not None + } engine_args = vllm.AsyncEngineArgs( model=config.model_path, enforce_eager=config.enforce_eager, @@ -99,7 +113,6 @@ def __init__( trust_remote_code=True, task="generate", gpu_memory_utilization=config.gpu_memory_utilization, - enable_chunked_prefill=config.enable_chunked_prefill, # max_num_batched_tokens=256, # you can further set this parameter to reduce the vllm peak memory usage override_generation_config={ # TODO: find a way to unittest this "temperature": config.temperature, @@ -114,12 +127,13 @@ def __init__( **rope_kwargs, **config.lora_kwargs, ) - if get_vllm_version() > parse_version("0.10.0"): + if self.vllm_version > parse_version("0.10.0"): engine_args.enable_log_requests = config.enable_log_requests else: engine_args.disable_log_requests = not config.enable_log_requests - if get_vllm_version() >= parse_version("0.11.0"): + if self.vllm_version >= parse_version("0.11.0"): engine_args.reasoning_parser = config.reasoning_parser + self.async_llm = vllm.AsyncLLMEngine.from_engine_args(engine_args) self.processor = None self.tokenizer = None @@ -157,9 +171,7 @@ async def prepare( await self.run_api_server() self._prepared = True - async def chat( - self, messages: List[Dict], lora_request: LoRARequest = None, **kwargs - ) -> Sequence[Experience]: + async def chat(self, messages: List[Dict], lora_request=None, **kwargs) -> Sequence[Experience]: """Chat with the model with a list of messages in async. Args: @@ -190,9 +202,7 @@ async def chat( ) return await self.generate(prompt=prompt, lora_request=lora_request, **kwargs) - async def generate( - self, prompt: str, lora_request: LoRARequest = None, **kwargs - ) -> Sequence[Experience]: + async def generate(self, prompt: str, lora_request=None, **kwargs) -> Sequence[Experience]: """Generate a response from the provided prompt in async. Args: @@ -361,7 +371,7 @@ async def generate_mm( async def logprobs( # type: ignore [override] self, token_ids: List[int], - lora_request: LoRARequest = None, + lora_request=None, temperature: Optional[float] = None, ) -> torch.Tensor: """Calculate the logprobs of the given tokens in async. Please slice the result carefully @@ -392,9 +402,7 @@ async def logprobs( # type: ignore [override] dtype=torch.float32, ) - async def _generate_internal( - self, prompt: Any, lora_request: LoRARequest = None, **kwargs - ) -> Any: + async def _generate_internal(self, prompt: Any, lora_request=None, **kwargs) -> Any: # Send the request to the LLM engine. self.request_id += 1 stream = self.async_llm.generate( @@ -561,23 +569,42 @@ async def run_api_server(self) -> bool: self.logger.info("OpenAI API server is already running. Skipping...") return True # already running - from trinity.common.models.vllm_patch.api_patch import ( - run_api_server_in_ray_actor, - ) - api_server_host, api_server_port = self.get_available_address() - self.api_server = asyncio.create_task( - run_api_server_in_ray_actor( - self.async_llm, - api_server_host, - api_server_port, - self.config.model_path, # type: ignore [arg-type] - self.config.enable_auto_tool_choice, - self.config.tool_call_parser, - self.config.reasoning_parser, - self.config.enable_log_requests, + if self.vllm_version <= parse_version("0.11.0"): + from trinity.common.models.vllm_patch.api_patch import ( + run_api_server_in_ray_actor, + ) + + self.api_server = asyncio.create_task( + run_api_server_in_ray_actor( + self.async_llm, + api_server_host, + api_server_port, + self.config.model_path, # type: ignore [arg-type] + self.config.enable_auto_tool_choice, + self.config.tool_call_parser, + self.config.reasoning_parser, + self.config.enable_log_requests, + ) + ) + else: + from trinity.common.models.vllm_patch.api_patch_v12 import ( + run_api_server_in_ray_actor_v12, + ) + + self.api_server = asyncio.create_task( + run_api_server_in_ray_actor_v12( + self.async_llm, + api_server_host, + api_server_port, + self.config.model_path, # type: ignore [arg-type] + logger=self.logger, + enable_auto_tool_choice=self.config.enable_auto_tool_choice, + tool_call_parser=self.config.tool_call_parser, + reasoning_parser=self.config.reasoning_parser, + enable_log_requests=self.config.enable_log_requests, + ) ) - ) self.api_server_host = api_server_host self.api_server_port = api_server_port return True @@ -604,7 +631,9 @@ def get_model_version(self) -> int: def get_model_path(self) -> str: return self.config.model_path # type: ignore [return-value] - def get_lora_request(self, lora_path: Optional[str] = None) -> LoRARequest: + def get_lora_request(self, lora_path: Optional[str] = None) -> Any: + from vllm.lora.request import LoRARequest + assert self.config.lora_modules is not None lora_request = LoRARequest(**self.config.lora_modules[0]) if lora_path is not None: diff --git a/trinity/common/models/vllm_patch/api_patch.py b/trinity/common/models/vllm_patch/api_patch.py index 0036b0956c..623f9c04af 100644 --- a/trinity/common/models/vllm_patch/api_patch.py +++ b/trinity/common/models/vllm_patch/api_patch.py @@ -1,4 +1,4 @@ -"""Patch for vllm OpenAI API server. +"""Patch for vllm OpenAI API server. Only for vllm versions >=0.8.5, <=0.11.0. 1. Mocks the `add_signal_handler` method to do nothing. 2. Adds `token_ids` and `prompt_token_ids` to the `ChatCompletionResponse`. @@ -51,7 +51,6 @@ class PatchedChatCompletionResponse(ChatCompletionResponse): choices: list[PatchedChatCompletionResponseChoice] = list[ChatCompletionResponseChoice] -# TODO: add patch to stream generator async def chat_completion_full_generator( # noqa C901 self, request, @@ -304,7 +303,11 @@ async def patch_and_serve_http(app, sock, args): loop = asyncio.get_event_loop() original_add_signal_handler = loop.add_signal_handler loop.add_signal_handler = functools.partial(dummy_add_signal_handler, loop) - OpenAIServingChat.chat_completion_full_generator = chat_completion_full_generator + vllm_version = get_vllm_version() + + # from 0.10.2, vllm added token_ids to ChatCompletionResponseChoice, so no need to patch + if vllm_version < parse_version("0.10.2"): + OpenAIServingChat.chat_completion_full_generator = chat_completion_full_generator try: shutdown_task = await serve_http( diff --git a/trinity/common/models/vllm_patch/api_patch_v12.py b/trinity/common/models/vllm_patch/api_patch_v12.py new file mode 100644 index 0000000000..b926b158a1 --- /dev/null +++ b/trinity/common/models/vllm_patch/api_patch_v12.py @@ -0,0 +1,168 @@ +"""Patch for vllm OpenAI API server. Only for vllm versions >0.11.0. + +1. Mocks the `add_signal_handler` method to do nothing. +2. Adds `token_ids` and `prompt_token_ids` to the `ChatCompletionResponse`. +""" +import logging +from typing import Optional + +import vllm +import vllm.envs as envs +from packaging.version import parse as parse_version +from vllm.entrypoints.launcher import serve_http +from vllm.entrypoints.openai.api_server import ( + build_app, + create_server_socket, + create_server_unix_socket, + init_app_state, + validate_api_server_args, +) +from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.entrypoints.openai.tool_parsers import ToolParserManager +from vllm.entrypoints.utils import log_non_default_args +from vllm.reasoning import ReasoningParserManager +from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.utils.network_utils import is_valid_ipv6_address +from vllm.utils.system_utils import set_ulimit +from vllm.version import __version__ as VLLM_VERSION + +from trinity.common.models.vllm_patch import get_vllm_version + + +def setup_server_in_ray(args, logger): + """Validate API server args, set up signal handler, create socket + ready to serve.""" + + logger.info("vLLM API server version %s", VLLM_VERSION) + log_non_default_args(args) + + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3: + ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin) + + validate_api_server_args(args) + + # workaround to make sure that we bind the port before the engine is set up. + # This avoids race conditions with ray. + # see https://github.com/vllm-project/vllm/issues/8204 + if args.uds: + sock = create_server_unix_socket(args.uds) + else: + sock_addr = (args.host or "", args.port) + sock = create_server_socket(sock_addr) + + # workaround to avoid footguns where uvicorn drops requests with too + # many concurrent requests active + set_ulimit() + + if args.uds: + listen_address = f"unix:{args.uds}" + else: + addr, port = sock_addr + is_ssl = args.ssl_keyfile and args.ssl_certfile + host_part = f"[{addr}]" if is_valid_ipv6_address(addr) else addr or "0.0.0.0" + listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}" + return listen_address, sock + + +async def run_server_worker_in_ray( + listen_address, + sock, + args, + engine_client, + logger, +) -> None: + # Modified from vllm.entrypoints.openai.api_server.run_server_worker + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3: + ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin) + + app = build_app(args) + + await init_app_state(engine_client, app.state, args) + + logger.info( + "Starting vLLM API server %d on %s", + engine_client.vllm_config.parallel_config._api_process_rank, + listen_address, + ) + + shutdown_task = await serve_http( + app, + sock=sock, + enable_ssl_refresh=args.enable_ssl_refresh, + host=args.host, + port=args.port, + log_level=args.uvicorn_log_level, + # NOTE: When the 'disable_uvicorn_access_log' value is True, + # no access log will be output. + access_log=not args.disable_uvicorn_access_log, + timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + h11_max_incomplete_event_size=args.h11_max_incomplete_event_size, + h11_max_header_count=args.h11_max_header_count, + ) + + # NB: Await server shutdown only after the backend context is exited + try: + await shutdown_task + finally: + sock.close() + + +async def run_server_in_ray(args, engine_client, logger): + # Modified from vllm.entrypoints.openai.api_server.run_server + listen_address, sock = setup_server_in_ray(args, logger) + logger.info("vLLM API server listening on %s", listen_address) + await run_server_worker_in_ray(listen_address, sock, args, engine_client, logger) + + +async def run_api_server_in_ray_actor_v12( + async_llm, + host: str, + port: int, + model_path: str, + logger: logging.Logger, + enable_auto_tool_choice: bool = False, + tool_call_parser: Optional[str] = None, + reasoning_parser: Optional[str] = None, + enable_log_requests: bool = False, +): + vllm_version = get_vllm_version() + if vllm_version <= parse_version("0.11.0"): + raise ValueError( + f"Unsupported vllm version: {vllm.__version__}. " + "This patch requires vllm version > 0.11.0" + ) + + parser = FlexibleArgumentParser(description="Run the OpenAI API server.") + args = make_arg_parser(parser) + cli_args = [ + "--host", + str(host), + "--port", + str(port), + "--model", + model_path, + "--enable-server-load-tracking", # enable tracking for load balancing + ] + if enable_log_requests: + cli_args.append("--enable-log-requests") + if enable_auto_tool_choice: + cli_args.append("--enable-auto-tool-choice") + if tool_call_parser: + cli_args.extend(["--tool-call-parser", tool_call_parser]) + if reasoning_parser: + cli_args.extend(["--reasoning-parser", reasoning_parser]) + args = parser.parse_args(cli_args) + if vllm_version >= parse_version("0.11.0"): + args.structured_outputs_config.reasoning_parser = reasoning_parser + logger.info(f"Starting vLLM OpenAI API server with args: {args}") + await run_server_in_ray(args, async_llm, logger) diff --git a/trinity/common/models/vllm_patch/worker_patch.py b/trinity/common/models/vllm_patch/worker_patch.py index ebe9d47ac3..c58decbf7c 100644 --- a/trinity/common/models/vllm_patch/worker_patch.py +++ b/trinity/common/models/vllm_patch/worker_patch.py @@ -10,15 +10,17 @@ from trinity.common.models.vllm_patch import get_vllm_version -def patch_vllm_prompt_logprobs(model_runner: GPUModelRunner): +def patch_vllm_prompt_logprobs(model_runner: GPUModelRunner): # noqa: C901 """Patch vLLM model runner to support prompt logprobs extraction.""" - if get_vllm_version() < parse_version("0.10.2"): + version = get_vllm_version() + if version < parse_version("0.10.2") or version > parse_version("0.12.0"): raise ValueError( f"Unsupported vllm version: {vllm.__version__}. " - "This patch requires vllm version >= 0.10.2, <= 0.11.0." + "This patch requires vllm version >= 0.10.2, <= 0.12.0." ) + is_v0102 = version == parse_version("0.10.2") - def _get_prompt_logprobs_dict( + def _get_prompt_logprobs_dict_v11( self, hidden_states: torch.Tensor, num_scheduled_tokens: dict[str, int], @@ -45,7 +47,131 @@ def _get_prompt_logprobs_dict( # maintainable loop over optimal performance. completed_prefill_reqs = [] for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items(): - num_tokens = num_scheduled_tokens[req_id] + num_tokens = num_scheduled_tokens.get(req_id) + if num_tokens is None: + # This can happen if the request was preempted in prefill stage. + continue + + # Get metadata for this request. + request = self.requests[req_id] + if request.prompt_token_ids is None: + # Prompt logprobs is incompatible with prompt embeddings + continue + + num_prompt_tokens = len(request.prompt_token_ids) + prompt_token_ids = torch.tensor(request.prompt_token_ids).to( + self.device, non_blocking=True + ) + + # Set up target LogprobsTensors object. + logprobs_tensors = in_progress_dict.get(req_id) + if not logprobs_tensors: + # Create empty logprobs CPU tensors for the entire prompt. + # If chunked, we'll copy in slice by slice. + logprobs_tensors = LogprobsTensors.empty_cpu( + num_prompt_tokens - 1, num_prompt_logprobs + 1 + ) + in_progress_dict[req_id] = logprobs_tensors + + # Determine number of logits to retrieve. + start_idx = request.num_computed_tokens + start_tok = start_idx + 1 + num_remaining_tokens = num_prompt_tokens - start_tok + if num_tokens <= num_remaining_tokens: + # This is a chunk, more tokens remain. + # In the == case, there are no more prompt logprobs to produce + # but we want to defer returning them to the next step where we + # have new generated tokens to return. + num_logits = num_tokens + else: + # This is the last chunk of prompt tokens to return. + num_logits = num_remaining_tokens + completed_prefill_reqs.append(req_id) + prompt_logprobs_dict[req_id] = logprobs_tensors + + if num_logits <= 0: + # This can happen for the final chunk if we prefilled exactly + # (num_prompt_tokens - 1) tokens for this request in the prior + # step. There are no more prompt logprobs to produce. + continue + + # Get the logits corresponding to this req's prompt tokens. + # If this is a partial request (i.e. chunked prefill), + # then there is prompt logprob generated for each index. + req_idx = self.input_batch.req_id_to_index[req_id] + offset = self.query_start_loc.np[req_idx].item() + prompt_hidden_states = hidden_states[offset : offset + num_logits] + # PATCH START + if is_v0102: + logits = self.model.compute_logits(prompt_hidden_states, None) + else: + logits = self.model.compute_logits(prompt_hidden_states) + + temp = request.sampling_params.temperature + if temp >= 1e-5: + logits.div_(temp) + # PATCH END + + # Get the "target" tokens for each index. For prompt at index i, + # the token at prompt index i+1 is the "sampled" token we want + # to gather the logprob for. + tgt_token_ids = prompt_token_ids[start_tok : start_tok + num_logits] + + # Compute prompt logprobs. + logprobs = self.sampler.compute_logprobs(logits) + token_ids, logprobs, ranks = self.sampler.gather_logprobs( + logprobs, num_prompt_logprobs, tgt_token_ids + ) + + # Transfer GPU->CPU async. + chunk_slice = slice(start_idx, start_idx + num_logits) + logprobs_tensors.logprob_token_ids[chunk_slice].copy_(token_ids, non_blocking=True) + logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, non_blocking=True) + logprobs_tensors.selected_token_ranks[chunk_slice].copy_(ranks, non_blocking=True) + + # Remove requests that have completed prefill from the batch + # num_prompt_logprobs_dict. + for req_id in completed_prefill_reqs: + del num_prompt_logprobs_dict[req_id] + del in_progress_dict[req_id] + + # Must synchronize the non-blocking GPU->CPU transfers. + if prompt_logprobs_dict: + self._sync_device() + + return prompt_logprobs_dict + + def _get_prompt_logprobs_dict_v12( + self, + hidden_states: torch.Tensor, + num_scheduled_tokens: dict[str, int], + ) -> dict[str, Optional[LogprobsTensors]]: + """Patched version of _get_prompt_logprobs_dict. + + This is a monkey-patched version of `_get_prompt_logprobs_dict` from + `vllm.v1.worker.gpu_model_runner.GPUModelRunner` (vLLM versions + 0.10.2 to 0.11.0). + + The original function does not apply temperature scaling to logits when + calculating prompt logprobs, which can lead to incorrect logprob values + when the temperature is not 1.0. This patch adds the missing + temperature scaling. + """ + num_prompt_logprobs_dict = self.num_prompt_logprobs + if not num_prompt_logprobs_dict: + return {} + + in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu + prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {} + + # Since prompt logprobs are a rare feature, prioritize simple, + # maintainable loop over optimal performance. + completed_prefill_reqs = [] + for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items(): + num_tokens = num_scheduled_tokens.get(req_id) + if num_tokens is None: + # This can happen if the request was preempted in prefill stage. + continue # Get metadata for this request. request = self.requests[req_id] @@ -133,4 +259,11 @@ def _get_prompt_logprobs_dict( return prompt_logprobs_dict - model_runner._get_prompt_logprobs_dict = MethodType(_get_prompt_logprobs_dict, model_runner) + if get_vllm_version() < parse_version("0.12.0"): + model_runner._get_prompt_logprobs_dict = MethodType( + _get_prompt_logprobs_dict_v11, model_runner + ) + else: + model_runner._get_prompt_logprobs_dict = MethodType( + _get_prompt_logprobs_dict_v12, model_runner + ) diff --git a/trinity/explorer/api/api.py b/trinity/explorer/api/api.py index 66b5e2e97a..b702c39f5a 100644 --- a/trinity/explorer/api/api.py +++ b/trinity/explorer/api/api.py @@ -15,6 +15,8 @@ async def chat_completions(request: Request): # Currently, we do not support streaming chat completions body = await request.json() + if "return_token_ids" not in body: + body["return_token_ids"] = True url = await request.app.state.service.allocate_model() try: async with httpx.AsyncClient(timeout=request.app.state.inference_timeout) as client: diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index 2590dae064..57c5c25bf0 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -278,6 +278,9 @@ async def debug(self) -> None: with VizTracer(output_file=self.output_profiling_file): status, exps = await self.run_task(task, 1, 0) + if not status.ok and len(exps) == 0: + exps = self.model_wrapper.extract_experience_from_history() + self.logger.info(f"Debugging failed, extracting {len(exps)} experiences from history.") await self.sqlite_writer.write_async(exps) if status.ok: print(f"Task {task.task_id} completed successfully with metrics:\n{status.metrics}")