diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 5b4ea7686a..d933dfa89b 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -1621,17 +1621,63 @@ class ChatProviderTemplate(TypedDict): "gemini_tts_voice_name": "Leda", "proxy": "", }, - "OpenAI Embedding": { - "id": "openai_embedding", - "type": "openai_embedding", + "OpenAI Compatible Embedding": { + "id": "openai_compatible_embedding", + "type": "openai_compatible_embedding", "provider": "openai", "provider_type": "embedding", - "hint": "provider_group.provider.openai_embedding.hint", + "hint": "provider_group.provider.openai_compatible_embedding.hint", "enable": True, "embedding_api_key": "", "embedding_api_base": "", "embedding_model": "", "embedding_dimensions": 1024, + "send_dimensions_param": False, + "timeout": 20, + "proxy": "", + }, + "Zhipu Embedding": { + "id": "zhipu_embedding", + "type": "openai_compatible_embedding", + "provider": "zhipu", + "provider_type": "embedding", + "hint": "provider_group.provider.zhipu_embedding.hint", + "enable": True, + "embedding_api_key": "", + "embedding_api_base": "https://open.bigmodel.cn/api/paas/v4", + "embedding_model": "embedding-3", + "embedding_dimensions": 2048, + "send_dimensions_param": True, + "timeout": 20, + "proxy": "", + }, + "Volcengine Embedding": { + "id": "volcengine_embedding", + "type": "openai_compatible_embedding", + "provider": "volcengine", + "provider_type": "embedding", + "hint": "provider_group.provider.volcengine_embedding.hint", + "enable": True, + "embedding_api_key": "", + "embedding_api_base": "https://ark.cn-beijing.volces.com/api/v3", + "embedding_model": "doubao-embedding-vision", + "embedding_dimensions": 2048, + "send_dimensions_param": True, + "timeout": 20, + "proxy": "", + }, + "Ollama Embedding": { + "id": "ollama_embedding", + "type": "openai_compatible_embedding", + "provider": "ollama", + "provider_type": "embedding", + "hint": "provider_group.provider.ollama_embedding.hint", + "enable": True, + "embedding_api_key": "ollama", + "embedding_api_base": "http://127.0.0.1:11434", + "embedding_model": "embeddinggemma", + "embedding_dimensions": 768, + "send_dimensions_param": False, "timeout": 20, "proxy": "", }, @@ -1937,6 +1983,12 @@ class ChatProviderTemplate(TypedDict): "description": "API Base URL", "type": "string", }, + "send_dimensions_param": { + "description": "透传 dimensions 参数", + "type": "bool", + "hint": "启用后,将 embedding_dimensions 作为 dimensions 参数发送给上游 API。支持自定义维度的服务(OpenAI、智谱、火山等)可开启此项以实现降维;若上游不支持自定义维度则关闭。", + "condition": {"type": "openai_compatible_embedding"}, + }, "volcengine_cluster": { "type": "string", "description": "火山引擎集群", @@ -2357,9 +2409,9 @@ class ChatProviderTemplate(TypedDict): "type": "string", }, "proxy": { - "description": "provider_group.provider.proxy.description", + "description": "代理地址", "type": "string", - "hint": "provider_group.provider.proxy.hint", + "hint": "HTTP/HTTPS 代理地址,格式如 http://127.0.0.1:7890。仅对该提供商的 API 请求生效,不影响 Docker 内网通信。", }, "model": { "description": "模型 ID", diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 0df9f791ae..5336ff8263 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -443,6 +443,10 @@ def dynamic_import_provider(self, type: str) -> None: from .sources.openai_embedding_source import ( OpenAIEmbeddingProvider as OpenAIEmbeddingProvider, ) + case "openai_compatible_embedding": + from .sources.openai_compatible_embedding_source import ( + OpenAICompatibleEmbeddingProvider as OpenAICompatibleEmbeddingProvider, + ) case "gemini_embedding": from .sources.gemini_embedding_source import ( GeminiEmbeddingProvider as GeminiEmbeddingProvider, diff --git a/astrbot/core/provider/sources/openai_compatible_embedding_source.py b/astrbot/core/provider/sources/openai_compatible_embedding_source.py new file mode 100644 index 0000000000..024ae3fda1 --- /dev/null +++ b/astrbot/core/provider/sources/openai_compatible_embedding_source.py @@ -0,0 +1,137 @@ +from urllib.parse import urlsplit + +import httpx +from openai import AsyncOpenAI + +from astrbot import logger + +from ..entities import ProviderType +from ..provider import EmbeddingProvider +from ..register import register_provider_adapter + + +def normalize_openai_compatible_embedding_api_base(api_base: str) -> str: + """Normalize API base while preserving provider-specific path prefixes. + + Handles URLs with or without scheme: + - Empty/whitespace → https://api.openai.com/v1 + - Host only (api.openai.com) → https://api.openai.com/v1 + - Full URL with path (https://example.com/api/v3) → preserved as-is + """ + cleaned_api_base = api_base.strip().removesuffix("/") + if not cleaned_api_base: + return "https://api.openai.com/v1" + + parsed_api_base = urlsplit(cleaned_api_base) + # If no scheme, the URL is parsed incorrectly (host becomes path) + if not parsed_api_base.scheme: + cleaned_api_base = f"https://{cleaned_api_base}" + parsed_api_base = urlsplit(cleaned_api_base) + + if parsed_api_base.path and parsed_api_base.path != "/": + return cleaned_api_base + + return f"{cleaned_api_base}/v1" + + +def parse_embedding_dimensions(provider_config: dict) -> int: + """Return the configured local vector size, or 0 when unset/invalid.""" + raw_dimensions = provider_config.get("embedding_dimensions") + if raw_dimensions in (None, ""): + return 0 + + try: + return int(raw_dimensions) + except (ValueError, TypeError): + logger.warning( + "embedding_dimensions in embedding configs is not a valid integer: '%s', ignored.", + raw_dimensions, + ) + return 0 + + +def should_send_dimensions_param(provider_config: dict) -> bool: + """Read the explicit bool switch used by OpenAI-compatible presets.""" + raw_value = provider_config.get("send_dimensions_param", False) + if isinstance(raw_value, bool): + return raw_value + if raw_value not in (None, ""): + logger.warning( + "send_dimensions_param should be a boolean in embedding configs: '%s', treated as disabled.", + raw_value, + ) + return False + + +@register_provider_adapter( + "openai_compatible_embedding", + "OpenAI Compatible Embedding 提供商适配器", + provider_type=ProviderType.EMBEDDING, +) +class OpenAICompatibleEmbeddingProvider(EmbeddingProvider): + def __init__(self, provider_config: dict, provider_settings: dict) -> None: + super().__init__(provider_config, provider_settings) + self.provider_config = provider_config + self.provider_settings = provider_settings + self._http_client = None + + proxy = provider_config.get("proxy", "") + if proxy: + logger.info(f"[OpenAI Compatible Embedding] 使用代理: {proxy}") + self._http_client = httpx.AsyncClient(proxy=proxy) + + try: + timeout = int(provider_config.get("timeout", 20)) + except (ValueError, TypeError): + logger.warning( + "Invalid timeout value in provider config: '%s'. Using default 20s.", + provider_config.get("timeout"), + ) + timeout = 20 + + self.client = AsyncOpenAI( + api_key=provider_config.get("embedding_api_key"), + base_url=normalize_openai_compatible_embedding_api_base( + provider_config.get("embedding_api_base", "") + ), + timeout=timeout, + http_client=self._http_client, + ) + self.model = provider_config.get("embedding_model", "text-embedding-3-small") + + async def get_embedding(self, text: str) -> list[float]: + """获取文本的嵌入。""" + kwargs = self._embedding_kwargs() + embedding = await self.client.embeddings.create( + input=text, + model=self.model, + **kwargs, + ) + return embedding.data[0].embedding + + async def get_embeddings(self, text: list[str]) -> list[list[float]]: + """批量获取文本的嵌入。""" + kwargs = self._embedding_kwargs() + embeddings = await self.client.embeddings.create( + input=text, + model=self.model, + **kwargs, + ) + return [item.embedding for item in embeddings.data] + + def _embedding_kwargs(self) -> dict: + """Only send optional parameters the upstream explicitly needs.""" + dimensions = parse_embedding_dimensions(self.provider_config) + if should_send_dimensions_param(self.provider_config) and dimensions > 0: + return {"dimensions": dimensions} + return {} + + def get_dim(self) -> int: + """获取向量的维度。""" + return parse_embedding_dimensions(self.provider_config) + + async def terminate(self): + if self.client: + await self.client.close() + if self._http_client: + await self._http_client.aclose() diff --git a/dashboard/src/components/shared/AstrBotConfig.vue b/dashboard/src/components/shared/AstrBotConfig.vue index bc1c86bdfc..254ceb9f62 100644 --- a/dashboard/src/components/shared/AstrBotConfig.vue +++ b/dashboard/src/components/shared/AstrBotConfig.vue @@ -50,36 +50,11 @@ const filteredIterable = computed(() => { const providerHint = computed(() => { const hint = props.iterable?.hint - if (typeof hint !== 'string' || !hint) return '' - - if ( - hint === 'provider_group.provider.openai_embedding.hint' - || hint === 'provider_group.provider.gemini_embedding.hint' - ) { - return '' - } - - return hint + return typeof hint === 'string' ? hint : '' }) -const getItemHint = (itemKey, itemMeta) => { - if (itemMeta?.hint) return itemMeta.hint - - if (itemKey !== 'embedding_api_base') return '' - - const providerType = props.iterable?.type - if (providerType === 'openai_embedding') { - return getRaw('provider_group.provider.openai_embedding.hint') - ? 'provider_group.provider.openai_embedding.hint' - : '' - } - if (providerType === 'gemini_embedding') { - return getRaw('provider_group.provider.gemini_embedding.hint') - ? 'provider_group.provider.gemini_embedding.hint' - : '' - } - - return '' +const getItemHint = (_itemKey, itemMeta) => { + return itemMeta?.hint || '' } const dialog = ref(false) diff --git a/dashboard/src/i18n/locales/en-US/features/config-metadata.json b/dashboard/src/i18n/locales/en-US/features/config-metadata.json index 2e12143725..089007fe4e 100644 --- a/dashboard/src/i18n/locales/en-US/features/config-metadata.json +++ b/dashboard/src/i18n/locales/en-US/features/config-metadata.json @@ -1188,8 +1188,24 @@ "embedding_api_base": { "description": "API Base URL" }, + "send_dimensions_param": { + "description": "Forward dimensions parameter", + "hint": "When enabled, sends embedding_dimensions to the upstream API as the dimensions parameter. This allows dimension reduction on services that support it (OpenAI, Zhipu, Volcengine). Leave disabled if the upstream does not support custom dimensions." + }, "openai_embedding": { - "hint": "OpenAI Embedding automatically appends /v1 at request time." + "hint": "OpenAI Embedding is kept for existing configs and standard /v1 endpoints. It automatically appends /v1 and still forwards dimensions whenever embedding_dimensions is configured." + }, + "openai_compatible_embedding": { + "hint": "Use this for broader OpenAI-compatible embedding services. AstrBot appends /v1 only when the API Base URL has no path, and preserves existing paths such as /api/paas/v4 or /api/v3. Unlike the legacy OpenAI Embedding preset, dimensions forwarding is explicitly configurable here." + }, + "zhipu_embedding": { + "hint": "The Zhipu preset defaults to https://open.bigmodel.cn/api/paas/v4 and embedding-3, and keeps the /api/paas/v4 path unchanged." + }, + "volcengine_embedding": { + "hint": "The Volcengine preset defaults to https://ark.cn-beijing.volces.com/api/v3 and doubao-embedding-vision, and keeps the /api/v3 path unchanged. AstrBot still uses this model for text input only." + }, + "ollama_embedding": { + "hint": "The Ollama preset defaults to local http://127.0.0.1:11434, model embeddinggemma, and 768 dimensions. The API key defaults to ollama only for OpenAI SDK compatibility, and Ollama ignores it." }, "gemini_embedding": { "hint": "Gemini Embedding does not require manually adding /v1beta." @@ -1518,4 +1534,4 @@ "helpMiddle": "or", "helpSuffix": "." } -} \ No newline at end of file +} diff --git a/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json b/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json index 56d12c9838..478f3c4c68 100644 --- a/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json +++ b/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json @@ -1193,8 +1193,24 @@ "embedding_api_base": { "description": "Адрес прокси-сервера" }, + "send_dimensions_param": { + "description": "Передавать параметр dimensions", + "hint": "Если включено, отправляет embedding_dimensions в upstream API как параметр dimensions. Это позволяет уменьшить размерность для сервисов, которые это поддерживают (OpenAI, Zhipu, Volcengine). Отключите, если upstream не поддерживает пользовательскую размерность." + }, "openai_embedding": { - "hint": "OpenAI Embedding автоматически добавляет /v1 при запросе." + "hint": "OpenAI Embedding сохранён для существующих конфигураций и стандартных /v1 endpoint-ов. Он автоматически добавляет /v1 и по-прежнему отправляет dimensions, если задан embedding_dimensions." + }, + "openai_compatible_embedding": { + "hint": "Используйте этот вариант для более широкого круга OpenAI-совместимых embedding-сервисов. AstrBot добавляет /v1 только если в API Base URL нет пути, а уже заданные пути, такие как /api/paas/v4 или /api/v3, сохраняются без изменений. В отличие от старого OpenAI Embedding, здесь можно явно управлять передачей dimensions." + }, + "zhipu_embedding": { + "hint": "Пресет Zhipu по умолчанию использует https://open.bigmodel.cn/api/paas/v4 и embedding-3, сохраняя путь /api/paas/v4 без изменений." + }, + "volcengine_embedding": { + "hint": "Пресет Volcengine по умолчанию использует https://ark.cn-beijing.volces.com/api/v3 и doubao-embedding-vision, сохраняя путь /api/v3 без изменений. Сейчас AstrBot использует эту модель только для текстового ввода." + }, + "ollama_embedding": { + "hint": "Пресет Ollama по умолчанию использует локальный http://127.0.0.1:11434, модель embeddinggemma и размерность 768. API key по умолчанию равен ollama только для совместимости с OpenAI SDK; сам Ollama его игнорирует." }, "gemini_embedding": { "hint": "Gemini Embedding не требует ручного добавления /v1beta." @@ -1523,4 +1539,4 @@ "helpMiddle": "или", "helpSuffix": "." } -} \ No newline at end of file +} diff --git a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json index 0c9148bd0b..caa865bfe4 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json +++ b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json @@ -1190,8 +1190,24 @@ "embedding_api_base": { "description": "API Base URL" }, + "send_dimensions_param": { + "description": "透传 dimensions 参数", + "hint": "启用后,将 embedding_dimensions 作为 dimensions 参数发送给上游 API。支持自定义维度的服务(OpenAI、智谱、火山等)可开启此项以实现降维;若上游不支持自定义维度则关闭。" + }, "openai_embedding": { - "hint": "OpenAI Embedding 会在请求时自动补上 /v1。" + "hint": "OpenAI Embedding 保留给已有配置和标准 /v1 接口使用,会在请求时自动补上 /v1,并在填写 embedding_dimensions 时继续直接发送 dimensions。" + }, + "openai_compatible_embedding": { + "hint": "用于更广义的 OpenAI 兼容 embedding 服务。如果 API Base URL 不带路径,AstrBot 会自动补上 /v1;如果已经带了 /api/paas/v4 或 /api/v3 这类路径,则保持原样。相比旧的 OpenAI Embedding,这里还可以显式控制是否透传 dimensions。" + }, + "zhipu_embedding": { + "hint": "智谱预设默认使用 https://open.bigmodel.cn/api/paas/v4 和 embedding-3,并保留 /api/paas/v4 路径不变。" + }, + "volcengine_embedding": { + "hint": "火山预设默认使用 https://ark.cn-beijing.volces.com/api/v3 和 doubao-embedding-vision,并保留 /api/v3 路径不变。AstrBot 当前仍只按文本输入使用该模型。" + }, + "ollama_embedding": { + "hint": "Ollama 预设默认使用本地 http://127.0.0.1:11434、模型 embeddinggemma、维度 768。API Key 默认填 ollama,仅用于兼容 OpenAI SDK,Ollama 会忽略它。" }, "gemini_embedding": { "hint": "Gemini Embedding 无需手动添加 /v1beta。" @@ -1520,4 +1536,4 @@ "helpMiddle": "或", "helpSuffix": "。" } -} \ No newline at end of file +} diff --git a/docs/en/use/knowledge-base.md b/docs/en/use/knowledge-base.md index b1f9e1dc12..9b1a0152ed 100644 --- a/docs/en/use/knowledge-base.md +++ b/docs/en/use/knowledge-base.md @@ -10,12 +10,25 @@ Open the service provider page, click "Add Service Provider", and select Embedding. -Currently, AstrBot supports embedding vector services compatible with OpenAI API and Gemini API. +AstrBot now includes built-in presets for OpenAI-compatible Embedding, Zhipu Embedding, Volcengine Embedding, Ollama Embedding, and Gemini Embedding. + +If you want to connect another OpenAI-compatible embedding service, use `OpenAI Compatible Embedding` first. When `embedding api base` only contains the host, AstrBot automatically appends `/v1`. If the URL already contains a path such as Zhipu `/api/paas/v4` or Volcengine Ark `/api/v3`, AstrBot preserves that path as-is. + +The legacy `OpenAI Embedding` preset is still kept for backward compatibility with existing configurations. It remains a good fit for standard OpenAI-style `/v1` endpoints and keeps the previous behavior of forwarding `dimensions` whenever `embedding_dimensions` is configured. `OpenAI Compatible Embedding` targets broader compatibility by preserving provider-specific path prefixes and making `dimensions` forwarding opt-in, so the two presets are intentionally not interchangeable. Click on the provider card above to enter the configuration page and fill in the configuration. After completing the configuration, click Save. +> [!NOTE] +> `OpenAI Compatible Embedding` includes a `send_dimensions_param` switch. When enabled, AstrBot sends `embedding_dimensions` to the upstream embedding API as the `dimensions` parameter. Disable it for OpenAI-compatible services that only need the local vector size and do not support `dimensions`. + +> [!NOTE] +> The Volcengine preset defaults to `doubao-embedding-vision`. AstrBot's knowledge-base pipeline is still text chunking plus text embedding only, so this integration uses the model with text input only and does not add multimodal knowledge-base support yet, although it is a multimodal embedding model. + +> [!NOTE] +> The Ollama preset defaults to local `http://127.0.0.1:11434`, model `embeddinggemma`, and 768 dimensions. Before using it, run `ollama pull embeddinggemma` locally and make sure the Ollama service is running. + ## Configuring Reranker Model (Optional) A reranker model can improve the precision of final retrieval results to some extent. @@ -53,7 +66,7 @@ In the configuration file, you can specify different knowledge bases for differe 2. Go to the [Model Marketplace](https://ppio.cn/model-api/console) and click on Embedding Models. 3. Click on BAAI:BGE-M3 (as of 2025-06-02, this model is free on this platform). 4. Find the API integration guide and apply for a Key. -5. Fill in the AstrBot OpenAI Embedding model provider configuration: +5. Fill in the AstrBot `OpenAI Compatible Embedding` model provider configuration: 1. API Key is the PPIO API Key you just applied for 2. embedding api base: enter `https://api.ppinfra.com/v3/openai` 3. model: enter the model you selected, in this example `baai/bge-m3`. diff --git a/docs/zh/use/knowledge-base.md b/docs/zh/use/knowledge-base.md index d79336c251..77fc06f44e 100644 --- a/docs/zh/use/knowledge-base.md +++ b/docs/zh/use/knowledge-base.md @@ -11,12 +11,25 @@ 打开服务提供商页面,点击新增服务提供商,选择 Embedding。 -目前 AstrBot 支持兼容 OpenAI API 和 Gemini API 的嵌入向量服务。 +目前 AstrBot 内置了通用 OpenAI-compatible Embedding、智谱 Embedding、火山 Embedding、Ollama Embedding 和 Gemini Embedding。 + +如果你要接入其他兼容 OpenAI API 的嵌入服务,优先选择 `OpenAI Compatible Embedding`。当 `embedding api base` 只填写域名时,AstrBot 会自动补上 `/v1`;如果你填写的是带路径的地址,例如智谱的 `/api/paas/v4` 或火山 Ark 的 `/api/v3`,AstrBot 会保持原样,不会额外拼接 `/v1`。 + +保留旧的 `OpenAI Embedding` 主要是为了兼容已有配置。它仍然适合标准 OpenAI 风格的 `/v1` 接口,并会继续沿用原来的行为:当你配置了 `embedding_dimensions` 时,请求里会直接发送 `dimensions`。`OpenAI Compatible Embedding` 则用于更广义的兼容服务,重点解决自定义路径前缀和是否透传 `dimensions` 这两个兼容性问题,因此两者不是简单重复。 点击上面的提供商卡片进入配置页面,填写配置。 配置完成后,点击保存。 +> [!NOTE] +> `OpenAI Compatible Embedding` 提供了 `send_dimensions_param` 开关。开启后,AstrBot 会把 `embedding_dimensions` 作为 `dimensions` 参数发送给上游接口;如果你的兼容服务只需要本地向量维度、但不支持 `dimensions` 参数,请关闭它。 + +> [!NOTE] +> 火山预设默认模型为 `doubao-embedding-vision`。AstrBot 当前的知识库链路仍然只按文本分块和文本 embedding 工作,所以本次接入只会按文本输入使用该模型,不代表已经支持多模态知识库。 + +> [!NOTE] +> Ollama 预设默认指向本地 `http://127.0.0.1:11434`,模型为 `embeddinggemma`,默认维度为 768。开始使用前请先在本机执行 `ollama pull embeddinggemma`,并确保 Ollama 服务已经启动。 + ## 配置重排序模型(可选) 重排序模型可以一定程度上提高最终召回结果的精度。 @@ -54,7 +67,7 @@ AstrBot 支持多知识库管理。在聊天时,您可以**自由指定知识 2. 进入 [模型广场](https://ppio.cn/model-api/console),点击嵌入模型 3. 点击 BAAI:BGE-M3 (截止至 2025-06-02,该模型在该平台免费)。 4. 找到 API 接入指南,申请 Key。 -5. 填写 AstrBot OpenAI Embedding 模型提供商配置: +5. 填写 AstrBot `OpenAI Compatible Embedding` 模型提供商配置: 1. API Key 为刚刚申请的 PPIO 的 API Key 2. embedding api base 填写 `https://api.ppinfra.com/v3/openai` 3. model 填写你选择的模型,此例子中为 `baai/bge-m3`。 diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py index bf14aa4c72..f00da85539 100644 --- a/tests/test_dashboard.py +++ b/tests/test_dashboard.py @@ -15,6 +15,7 @@ from astrbot.core import LogBroker from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.db.sqlite import SQLiteDatabase +from astrbot.core.provider.sources import openai_compatible_embedding_source from astrbot.core.star.star import star_registry from astrbot.core.star.star_handler import star_handlers_registry from astrbot.core.utils.pip_installer import PipInstallError @@ -107,6 +108,87 @@ async def test_get_stat(app: Quart, authenticated_header: dict): assert data["status"] == "ok" and "platform" in data["data"] +@pytest.mark.asyncio +async def test_provider_template_exposes_openai_compatible_embedding_presets( + app: Quart, + authenticated_header: dict, +): + test_client = app.test_client() + response = await test_client.get( + "/api/config/provider/template", + headers=authenticated_header, + ) + + assert response.status_code == 200 + data = await response.get_json() + assert data["status"] == "ok" + + templates = data["data"]["config_schema"]["provider"]["config_template"] + assert "OpenAI Compatible Embedding" in templates + assert "Zhipu Embedding" in templates + assert "Volcengine Embedding" in templates + assert "Ollama Embedding" in templates + assert templates["OpenAI Compatible Embedding"]["type"] == ( + "openai_compatible_embedding" + ) + assert templates["Ollama Embedding"]["provider"] == "ollama" + + +class _FakeDashboardEmbeddingsAPI: + async def create(self, **kwargs): + return SimpleNamespace( + data=[SimpleNamespace(embedding=[0.1, 0.2, 0.3, 0.4])], + ) + + +class _FakeDashboardAsyncOpenAI: + def __init__(self, **kwargs) -> None: + self.kwargs = kwargs + self.embeddings = _FakeDashboardEmbeddingsAPI() + + async def close(self): + return None + + +@pytest.mark.asyncio +async def test_get_embedding_dim_supports_openai_compatible_embedding( + app: Quart, + authenticated_header: dict, + monkeypatch: pytest.MonkeyPatch, +): + test_client = app.test_client() + monkeypatch.setattr( + openai_compatible_embedding_source, + "AsyncOpenAI", + _FakeDashboardAsyncOpenAI, + ) + + response = await test_client.post( + "/api/config/provider/get_embedding_dim", + json={ + "provider_config": { + "id": "dashboard-openai-compatible-embedding", + "type": "openai_compatible_embedding", + "provider_type": "embedding", + "embedding_api_key": "test-key", + "embedding_api_base": "https://example.com", + "embedding_model": "text-embedding-3-small", + "embedding_dimensions": 2048, + "send_dimensions_param": False, + "timeout": 20, + "proxy": "", + "enable": True, + } + }, + headers=authenticated_header, + ) + + assert response.status_code == 200 + data = await response.get_json() + assert data["status"] == "ok" + assert data["data"]["embedding_dimensions"] == 4 + + @pytest.mark.asyncio async def test_subagent_config_accepts_default_persona( app: Quart, diff --git a/tests/test_openai_compatible_embedding_source.py b/tests/test_openai_compatible_embedding_source.py new file mode 100644 index 0000000000..4fad605cc1 --- /dev/null +++ b/tests/test_openai_compatible_embedding_source.py @@ -0,0 +1,193 @@ +from types import SimpleNamespace + +import pytest + +from astrbot.core.provider.sources import openai_compatible_embedding_source as source + + +class _FakeEmbeddingsAPI: + def __init__(self, create_calls: list[dict]) -> None: + self._create_calls = create_calls + + async def create(self, **kwargs): + self._create_calls.append(kwargs) + return SimpleNamespace( + data=[SimpleNamespace(embedding=[0.1, 0.2, 0.3])], + ) + + +class _FakeAsyncOpenAI: + instances: list["_FakeAsyncOpenAI"] = [] + + def __init__(self, **kwargs) -> None: + self.kwargs = kwargs + self.create_calls: list[dict] = [] + self.embeddings = _FakeEmbeddingsAPI(self.create_calls) + self.closed = False + self.__class__.instances.append(self) + + async def close(self): + self.closed = True + + +class _FakeHTTPClient: + instances: list["_FakeHTTPClient"] = [] + + def __init__(self, **kwargs) -> None: + self.kwargs = kwargs + self.closed = False + self.__class__.instances.append(self) + + async def aclose(self): + self.closed = True + + +def _make_provider_config(**overrides) -> dict: + provider_config = { + "id": "test-openai-compatible-embedding", + "type": "openai_compatible_embedding", + "embedding_api_key": "test-key", + "embedding_api_base": "", + "embedding_model": "text-embedding-3-small", + "embedding_dimensions": 1024, + "send_dimensions_param": False, + "timeout": 20, + "proxy": "", + } + provider_config.update(overrides) + return provider_config + + +@pytest.mark.parametrize( + ("api_base", "expected_base_url"), + [ + ("", "https://api.openai.com/v1"), + ("api.openai.com", "https://api.openai.com/v1"), + ("https://example.com", "https://example.com/v1"), + ("https://example.com/", "https://example.com/v1"), + ( + "https://open.bigmodel.cn/api/paas/v4", + "https://open.bigmodel.cn/api/paas/v4", + ), + ( + "https://ark.cn-beijing.volces.com/api/v3", + "https://ark.cn-beijing.volces.com/api/v3", + ), + ], +) +def test_normalize_openai_compatible_embedding_api_base(api_base, expected_base_url): + assert ( + source.normalize_openai_compatible_embedding_api_base(api_base) + == expected_base_url + ) + + +@pytest.mark.asyncio +async def test_openai_compatible_embedding_provider_appends_v1_only_for_host_url( + monkeypatch: pytest.MonkeyPatch, +): + _FakeAsyncOpenAI.instances.clear() + monkeypatch.setattr(source, "AsyncOpenAI", _FakeAsyncOpenAI) + + provider = source.OpenAICompatibleEmbeddingProvider( + _make_provider_config(embedding_api_base="https://example.com"), + {}, + ) + + try: + assert _FakeAsyncOpenAI.instances[-1].kwargs["base_url"] == ( + "https://example.com/v1" + ) + finally: + await provider.terminate() + assert provider.client.closed is True + + +@pytest.mark.asyncio +async def test_openai_compatible_embedding_provider_preserves_existing_api_path( + monkeypatch: pytest.MonkeyPatch, +): + _FakeAsyncOpenAI.instances.clear() + monkeypatch.setattr(source, "AsyncOpenAI", _FakeAsyncOpenAI) + + provider = source.OpenAICompatibleEmbeddingProvider( + _make_provider_config( + embedding_api_base="https://open.bigmodel.cn/api/paas/v4", + ), + {}, + ) + + try: + assert _FakeAsyncOpenAI.instances[-1].kwargs["base_url"] == ( + "https://open.bigmodel.cn/api/paas/v4" + ) + finally: + await provider.terminate() + assert provider.client.closed is True + + +@pytest.mark.asyncio +async def test_openai_compatible_embedding_provider_sends_dimensions_only_when_enabled( + monkeypatch: pytest.MonkeyPatch, +): + _FakeAsyncOpenAI.instances.clear() + monkeypatch.setattr(source, "AsyncOpenAI", _FakeAsyncOpenAI) + + provider_without_dimensions = source.OpenAICompatibleEmbeddingProvider( + _make_provider_config(send_dimensions_param=False), + {}, + ) + provider_with_dimensions = source.OpenAICompatibleEmbeddingProvider( + _make_provider_config(send_dimensions_param=True, embedding_dimensions=2048), + {}, + ) + + try: + await provider_without_dimensions.get_embedding("hello") + await provider_with_dimensions.get_embedding("hello") + + assert "dimensions" not in _FakeAsyncOpenAI.instances[0].create_calls[0] + assert _FakeAsyncOpenAI.instances[1].create_calls[0]["dimensions"] == 2048 + finally: + await provider_without_dimensions.terminate() + await provider_with_dimensions.terminate() + assert provider_without_dimensions.client.closed is True + assert provider_with_dimensions.client.closed is True + + +@pytest.mark.asyncio +async def test_openai_compatible_embedding_provider_closes_proxy_http_client( + monkeypatch: pytest.MonkeyPatch, +): + _FakeAsyncOpenAI.instances.clear() + _FakeHTTPClient.instances.clear() + monkeypatch.setattr(source, "AsyncOpenAI", _FakeAsyncOpenAI) + monkeypatch.setattr(source.httpx, "AsyncClient", _FakeHTTPClient) + + provider = source.OpenAICompatibleEmbeddingProvider( + _make_provider_config(proxy="http://127.0.0.1:7890"), + {}, + ) + + try: + assert _FakeHTTPClient.instances[-1].kwargs["proxy"] == "http://127.0.0.1:7890" + finally: + await provider.terminate() + assert provider.client.closed is True + assert provider._http_client.closed is True + + +def test_should_send_dimensions_param_requires_boolean( + monkeypatch: pytest.MonkeyPatch, +): + warnings: list[str] = [] + + def _capture_warning(message, *args): + warnings.append(message % args) + + monkeypatch.setattr(source.logger, "warning", _capture_warning) + + assert ( + source.should_send_dimensions_param({"send_dimensions_param": "true"}) is False + ) + assert "send_dimensions_param should be a boolean" in warnings[0]