Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class BaiLianEmbeddingModelParams(BaseForm):


class AliyunBaiLianEmbeddingCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField(_('API URL'), required=False, default_value='https://dashscope.aliyuncs.com/compatible-mode/v1')
dashscope_api_key = forms.PasswordInputField('API Key', required=True)

def is_valid(
self,
Expand Down Expand Up @@ -91,5 +93,3 @@ def encryption_dict(self, model: Dict[str, Any]) -> Dict[str, Any]:

def get_model_params_setting_form(self, model_name):
return BaiLianEmbeddingModelParams()

dashscope_api_key = forms.PasswordInputField('API Key', required=True)
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class QwenModelParams(BaseForm):


class QwenVLModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField(_('API URL'), required=False, default_value='https://dashscope.aliyuncs.com/compatible-mode/v1')
api_key = forms.PasswordInputField('API Key', required=True)

def is_valid(
self,
Expand Down Expand Up @@ -84,7 +86,5 @@ def is_valid(
def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]:
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}

api_key = forms.PasswordInputField('API Key', required=True)

def get_model_params_setting_form(self, model_name):
return QwenModelParams()
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from django.utils.translation import gettext_lazy as _, gettext

from common.exception.app_exception import AppApiException
from common import forms
from common.forms import BaseForm, PasswordInputField, SingleSelect, SliderField, TooltipLabel
from common.forms.switch_field import SwitchField
from models_provider.base_model_provider import BaseModelCredential, ValidCode
Expand Down Expand Up @@ -41,6 +42,7 @@ class ImageToVideoModelCredential(BaseForm, BaseModelCredential):
Provides validation and encryption for the model credentials.
"""

api_base = forms.TextInputField(_('API URL'), required=False, default_value='https://dashscope.aliyuncs.com/compatible-mode/v1')
api_key = PasswordInputField('API Key', required=True)

def is_valid(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class BaiLianLLMModelParams(BaseForm):


class BaiLianLLMModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField(_('API URL'), required=True)
api_base = forms.TextInputField(_('API URL'), required=True, default_value='https://dashscope.aliyuncs.com/compatible-mode/v1')
api_key = forms.PasswordInputField(_('API Key'), required=True)

def is_valid(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from langchain_core.documents import Document

from common.exception.app_exception import AppApiException
from common import forms
from common.forms import BaseForm, PasswordInputField
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from models_provider.impl.aliyun_bai_lian_model_provider.model.reranker import AliyunBaiLianReranker
Expand All @@ -17,6 +18,7 @@ class AliyunBaiLianRerankerCredential(BaseForm, BaseModelCredential):
Provides validation and encryption for the model credentials.
"""

api_base = forms.TextInputField(_('API URL'), required=False, default_value='https://dashscope.aliyuncs.com/compatible-mode/v1')
dashscope_api_key = PasswordInputField('API Key', required=True)

def is_valid(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from django.utils.translation import gettext_lazy as _, gettext

from common.exception.app_exception import AppApiException
from common import forms
from common.forms import BaseForm, PasswordInputField, SingleSelect, SliderField, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from common.utils.logger import maxkb_logger
Expand Down Expand Up @@ -68,6 +69,7 @@ class QwenTextToImageModelCredential(BaseForm, BaseModelCredential):
Provides validation and encryption for the model credentials.
"""

api_base = forms.TextInputField(_('API URL'), required=False, default_value='https://dashscope.aliyuncs.com/compatible-mode/v1')
api_key = PasswordInputField('API Key', required=True)

def is_valid(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from django.utils.translation import gettext_lazy as _, gettext

from common.exception.app_exception import AppApiException
from common import forms
from common.forms import BaseForm, PasswordInputField, SingleSelect, SliderField, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from common.utils.logger import maxkb_logger
Expand Down Expand Up @@ -59,6 +60,7 @@ class AliyunBaiLianTTSModelCredential(BaseForm, BaseModelCredential):
Provides validation and encryption for the model credentials.
"""

api_base = forms.TextInputField(_('API URL'), required=False, default_value='https://dashscope.aliyuncs.com/compatible-mode/v1')
api_key = PasswordInputField("API Key", required=True)

def is_valid(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from django.utils.translation import gettext_lazy as _, gettext

from common.exception.app_exception import AppApiException
from common import forms
from common.forms import BaseForm, PasswordInputField, SingleSelect, SliderField, TooltipLabel
from common.forms.switch_field import SwitchField
from models_provider.base_model_provider import BaseModelCredential, ValidCode
Expand Down Expand Up @@ -43,6 +44,7 @@ class TextToVideoModelCredential(BaseForm, BaseModelCredential):
Provides validation and encryption for the model credentials.
"""

api_base = forms.TextInputField(_('API URL'), required=False, default_value='https://dashscope.aliyuncs.com/compatible-mode/v1')
api_key = PasswordInputField('API Key', required=True)

def is_valid(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ class AliyunBaiLianEmbedding(MaxKBBaseModel):
model_name: str
optional_params: dict

def __init__(self, api_key, model_name: str, optional_params: dict):
self.client = OpenAI(api_key=api_key, base_url='https://dashscope.aliyuncs.com/compatible-mode/v1').embeddings
def __init__(self, api_key, api_base, model_name: str, optional_params: dict):
api_base = api_base or 'https://dashscope.aliyuncs.com/compatible-mode/v1'
self.client = OpenAI(api_key=api_key, base_url=api_base).embeddings
self.model_name = model_name
self.optional_params = optional_params

Expand All @@ -30,6 +31,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return AliyunBaiLianEmbedding(
api_key=model_credential.get('dashscope_api_key'),
api_base=model_credential.get('api_base'),
model_name=model_name,
optional_params=optional_params
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ def is_cache_model():
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
api_base = model_credential.get('api_base') or 'https://dashscope.aliyuncs.com/compatible-mode/v1'
chat_tong_yi = QwenVLChatModel(
model_name=model_name,
openai_api_key=model_credential.get('api_key'),
openai_api_base='https://dashscope.aliyuncs.com/compatible-mode/v1',
openai_api_base=api_base,
# stream_options={"include_usage": True},
streaming=True,
stream_usage=True,
Expand All @@ -41,7 +42,15 @@ def check_auth(self, api_key):

def get_upload_policy(self, api_key, model_name):
"""获取文件上传凭证"""
url = "https://dashscope.aliyuncs.com/api/v1/uploads"
# 如果有自定义api_base,提取host部分,否则使用默认URL
if hasattr(self, 'openai_api_base') and self.openai_api_base:
# 从api_base中提取host,替换默认URL
from urllib.parse import urlparse
parsed_url = urlparse(self.openai_api_base)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
url = f"{base_url}/api/v1/uploads"
else:
url = "https://dashscope.aliyuncs.com/api/v1/uploads"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
Expand Down Expand Up @@ -109,7 +118,11 @@ def stream(
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> Iterator[BaseMessageChunk]:
url = "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions"
# 如果有自定义api_base,使用它,否则使用默认URL
if hasattr(self, 'openai_api_base') and self.openai_api_base:
url = f"{self.openai_api_base}/chat/completions"
else:
url = "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions"

headers = {
"Authorization": f"Bearer {self.openai_api_key.get_secret_value()}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
import dashscope
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.documents import BaseDocumentCompressor

from models_provider.base_model_provider import MaxKBBaseModel


class AliyunBaiLianReranker(MaxKBBaseModel, BaseDocumentCompressor):
model: Optional[str]
api_key: Optional[str]
api_base: Optional[str]

top_n: Optional[int] = 3 # 取前 N 个最相关的结果

Expand All @@ -31,6 +31,7 @@ def is_cache_model():
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return AliyunBaiLianReranker(model=model_name,
api_key=model_credential.get('dashscope_api_key'),
api_base=model_credential.get('api_base'),
top_n=model_kwargs.get('top_n', 3))

def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
Expand All @@ -39,6 +40,9 @@ def compress_documents(self, documents: Sequence[Document], query: str, callback
return []

texts = [doc.page_content for doc in documents]
# 如果提供了api_base,则配置dashscope使用自定义endpoint
if self.api_base:
dashscope.base_http_url = self.api_base
resp = dashscope.TextReRank.call(
model=self.model,
query=query,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# coding=utf-8
from http import HTTPStatus
from typing import Dict
from typing import Dict, Optional

from dashscope import ImageSynthesis, MultiModalConversation
from django.utils.translation import gettext
Expand All @@ -15,12 +15,14 @@

class QwenTextToImageModel(MaxKBBaseModel, BaseTextToImage):
api_key: str
api_base: Optional[str]
model_name: str
params: dict

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.api_base = kwargs.get('api_base')
self.model_name = kwargs.get('model_name')
self.params = kwargs.get('params')

Expand All @@ -37,6 +39,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
chat_tong_yi = QwenTextToImageModel(
model_name=model_name,
api_key=model_credential.get('api_key'),
api_base=model_credential.get('api_base'),
**optional_params,
)
return chat_tong_yi
Expand All @@ -47,9 +50,11 @@ def check_auth(self):

def generate_image(self, prompt: str, negative_prompt: str = None):
if self.model_name.startswith("wan"):
# 如果提供了api_base,则使用自定义base_url,否则使用默认URL
base_url = self.api_base or 'https://dashscope.aliyuncs.com/compatible-mode/v1'
rsp = ImageSynthesis.call(api_key=self.api_key,
model=self.model_name,
base_url='https://dashscope.aliyuncs.com/compatible-mode/v1',
base_url=base_url,
prompt=prompt,
negative_prompt=negative_prompt,
**self.params)
Expand All @@ -73,12 +78,14 @@ def generate_image(self, prompt: str, negative_prompt: str = None):
]
}
]
# 如果提供了api_base,则使用自定义base_url,否则使用默认URL
base_url = self.api_base or 'https://dashscope.aliyuncs.com/v1'
rsp = MultiModalConversation.call(
api_key=self.api_key,
model=self.model_name,
messages=messages,
result_format='message',
base_url='https://dashscope.aliyuncs.com/v1',
base_url=base_url,
stream=False,
negative_prompt=negative_prompt,
**self.params
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Dict
from typing import Dict, Optional

import dashscope
from dashscope.api_entities.dashscope_response import DashScopeAPIResponse

from django.utils.translation import gettext as _

Expand All @@ -11,12 +12,14 @@

class AliyunBaiLianTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
api_key: str
api_base: Optional[str]
model: str
params: dict

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.api_base = kwargs.get('api_base')
self.model = kwargs.get('model')
self.params = kwargs.get('params')

Expand All @@ -34,6 +37,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
return AliyunBaiLianTextToSpeech(
model=model_name,
api_key=model_credential.get('api_key'),
api_base=model_credential.get('api_base'),
**optional_params,
)

Expand All @@ -42,6 +46,9 @@ def check_auth(self):

def text_to_speech(self, text):
dashscope.api_key = self.api_key
# 如果提供了api_base,则配置dashscope使用自定义endpoint
if self.api_base:
dashscope.base_http_url = self.api_base
text = _remove_empty_lines(text)
if 'sambert' in self.model:
from dashscope.audio.tts import SpeechSynthesizer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

class GenerationVideoModel(MaxKBBaseModel, BaseGenerationVideo):
api_key: str
api_base: Optional[str]
model_name: str
params: dict
max_retries: int = 3
Expand All @@ -22,6 +23,7 @@ class GenerationVideoModel(MaxKBBaseModel, BaseGenerationVideo):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.api_base = kwargs.get('api_base')
self.model_name = kwargs.get('model_name')
self.params = kwargs.get('params', {})
self.max_retries = kwargs.get('max_retries', 3)
Expand All @@ -40,6 +42,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
return GenerationVideoModel(
model_name=model_name,
api_key=model_credential.get('api_key'),
api_base=model_credential.get('api_base'),
**optional_params,
)

Expand Down Expand Up @@ -83,6 +86,9 @@ def generate_video(self, prompt, negative_prompt=None, first_frame_url=None, las
params.update(self.params)

# --- 异步提交任务 ---
# 如果提供了api_base,则配置dashscope使用自定义endpoint
if self.api_base:
params['base_url'] = self.api_base
rsp = self._safe_call(VideoSynthesis.async_call, **params)
if rsp.status_code != HTTPStatus.OK:
maxkb_logger.info(f'提交任务失败,status_code: {rsp.status_code}, code: {rsp.code}, message: {rsp.message}')
Expand Down