Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 2 additions & 38 deletions agentrun/model/__model_proxy_async_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pydash

from agentrun.model.api.data import BaseInfo, ModelDataAPI
from agentrun.model.api.model_api import ModelAPI
from agentrun.utils.config import Config
from agentrun.utils.model import Status
from agentrun.utils.resource import ResourceBase
Expand All @@ -30,6 +31,7 @@ class ModelProxy(
ModelProxyImmutableProps,
ModelProxyMutableProps,
ModelProxySystemProps,
ModelAPI,
ResourceBase,
):
"""模型服务"""
Expand Down Expand Up @@ -230,41 +232,3 @@ def model_info(self, config: Optional[Config] = None) -> BaseInfo:
)

return self._data_client.model_info()

def completions(
self,
messages: list,
model: Optional[str] = None,
stream: bool = False,
config: Optional[Config] = None,
**kwargs,
):
self.model_info(config)
assert self._data_client

return self._data_client.completions(
**kwargs,
messages=messages,
model=model,
stream=stream,
config=config,
)

def responses(
self,
messages: list,
model: Optional[str] = None,
stream: bool = False,
config: Optional[Config] = None,
**kwargs,
):
self.model_info(config)
assert self._data_client

return self._data_client.responses(
**kwargs,
messages=messages,
model=model,
stream=stream,
config=config,
)
39 changes: 3 additions & 36 deletions agentrun/model/__model_service_async_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from typing import List, Optional

from agentrun.model.api.data import BaseInfo, ModelCompletionAPI
from agentrun.model.api.data import BaseInfo
from agentrun.model.api.model_api import ModelAPI
from agentrun.utils.config import Config
from agentrun.utils.model import PageableInput
from agentrun.utils.resource import ResourceBase
Expand All @@ -27,6 +28,7 @@ class ModelService(
ModelServiceImmutableProps,
ModelServiceMutableProps,
ModelServicesSystemProps,
ModelAPI,
ResourceBase,
):
"""模型服务"""
Expand Down Expand Up @@ -230,38 +232,3 @@ def model_info(self, config: Optional[Config] = None) -> BaseInfo:
model=default_model,
headers=cfg.get_headers(),
)

def completions(
self,
messages: list,
model: Optional[str] = None,
stream: bool = False,
**kwargs,
):
info = self.model_info(config=kwargs.get("config"))

m = ModelCompletionAPI(
api_key=info.api_key or "",
base_url=info.base_url or "",
model=model or info.model or self.model_service_name or "",
)

return m.completions(**kwargs, messages=messages, stream=stream)

def responses(
self,
messages: list,
model: Optional[str] = None,
stream: bool = False,
**kwargs,
):
info = self.model_info(config=kwargs.get("config"))

m = ModelCompletionAPI(
api_key=info.api_key or "",
base_url=info.base_url or "",
model=model or info.model or self.model_service_name or "",
provider=(self.provider or "openai").lower(),
)

return m.responses(**kwargs, messages=messages, stream=stream)
156 changes: 156 additions & 0 deletions agentrun/model/api/model_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from abc import ABC, abstractmethod
from typing import Optional, TYPE_CHECKING, Union

from .data import BaseInfo

if TYPE_CHECKING:
from litellm import ResponseInputParam


class ModelAPI(ABC):

@abstractmethod
def model_info(self) -> BaseInfo:
...

def completions(
self,
**kwargs,
):
"""
Deprecated. Use completion() instead.
"""
import warnings

warnings.warn(
"completions() is deprecated, use completion() instead",
DeprecationWarning,
stacklevel=2,
)
return self.completion(**kwargs)

def completion(
self,
messages=[],
model: Optional[str] = None,
custom_llm_provider: Optional[str] = None,
**kwargs,
):
from litellm import completion

info = self.model_info()
return completion(
**kwargs,
api_key=info.api_key,
base_url=info.base_url,
model=model or info.model or "",
custom_llm_provider=custom_llm_provider
or info.provider
or "openai",
messages=messages,
)

async def acompletion(
self,
messages=[],
model: Optional[str] = None,
custom_llm_provider: Optional[str] = None,
**kwargs,
):
from litellm import acompletion

info = self.model_info()
return await acompletion(
**kwargs,
api_key=info.api_key,
base_url=info.base_url,
model=model or info.model or "",
custom_llm_provider=custom_llm_provider
or info.provider
or "openai",
messages=messages,
)

def responses(
self,
input: Union[str, "ResponseInputParam"],
model: Optional[str] = None,
custom_llm_provider: Optional[str] = None,
**kwargs,
):
from litellm import responses

