77
88import json
99from contextlib import asynccontextmanager
10- from dataclasses import replace as dataclass_replace
1110from typing import (
1211 AsyncGenerator ,
1312 List ,
3029)
3130from select_ai .conversation import AsyncConversation
3231from select_ai .db import async_cursor , async_get_connection
33- from select_ai .errors import ProfileExistsError , ProfileNotFoundError
32+ from select_ai .errors import (
33+ ProfileAttributesEmptyError ,
34+ ProfileNotFoundError ,
35+ )
3436from select_ai .feedback import (
3537 FeedbackOperation ,
3638 FeedbackType ,
@@ -69,35 +71,23 @@ async def _init_profile(self):
6971 if self .profile_name :
7072 profile_exists = False
7173 try :
72- saved_attributes = await self ._get_attributes (
74+ saved_description = await self ._get_profile_description (
7375 profile_name = self .profile_name
7476 )
7577 profile_exists = True
76- if not self .replace and not self .merge :
77- if (
78- self .attributes is not None
79- or self .description is not None
80- ):
81- if self .raise_error_if_exists :
82- raise ProfileExistsError (self .profile_name )
83-
84- if self .description is None and not self .replace :
85- self .description = await self ._get_profile_description (
86- profile_name = self .profile_name
87- )
78+ saved_attributes = await self ._get_attributes (
79+ profile_name = self .profile_name ,
80+ raise_on_empty = True ,
81+ )
82+ self ._raise_error_if_profile_exists ()
83+ except ProfileAttributesEmptyError :
84+ if self .raise_error_on_empty_attributes :
85+ raise
8886 except ProfileNotFoundError :
8987 if self .attributes is None and self .description is None :
9088 raise
9189 else :
92- if self .attributes is None :
93- self .attributes = saved_attributes
94- if self .merge :
95- self .replace = True
96- if self .attributes is not None :
97- self .attributes = dataclass_replace (
98- saved_attributes ,
99- ** self .attributes .dict (exclude_null = True ),
100- )
90+ self ._merge_attributes (saved_attributes , saved_description )
10191 if self .replace or not profile_exists :
10292 await self .create (replace = self .replace )
10393 else : # profile name is None:
@@ -132,12 +122,15 @@ async def _get_profile_description(profile_name) -> Union[str, None]:
132122 raise ProfileNotFoundError (profile_name )
133123
134124 @staticmethod
135- async def _get_attributes (profile_name ) -> ProfileAttributes :
125+ async def _get_attributes (
126+ profile_name : str , raise_on_empty : bool = True
127+ ) -> Union [ProfileAttributes , None ]:
136128 """Asynchronously gets AI profile attributes from the Database
137129
138130 :param str profile_name: Name of the profile
131+ :param bool raise_on_empty: Raise an error if attributes are empty
139132 :return: select_ai.provider.ProviderAttributes
140- :raises: ProfileNotFoundError
133+ :raises: select_ai.errors.ProfileAttributesEmptyError
141134
142135 """
143136 async with async_cursor () as cr :
@@ -149,7 +142,11 @@ async def _get_attributes(profile_name) -> ProfileAttributes:
149142 if attributes :
150143 return await ProfileAttributes .async_create (** dict (attributes ))
151144 else :
152- raise ProfileNotFoundError (profile_name = profile_name )
145+ if raise_on_empty :
146+ raise ProfileAttributesEmptyError (
147+ profile_name = profile_name
148+ )
149+ return None
153150
154151 async def get_attributes (self ) -> ProfileAttributes :
155152 """Asynchronously gets AI profile attributes from the Database
@@ -387,7 +384,9 @@ async def list(
387384 for row in rows :
388385 profile_name = row [0 ]
389386 yield await cls (
390- profile_name = profile_name , raise_error_if_exists = False
387+ profile_name = profile_name ,
388+ raise_error_if_exists = False ,
389+ raise_error_on_empty_attributes = False ,
391390 )
392391
393392 async def generate (
@@ -623,6 +622,34 @@ async def run_pipeline(
623622 responses .append (result .error )
624623 return responses
625624
625+ async def translate (
626+ self , text : str , source_language : str , target_language : str
627+ ) -> Union [str , None ]:
628+ """
629+ Translate a text using a source language and a target language
630+
631+ :param str text: Text to translate
632+ :param str source_language: Source language
633+ :param str target_language: Target language
634+ :return: str
635+ """
636+ parameters = {
637+ "profile_name" : self .profile_name ,
638+ "text" : text ,
639+ "source_language" : source_language ,
640+ "target_language" : target_language ,
641+ }
642+ async with async_cursor () as cr :
643+ data = await cr .callfunc (
644+ "DBMS_CLOUD_AI.TRANSLATE" ,
645+ oracledb .DB_TYPE_CLOB ,
646+ keyword_parameters = parameters ,
647+ )
648+ if data is not None :
649+ result = await data .read ()
650+ return result
651+ return None
652+
626653
627654class AsyncSession :
628655 """AsyncSession lets you persist request parameters across DBMS_CLOUD_AI
0 commit comments