-
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
feat: 添加更多嵌入提供商 OpenAI-兼容 embedding provider 和智谱字节ollama提供商 #6642
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
whatevertogo
wants to merge
8
commits into
AstrBotDevs:master
Choose a base branch
from
whatevertogo:feat/more-embedding-providers
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+561
−44
Open
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
fb912c4
feat: 添加 OpenAI 兼容嵌入提供商及相关配置,支持智谱和火山引擎提供商嵌入
whatevertogo 61b3e79
feat: 添加 Ollama 兼容嵌入提供商及相关配置,更新文档和测试用例
whatevertogo 2aaad51
fix: address PR review feedback for embedding provider
whatevertogo 448b5e4
docs: improve embedding provider hints for clarity
whatevertogo f9e29ac
feat: 更新 OpenAI 兼容嵌入提供商的文档和测试,增强兼容性和警告处理
whatevertogo 6db3a86
docs: clarify send_dimensions_param purpose in i18n
whatevertogo 04e23fa
fix: use i18n keys for send_dimensions_param in default config
whatevertogo 856f13c
feat: 更新 send_dimensions_param 和 proxy 的描述和提示信息,以增强用户理解
whatevertogo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
137 changes: 137 additions & 0 deletions
137
astrbot/core/provider/sources/openai_compatible_embedding_source.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,137 @@ | ||
| from urllib.parse import urlsplit | ||
|
|
||
| import httpx | ||
| from openai import AsyncOpenAI | ||
|
|
||
| from astrbot import logger | ||
|
|
||
| from ..entities import ProviderType | ||
| from ..provider import EmbeddingProvider | ||
| from ..register import register_provider_adapter | ||
|
|
||
|
|
||
| def normalize_openai_compatible_embedding_api_base(api_base: str) -> str: | ||
| """Normalize API base while preserving provider-specific path prefixes. | ||
| Handles URLs with or without scheme: | ||
| - Empty/whitespace → https://api.openai.com/v1 | ||
| - Host only (api.openai.com) → https://api.openai.com/v1 | ||
| - Full URL with path (https://example.com/api/v3) → preserved as-is | ||
| """ | ||
| cleaned_api_base = api_base.strip().removesuffix("/") | ||
| if not cleaned_api_base: | ||
| return "https://api.openai.com/v1" | ||
|
|
||
| parsed_api_base = urlsplit(cleaned_api_base) | ||
| # If no scheme, the URL is parsed incorrectly (host becomes path) | ||
| if not parsed_api_base.scheme: | ||
| cleaned_api_base = f"https://{cleaned_api_base}" | ||
| parsed_api_base = urlsplit(cleaned_api_base) | ||
|
|
||
| if parsed_api_base.path and parsed_api_base.path != "/": | ||
| return cleaned_api_base | ||
|
|
||
| return f"{cleaned_api_base}/v1" | ||
|
|
||
|
|
||
| def parse_embedding_dimensions(provider_config: dict) -> int: | ||
| """Return the configured local vector size, or 0 when unset/invalid.""" | ||
| raw_dimensions = provider_config.get("embedding_dimensions") | ||
| if raw_dimensions in (None, ""): | ||
| return 0 | ||
|
|
||
| try: | ||
| return int(raw_dimensions) | ||
| except (ValueError, TypeError): | ||
| logger.warning( | ||
| "embedding_dimensions in embedding configs is not a valid integer: '%s', ignored.", | ||
| raw_dimensions, | ||
| ) | ||
| return 0 | ||
|
|
||
|
|
||
| def should_send_dimensions_param(provider_config: dict) -> bool: | ||
| """Read the explicit bool switch used by OpenAI-compatible presets.""" | ||
| raw_value = provider_config.get("send_dimensions_param", False) | ||
| if isinstance(raw_value, bool): | ||
| return raw_value | ||
| if raw_value not in (None, ""): | ||
| logger.warning( | ||
| "send_dimensions_param should be a boolean in embedding configs: '%s', treated as disabled.", | ||
| raw_value, | ||
| ) | ||
| return False | ||
|
|
||
|
|
||
| @register_provider_adapter( | ||
| "openai_compatible_embedding", | ||
| "OpenAI Compatible Embedding 提供商适配器", | ||
| provider_type=ProviderType.EMBEDDING, | ||
| ) | ||
| class OpenAICompatibleEmbeddingProvider(EmbeddingProvider): | ||
| def __init__(self, provider_config: dict, provider_settings: dict) -> None: | ||
| super().__init__(provider_config, provider_settings) | ||
| self.provider_config = provider_config | ||
| self.provider_settings = provider_settings | ||
| self._http_client = None | ||
|
|
||
| proxy = provider_config.get("proxy", "") | ||
| if proxy: | ||
| logger.info(f"[OpenAI Compatible Embedding] 使用代理: {proxy}") | ||
| self._http_client = httpx.AsyncClient(proxy=proxy) | ||
|
|
||
| try: | ||
| timeout = int(provider_config.get("timeout", 20)) | ||
| except (ValueError, TypeError): | ||
| logger.warning( | ||
| "Invalid timeout value in provider config: '%s'. Using default 20s.", | ||
| provider_config.get("timeout"), | ||
| ) | ||
| timeout = 20 | ||
|
|
||
| self.client = AsyncOpenAI( | ||
| api_key=provider_config.get("embedding_api_key"), | ||
| base_url=normalize_openai_compatible_embedding_api_base( | ||
| provider_config.get("embedding_api_base", "") | ||
| ), | ||
| timeout=timeout, | ||
| http_client=self._http_client, | ||
| ) | ||
| self.model = provider_config.get("embedding_model", "text-embedding-3-small") | ||
|
|
||
| async def get_embedding(self, text: str) -> list[float]: | ||
| """获取文本的嵌入。""" | ||
| kwargs = self._embedding_kwargs() | ||
| embedding = await self.client.embeddings.create( | ||
| input=text, | ||
| model=self.model, | ||
| **kwargs, | ||
| ) | ||
| return embedding.data[0].embedding | ||
|
|
||
| async def get_embeddings(self, text: list[str]) -> list[list[float]]: | ||
| """批量获取文本的嵌入。""" | ||
| kwargs = self._embedding_kwargs() | ||
| embeddings = await self.client.embeddings.create( | ||
| input=text, | ||
| model=self.model, | ||
| **kwargs, | ||
| ) | ||
| return [item.embedding for item in embeddings.data] | ||
|
|
||
| def _embedding_kwargs(self) -> dict: | ||
| """Only send optional parameters the upstream explicitly needs.""" | ||
| dimensions = parse_embedding_dimensions(self.provider_config) | ||
| if should_send_dimensions_param(self.provider_config) and dimensions > 0: | ||
| return {"dimensions": dimensions} | ||
| return {} | ||
|
|
||
| def get_dim(self) -> int: | ||
| """获取向量的维度。""" | ||
| return parse_embedding_dimensions(self.provider_config) | ||
|
|
||
| async def terminate(self): | ||
| if self.client: | ||
| await self.client.close() | ||
whatevertogo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if self._http_client: | ||
| await self._http_client.aclose() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.