From cfffd23220b07b9e8b3ff25c444f827bafd68538 Mon Sep 17 00:00:00 2001 From: OhYee Date: Tue, 13 Jan 2026 12:04:18 +0800 Subject: [PATCH 1/2] refactor(model): integrate ModelAPI into model proxy and service classes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This change introduces ModelAPI as a mixin to both ModelProxy and ModelService classes, consolidating common API functionality. The completions and responses methods have been removed from these classes as they are now inherited from ModelAPI, reducing code duplication and improving maintainability. The refactoring affects multiple files including __model_proxy_async_template.py, __model_service_async_template.py, model_proxy.py, and model_service.py, along with the new model_api.py module definition. Two new files were created: agentrun/model/api/model_api.py which defines the ModelAPI base class, and examples/embedding.py which provides usage examples for embedding functionality. // 此更改将ModelAPI作为mixin引入到ModelProxy和ModelService类中, 整合了通用的API功能。completions和responses方法已从这些类中移除, 因为它们现在继承自ModelAPI,减少了代码重复并提高了可维护性。 重构影响了多个文件,包括__model_proxy_async_template.py、 __model_service_async_template.py、model_proxy.py和model_service.py, 以及新的model_api.py模块定义。 创建了两个新文件:agentrun/model/api/model_api.py定义了ModelAPI基类, examples/embedding.py提供了嵌入功能的使用示例。 Change-Id: I81045039bef65df9d7daab6a56ff2e1442fddcb7 Signed-off-by: OhYee --- .../model/__model_proxy_async_template.py | 40 +---- .../model/__model_service_async_template.py | 39 +---- agentrun/model/api/model_api.py | 156 ++++++++++++++++++ agentrun/model/model_proxy.py | 40 +---- agentrun/model/model_service.py | 39 +---- examples/embedding.py | 141 ++++++++++++++++ 6 files changed, 307 insertions(+), 148 deletions(-) create mode 100644 agentrun/model/api/model_api.py create mode 100644 examples/embedding.py diff --git a/agentrun/model/__model_proxy_async_template.py b/agentrun/model/__model_proxy_async_template.py index dfa15a7..847dc52 100644 --- a/agentrun/model/__model_proxy_async_template.py +++ b/agentrun/model/__model_proxy_async_template.py @@ -9,6 +9,7 @@ import pydash from agentrun.model.api.data import BaseInfo, ModelDataAPI +from agentrun.model.api.model_api import ModelAPI from agentrun.utils.config import Config from agentrun.utils.model import Status from agentrun.utils.resource import ResourceBase @@ -30,6 +31,7 @@ class ModelProxy( ModelProxyImmutableProps, ModelProxyMutableProps, ModelProxySystemProps, + ModelAPI, ResourceBase, ): """模型服务""" @@ -230,41 +232,3 @@ def model_info(self, config: Optional[Config] = None) -> BaseInfo: ) return self._data_client.model_info() - - def completions( - self, - messages: list, - model: Optional[str] = None, - stream: bool = False, - config: Optional[Config] = None, - **kwargs, - ): - self.model_info(config) - assert self._data_client - - return self._data_client.completions( - **kwargs, - messages=messages, - model=model, - stream=stream, - config=config, - ) - - def responses( - self, - messages: list, - model: Optional[str] = None, - stream: bool = False, - config: Optional[Config] = None, - **kwargs, - ): - self.model_info(config) - assert self._data_client - - return self._data_client.responses( - **kwargs, - messages=messages, - model=model, - stream=stream, - config=config, - ) diff --git a/agentrun/model/__model_service_async_template.py b/agentrun/model/__model_service_async_template.py index e94331d..a3cfdcb 100644 --- a/agentrun/model/__model_service_async_template.py +++ b/agentrun/model/__model_service_async_template.py @@ -6,7 +6,8 @@ from typing import List, Optional -from agentrun.model.api.data import BaseInfo, ModelCompletionAPI +from agentrun.model.api.data import BaseInfo +from agentrun.model.api.model_api import ModelAPI from agentrun.utils.config import Config from agentrun.utils.model import PageableInput from agentrun.utils.resource import ResourceBase @@ -27,6 +28,7 @@ class ModelService( ModelServiceImmutableProps, ModelServiceMutableProps, ModelServicesSystemProps, + ModelAPI, ResourceBase, ): """模型服务""" @@ -230,38 +232,3 @@ def model_info(self, config: Optional[Config] = None) -> BaseInfo: model=default_model, headers=cfg.get_headers(), ) - - def completions( - self, - messages: list, - model: Optional[str] = None, - stream: bool = False, - **kwargs, - ): - info = self.model_info(config=kwargs.get("config")) - - m = ModelCompletionAPI( - api_key=info.api_key or "", - base_url=info.base_url or "", - model=model or info.model or self.model_service_name or "", - ) - - return m.completions(**kwargs, messages=messages, stream=stream) - - def responses( - self, - messages: list, - model: Optional[str] = None, - stream: bool = False, - **kwargs, - ): - info = self.model_info(config=kwargs.get("config")) - - m = ModelCompletionAPI( - api_key=info.api_key or "", - base_url=info.base_url or "", - model=model or info.model or self.model_service_name or "", - provider=(self.provider or "openai").lower(), - ) - - return m.responses(**kwargs, messages=messages, stream=stream) diff --git a/agentrun/model/api/model_api.py b/agentrun/model/api/model_api.py new file mode 100644 index 0000000..001097b --- /dev/null +++ b/agentrun/model/api/model_api.py @@ -0,0 +1,156 @@ +from abc import ABC, abstractmethod +from typing import Optional, TYPE_CHECKING, Union + +from .data import BaseInfo + +if TYPE_CHECKING: + from litellm import ResponseInputParam + + +class ModelAPI(ABC): + + @abstractmethod + def model_info(self) -> BaseInfo: + ... + + def completions( + self, + **kwargs, + ): + """ + Deprecated. Use completion() instead. + """ + import warnings + + warnings.warn( + "completions() is deprecated, use completion() instead", + DeprecationWarning, + stacklevel=2, + ) + return self.completion(**kwargs) + + def completion( + self, + messages=[], + model: Optional[str] = None, + custom_llm_provider: Optional[str] = None, + **kwargs, + ): + from litellm import completion + + info = self.model_info() + return completion( + **kwargs, + api_key=info.api_key, + base_url=info.base_url, + model=model or info.model or "", + custom_llm_provider=custom_llm_provider + or info.provider + or "openai", + messages=messages, + ) + + async def acompletion( + self, + messages=[], + model: Optional[str] = None, + custom_llm_provider: Optional[str] = None, + **kwargs, + ): + from litellm import acompletion + + info = self.model_info() + return await acompletion( + **kwargs, + api_key=info.api_key, + base_url=info.base_url, + model=model or info.model or "", + custom_llm_provider=custom_llm_provider + or info.provider + or "openai", + messages=messages, + ) + + def responses( + self, + input: Union[str, "ResponseInputParam"], + model: Optional[str] = None, + custom_llm_provider: Optional[str] = None, + **kwargs, + ): + from litellm import responses + + info = self.model_info() + return responses( + **kwargs, + api_key=info.api_key, + base_url=info.base_url, + model=model or info.model or "", + custom_llm_provider=custom_llm_provider + or info.provider + or "openai", + input=input, + ) + + async def aresponses( + self, + input: Union[str, "ResponseInputParam"], + model: Optional[str] = None, + custom_llm_provider: Optional[str] = None, + **kwargs, + ): + from litellm import aresponses + + info = self.model_info() + return await aresponses( + **kwargs, + api_key=info.api_key, + base_url=info.base_url, + model=model or info.model or "", + custom_llm_provider=custom_llm_provider + or info.provider + or "openai", + input=input, + ) + + def embedding( + self, + input=[], + model: Optional[str] = None, + custom_llm_provider: Optional[str] = None, + **kwargs, + ): + from litellm import embedding + + info = self.model_info() + return embedding( + **kwargs, + api_key=info.api_key, + api_base=info.base_url, + model=model or info.model or "", + custom_llm_provider=custom_llm_provider + or info.provider + or "openai", + input=input, + ) + + def aembedding( + self, + input=[], + model: Optional[str] = None, + custom_llm_provider: Optional[str] = None, + **kwargs, + ): + from litellm import aembedding + + info = self.model_info() + return aembedding( + **kwargs, + api_key=info.api_key, + api_base=info.base_url, + model=model or info.model or "", + custom_llm_provider=custom_llm_provider + or info.provider + or "openai", + input=input, + ) diff --git a/agentrun/model/model_proxy.py b/agentrun/model/model_proxy.py index 248d210..889ee2f 100644 --- a/agentrun/model/model_proxy.py +++ b/agentrun/model/model_proxy.py @@ -19,6 +19,7 @@ import pydash from agentrun.model.api.data import BaseInfo, ModelDataAPI +from agentrun.model.api.model_api import ModelAPI from agentrun.utils.config import Config from agentrun.utils.model import Status from agentrun.utils.resource import ResourceBase @@ -40,6 +41,7 @@ class ModelProxy( ModelProxyImmutableProps, ModelProxyMutableProps, ModelProxySystemProps, + ModelAPI, ResourceBase, ): """模型服务""" @@ -399,41 +401,3 @@ def model_info(self, config: Optional[Config] = None) -> BaseInfo: ) return self._data_client.model_info() - - def completions( - self, - messages: list, - model: Optional[str] = None, - stream: bool = False, - config: Optional[Config] = None, - **kwargs, - ): - self.model_info(config) - assert self._data_client - - return self._data_client.completions( - **kwargs, - messages=messages, - model=model, - stream=stream, - config=config, - ) - - def responses( - self, - messages: list, - model: Optional[str] = None, - stream: bool = False, - config: Optional[Config] = None, - **kwargs, - ): - self.model_info(config) - assert self._data_client - - return self._data_client.responses( - **kwargs, - messages=messages, - model=model, - stream=stream, - config=config, - ) diff --git a/agentrun/model/model_service.py b/agentrun/model/model_service.py index 24f9cce..270b355 100644 --- a/agentrun/model/model_service.py +++ b/agentrun/model/model_service.py @@ -16,7 +16,8 @@ from typing import List, Optional -from agentrun.model.api.data import BaseInfo, ModelCompletionAPI +from agentrun.model.api.data import BaseInfo +from agentrun.model.api.model_api import ModelAPI from agentrun.utils.config import Config from agentrun.utils.model import PageableInput from agentrun.utils.resource import ResourceBase @@ -37,6 +38,7 @@ class ModelService( ModelServiceImmutableProps, ModelServiceMutableProps, ModelServicesSystemProps, + ModelAPI, ResourceBase, ): """模型服务""" @@ -401,38 +403,3 @@ def model_info(self, config: Optional[Config] = None) -> BaseInfo: model=default_model, headers=cfg.get_headers(), ) - - def completions( - self, - messages: list, - model: Optional[str] = None, - stream: bool = False, - **kwargs, - ): - info = self.model_info(config=kwargs.get("config")) - - m = ModelCompletionAPI( - api_key=info.api_key or "", - base_url=info.base_url or "", - model=model or info.model or self.model_service_name or "", - ) - - return m.completions(**kwargs, messages=messages, stream=stream) - - def responses( - self, - messages: list, - model: Optional[str] = None, - stream: bool = False, - **kwargs, - ): - info = self.model_info(config=kwargs.get("config")) - - m = ModelCompletionAPI( - api_key=info.api_key or "", - base_url=info.base_url or "", - model=model or info.model or self.model_service_name or "", - provider=(self.provider or "openai").lower(), - ) - - return m.responses(**kwargs, messages=messages, stream=stream) diff --git a/examples/embedding.py b/examples/embedding.py new file mode 100644 index 0000000..f2b5e5f --- /dev/null +++ b/examples/embedding.py @@ -0,0 +1,141 @@ +import os +import re +import time + +from agentrun import model +from agentrun.model import ( + BackendType, + ModelClient, + ModelService, + ModelServiceCreateInput, + ModelServiceListInput, + ModelServiceUpdateInput, +) +from agentrun.utils.exception import ( + ResourceAlreadyExistError, + ResourceNotExistError, +) +from agentrun.utils.log import logger +from agentrun.utils.model import Status + +base_url = os.getenv( + "BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1" +) +api_key = os.getenv("API_KEY", "sk-xxxxx") +model_names = re.split( + r"\s|,", os.getenv("MODEL_NAMES", "text-embedding-v1").strip() +) + + +client = ModelClient() +model_service_name = "sdk-test-embedding" + + +def create_or_get_model_service(): + """ + 为您演示如何进行创建 / 获取 + """ + logger.info("创建或获取已有的资源") + + try: + ms = client.create( + ModelServiceCreateInput( + model_service_name=model_service_name, + description="测试模型服务", + model_type=model.ModelType.EMBEDDING, + provider="openai", + provider_settings=model.ProviderSettings( + api_key=api_key, + base_url=base_url, + model_names=model_names, + ), + ) + ) + except ResourceAlreadyExistError: + logger.info("已存在,获取已有资源") + ms = client.get( + name=model_service_name, backend_type=BackendType.SERVICE + ) + + ms.wait_until_ready_or_failed() + if ms.status != Status.READY: + raise Exception(f"状态异常:{ms.status}") + + logger.info("已就绪状态,当前信息: %s", ms) + + return ms + + +def update_model_service(ms: ModelService): + """ + 为您演示如何进行更新 + """ + logger.info("更新描述为当前时间") + + # 也可以使用 client.update + ms.update( + ModelServiceUpdateInput(description=f"当前时间戳:{time.time()}"), + ) + ms.wait_until_ready_or_failed() + if ms.status != Status.READY: + raise Exception(f"状态异常:{ms.status}") + + logger.info("更新成功,当前信息: %s", ms) + + +def list_model_services(): + """ + 为您演示如何进行枚举 + """ + logger.info("枚举资源列表") + ms_arr = client.list(ModelServiceListInput(model_type=model.ModelType.LLM)) + logger.info( + "共有 %d 个资源,分别为 %s", + len(ms_arr), + [c.model_service_name for c in ms_arr], + ) + + +def delete_model_service(ms: ModelService): + """ + 为您演示如何进行删除 + """ + logger.info("开始清理资源") + # 也可以使用 client.delete / cred.delete + 轮询状态 + ms.delete_and_wait_until_finished() + + logger.info("再次尝试获取") + try: + ms.refresh() + except ResourceNotExistError as e: + logger.info("得到资源不存在报错,删除成功,%s", e) + + +def invoke_model_service(ms: ModelService): + logger.info("调用模型服务进行推理") + + result = ms.embedding(input=["你好", "今天是周几"]) + logger.info("Embedding result: %s", result) + + +def model_example(): + """ + 为您演示模型模块的基本功能 + """ + logger.info("==== 模型模块基本功能示例 ====") + logger.info(" base_url=%s", base_url) + logger.info(" api_key=%s", len(api_key) * "*") + logger.info(" model_names=%s", model_names) + + list_model_services() + ms = create_or_get_model_service() + update_model_service(ms) + + invoke_model_service(ms) + + delete_model_service(ms) + list_model_services() + + +if __name__ == "__main__": + model_example() From 70876defa81ca60f16eb83765054c6bf0e1743fd Mon Sep 17 00:00:00 2001 From: OhYee Date: Thu, 15 Jan 2026 13:59:37 +0800 Subject: [PATCH 2/2] test(model): update ModelProxy tests to use proper mocking and parameters MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Update the unit tests for ModelProxy completions and responses methods to properly mock litellm functions instead of mocking the data client completions/responses methods. Also correct the test_responses method to pass the correct 'input' parameter instead of 'messages'. The changes ensure that the tests accurately reflect how the ModelProxy interacts with the underlying litellm API while maintaining proper mocking behavior. 更新 ModelProxy 测试以使用正确的模拟和参数 更新 ModelProxy 完成和响应方法的单元测试,以正确模拟 litellm 函数, 而不是模拟数据客户端的完成/响应方法。同时修正 test_responses 方法, 传递正确的 'input' 参数而不是 'messages'。 这些更改确保测试准确反映 ModelProxy 如何与底层 litellm API 交互, 同时保持适当的模拟行为。 Change-Id: I5fffbf165d23efa7607c2618b1fd9604df3caeb6 Signed-off-by: OhYee --- tests/unittests/model/test_model_proxy.py | 28 ++++++++++++++--------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/tests/unittests/model/test_model_proxy.py b/tests/unittests/model/test_model_proxy.py index fc88b01..6f0890e 100644 --- a/tests/unittests/model/test_model_proxy.py +++ b/tests/unittests/model/test_model_proxy.py @@ -527,23 +527,25 @@ class TestModelProxyCompletions: "AGENTRUN_ACCOUNT_ID": "test-account", }, ) - def test_completions(self): + @patch("litellm.completion") + def test_completions(self, mock_completion): from agentrun.model.api.data import BaseInfo + mock_completion.return_value = {"choices": []} + proxy = ModelProxy(model_proxy_name="test-proxy") - # Create a mock _data_client directly + # Create a mock _data_client to provide model_info mock_data_client = MagicMock() mock_info = BaseInfo(model="gpt-4", base_url="https://api.example.com") mock_data_client.model_info.return_value = mock_info - mock_data_client.completions.return_value = {"choices": []} - # Bypass the model_info call by setting _data_client + # Set _data_client so model_info() returns our mock info proxy._data_client = mock_data_client proxy.completions(messages=[{"role": "user", "content": "Hello"}]) - mock_data_client.completions.assert_called_once() + mock_completion.assert_called_once() class TestModelProxyResponses: @@ -557,20 +559,24 @@ class TestModelProxyResponses: "AGENTRUN_ACCOUNT_ID": "test-account", }, ) - def test_responses(self): + @patch("litellm.responses") + def test_responses(self, mock_responses): from agentrun.model.api.data import BaseInfo + mock_responses.return_value = {} + proxy = ModelProxy(model_proxy_name="test-proxy") - # Create a mock _data_client directly + # Create a mock _data_client to provide model_info mock_data_client = MagicMock() mock_info = BaseInfo(model="gpt-4", base_url="https://api.example.com") mock_data_client.model_info.return_value = mock_info - mock_data_client.responses.return_value = {} - # Bypass the model_info call by setting _data_client + # Set _data_client so model_info() returns our mock info proxy._data_client = mock_data_client - proxy.responses(messages=[{"role": "user", "content": "Hello"}]) + # Note: The responses method expects 'input' parameter (not 'messages') + # based on the ModelAPI.responses signature + proxy.responses(input="Hello") - mock_data_client.responses.assert_called_once() + mock_responses.assert_called_once()