info = self.model_info()
return responses(
**kwargs,
api_key=info.api_key,
base_url=info.base_url,
model=model or info.model or "",
custom_llm_provider=custom_llm_provider
or info.provider
or "openai",
input=input,
)

async def aresponses(
self,
input: Union[str, "ResponseInputParam"],
model: Optional[str] = None,
custom_llm_provider: Optional[str] = None,
**kwargs,
):
from litellm import aresponses

info = self.model_info()
return await aresponses(
**kwargs,
api_key=info.api_key,
base_url=info.base_url,
model=model or info.model or "",
custom_llm_provider=custom_llm_provider
or info.provider
or "openai",
input=input,
)

def embedding(
self,
input=[],
model: Optional[str] = None,
custom_llm_provider: Optional[str] = None,
**kwargs,
):
from litellm import embedding

info = self.model_info()
return embedding(
**kwargs,
api_key=info.api_key,
api_base=info.base_url,
model=model or info.model or "",
custom_llm_provider=custom_llm_provider
or info.provider
or "openai",
input=input,
)

def aembedding(
self,
input=[],
model: Optional[str] = None,
custom_llm_provider: Optional[str] = None,
**kwargs,
):
from litellm import aembedding

info = self.model_info()
return aembedding(
**kwargs,
api_key=info.api_key,
api_base=info.base_url,
model=model or info.model or "",
custom_llm_provider=custom_llm_provider
or info.provider
or "openai",
input=input,
)
40 changes: 2 additions & 38 deletions agentrun/model/model_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pydash

from agentrun.model.api.data import BaseInfo, ModelDataAPI
from agentrun.model.api.model_api import ModelAPI
from agentrun.utils.config import Config
from agentrun.utils.model import Status
from agentrun.utils.resource import ResourceBase
Expand All @@ -40,6 +41,7 @@ class ModelProxy(
ModelProxyImmutableProps,
ModelProxyMutableProps,
ModelProxySystemProps,
ModelAPI,
ResourceBase,
):
"""模型服务"""
Expand Down Expand Up @@ -399,41 +401,3 @@ def model_info(self, config: Optional[Config] = None) -> BaseInfo:
)

return self._data_client.model_info()

def completions(
self,
messages: list,
model: Optional[str] = None,
stream: bool = False,
config: Optional[Config] = None,
**kwargs,
):
self.model_info(config)
assert self._data_client

return self._data_client.completions(
**kwargs,
messages=messages,
model=model,
stream=stream,
config=config,
)

def responses(
self,
messages: list,
model: Optional[str] = None,
stream: bool = False,
config: Optional[Config] = None,
**kwargs,
):
self.model_info(config)
assert self._data_client

return self._data_client.responses(
**kwargs,
messages=messages,
model=model,
stream=stream,
config=config,
)
39 changes: 3 additions & 36 deletions agentrun/model/model_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

from typing import List, Optional

from agentrun.model.api.data import BaseInfo, ModelCompletionAPI
from agentrun.model.api.data import BaseInfo
from agentrun.model.api.model_api import ModelAPI
from agentrun.utils.config import Config
from agentrun.utils.model import PageableInput
from agentrun.utils.resource import ResourceBase
Expand All @@ -37,6 +38,7 @@ class ModelService(
ModelServiceImmutableProps,
ModelServiceMutableProps,
ModelServicesSystemProps,
ModelAPI,
ResourceBase,
):
"""模型服务"""
Expand Down Expand Up @@ -401,38 +403,3 @@ def model_info(self, config: Optional[Config] = None) -> BaseInfo:
model=default_model,
headers=cfg.get_headers(),
)

def completions(
self,
messages: list,
model: Optional[str] = None,
stream: bool = False,
**kwargs,
):
info = self.model_info(config=kwargs.get("config"))

m = ModelCompletionAPI(
api_key=info.api_key or "",
base_url=info.base_url or "",
model=model or info.model or self.model_service_name or "",
)

return m.completions(**kwargs, messages=messages, stream=stream)

def responses(
self,
messages: list,
model: Optional[str] = None,
stream: bool = False,
**kwargs,
):
info = self.model_info(config=kwargs.get("config"))

m = ModelCompletionAPI(
api_key=info.api_key or "",
base_url=info.base_url or "",
model=model or info.model or self.model_service_name or "",
provider=(self.provider or "openai").lower(),
)

return m.responses(**kwargs, messages=messages, stream=stream)
Loading