Skip to content

Commit a76c07a

Browse files
authored
Merge pull request #14 from oracle/dev/v1.2.1
dev/v1.2.1
2 parents aa3cc84 + c1f523d commit a76c07a

File tree

10 files changed

+393
-69
lines changed

10 files changed

+393
-69
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@ pytest.env
1919
sample_connect.py
2020
async_pipeline_test.py
2121
parquet.py
22+
local_sample

src/select_ai/action.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@ class Action(StrEnum):
2121
SHOWPROMPT = "showprompt"
2222
FEEDBACK = "feedback"
2323
SUMMARIZE = "summarize"
24+
TRANSLATE = "translate"

src/select_ai/async_profile.py

Lines changed: 55 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import json
99
from contextlib import asynccontextmanager
10-
from dataclasses import replace as dataclass_replace
1110
from typing import (
1211
AsyncGenerator,
1312
List,
@@ -30,7 +29,10 @@
3029
)
3130
from select_ai.conversation import AsyncConversation
3231
from select_ai.db import async_cursor, async_get_connection
33-
from select_ai.errors import ProfileExistsError, ProfileNotFoundError
32+
from select_ai.errors import (
33+
ProfileAttributesEmptyError,
34+
ProfileNotFoundError,
35+
)
3436
from select_ai.feedback import (
3537
FeedbackOperation,
3638
FeedbackType,
@@ -69,35 +71,23 @@ async def _init_profile(self):
6971
if self.profile_name:
7072
profile_exists = False
7173
try:
72-
saved_attributes = await self._get_attributes(
74+
saved_description = await self._get_profile_description(
7375
profile_name=self.profile_name
7476
)
7577
profile_exists = True
76-
if not self.replace and not self.merge:
77-
if (
78-
self.attributes is not None
79-
or self.description is not None
80-
):
81-
if self.raise_error_if_exists:
82-
raise ProfileExistsError(self.profile_name)
83-
84-
if self.description is None and not self.replace:
85-
self.description = await self._get_profile_description(
86-
profile_name=self.profile_name
87-
)
78+
saved_attributes = await self._get_attributes(
79+
profile_name=self.profile_name,
80+
raise_on_empty=True,
81+
)
82+
self._raise_error_if_profile_exists()
83+
except ProfileAttributesEmptyError:
84+
if self.raise_error_on_empty_attributes:
85+
raise
8886
except ProfileNotFoundError:
8987
if self.attributes is None and self.description is None:
9088
raise
9189
else:
92-
if self.attributes is None:
93-
self.attributes = saved_attributes
94-
if self.merge:
95-
self.replace = True
96-
if self.attributes is not None:
97-
self.attributes = dataclass_replace(
98-
saved_attributes,
99-
**self.attributes.dict(exclude_null=True),
100-
)
90+
self._merge_attributes(saved_attributes, saved_description)
10191
if self.replace or not profile_exists:
10292
await self.create(replace=self.replace)
10393
else: # profile name is None:
@@ -132,12 +122,15 @@ async def _get_profile_description(profile_name) -> Union[str, None]:
132122
raise ProfileNotFoundError(profile_name)
133123

134124
@staticmethod
135-
async def _get_attributes(profile_name) -> ProfileAttributes:
125+
async def _get_attributes(
126+
profile_name: str, raise_on_empty: bool = True
127+
) -> Union[ProfileAttributes, None]:
136128
"""Asynchronously gets AI profile attributes from the Database
137129
138130
:param str profile_name: Name of the profile
131+
:param bool raise_on_empty: Raise an error if attributes are empty
139132
:return: select_ai.provider.ProviderAttributes
140-
:raises: ProfileNotFoundError
133+
:raises: select_ai.errors.ProfileAttributesEmptyError
141134
142135
"""
143136
async with async_cursor() as cr:
@@ -149,7 +142,11 @@ async def _get_attributes(profile_name) -> ProfileAttributes:
149142
if attributes:
150143
return await ProfileAttributes.async_create(**dict(attributes))
151144
else:
152-
raise ProfileNotFoundError(profile_name=profile_name)
145+
if raise_on_empty:
146+
raise ProfileAttributesEmptyError(
147+
profile_name=profile_name
148+
)
149+
return None
153150

154151
async def get_attributes(self) -> ProfileAttributes:
155152
"""Asynchronously gets AI profile attributes from the Database
@@ -387,7 +384,9 @@ async def list(
387384
for row in rows:
388385
profile_name = row[0]
389386
yield await cls(
390-
profile_name=profile_name, raise_error_if_exists=False
387+
profile_name=profile_name,
388+
raise_error_if_exists=False,
389+
raise_error_on_empty_attributes=False,
391390
)
392391

393392
async def generate(
@@ -623,6 +622,34 @@ async def run_pipeline(
623622
responses.append(result.error)
624623
return responses
625624

625+
async def translate(
626+
self, text: str, source_language: str, target_language: str
627+
) -> Union[str, None]:
628+
"""
629+
Translate a text using a source language and a target language
630+
631+
:param str text: Text to translate
632+
:param str source_language: Source language
633+
:param str target_language: Target language
634+
:return: str
635+
"""
636+
parameters = {
637+
"profile_name": self.profile_name,
638+
"text": text,
639+
"source_language": source_language,
640+
"target_language": target_language,
641+
}
642+
async with async_cursor() as cr:
643+
data = await cr.callfunc(
644+
"DBMS_CLOUD_AI.TRANSLATE",
645+
oracledb.DB_TYPE_CLOB,
646+
keyword_parameters=parameters,
647+
)
648+
if data is not None:
649+
result = await data.read()
650+
return result
651+
return None
652+
626653

627654
class AsyncSession:
628655
"""AsyncSession lets you persist request parameters across DBMS_CLOUD_AI

src/select_ai/base_profile.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
import json
99
from abc import ABC
1010
from dataclasses import dataclass
11+
from dataclasses import replace as dataclass_replace
1112
from typing import List, Mapping, Optional, Tuple
1213

1314
import oracledb
1415

1516
from select_ai._abc import SelectAIDataClass
1617
from select_ai.action import Action
18+
from select_ai.errors import ProfileExistsError
1719
from select_ai.feedback import (
1820
FeedbackOperation,
1921
FeedbackType,
@@ -159,6 +161,10 @@ class BaseProfile(ABC):
159161
if profile exists in the database and replace = False and
160162
merge = False. Default value is True
161163
164+
:param bool raise_error_on_empty_attributes: Raise
165+
ProfileEmptyAttributesError, if profile attributes are empty
166+
in database. Default value is False.
167+
162168
"""
163169

164170
def __init__(
@@ -169,6 +175,7 @@ def __init__(
169175
merge: Optional[bool] = False,
170176
replace: Optional[bool] = False,
171177
raise_error_if_exists: Optional[bool] = True,
178+
raise_error_on_empty_attributes: Optional[bool] = False,
172179
):
173180
"""Initialize a base profile"""
174181
self.profile_name = profile_name
@@ -182,6 +189,34 @@ def __init__(
182189
self.merge = merge
183190
self.replace = replace
184191
self.raise_error_if_exists = raise_error_if_exists
192+
self.raise_error_on_empty_attributes = raise_error_on_empty_attributes
193+
194+
def _raise_error_if_profile_exists(self):
195+
"""
196+
Helper method to raise ProfileExistsError if profile exists
197+
in the database and replace = False and merge = False
198+
"""
199+
if not self.replace and not self.merge:
200+
if self.attributes is not None or self.description is not None:
201+
if self.raise_error_if_exists:
202+
raise ProfileExistsError(self.profile_name)
203+
204+
def _merge_attributes(self, saved_attributes, saved_description):
205+
"""
206+
Helper method to merge user passed attributes with the attributes saved
207+
in the database.
208+
"""
209+
if self.description is None and not self.replace:
210+
self.description = saved_description
211+
if self.attributes is None:
212+
self.attributes = saved_attributes
213+
if self.merge:
214+
self.replace = True
215+
if self.attributes is not None:
216+
self.attributes = dataclass_replace(
217+
saved_attributes,
218+
**self.attributes.dict(exclude_null=True),
219+
)
185220

186221
def __repr__(self):
187222
return (
@@ -206,15 +241,15 @@ def validate_params_for_feedback(
206241
response: Optional[str] = None,
207242
operation: Optional[FeedbackOperation] = FeedbackOperation.ADD,
208243
):
209-
if sql_id and prompt_spec:
210-
raise AttributeError("Either sql_id or prompt_spec must be specified")
211244
if not sql_id and not prompt_spec:
212245
raise AttributeError("Either sql_id or prompt_spec must be specified")
213-
parameters = {
214-
"feedback_type": feedback_type.value,
215-
"feedback_content": feedback_content,
216-
"operation": operation.value,
217-
}
246+
parameters = {"operation": operation.value}
247+
if feedback_content:
248+
parameters["feedback_content"] = feedback_content
249+
if feedback_type:
250+
parameters["feedback_type"] = feedback_type.value
251+
if response:
252+
parameters["response"] = response
218253
if prompt_spec:
219254
prompt, action = prompt_spec
220255
if action not in (Action.RUNSQL, Action.SHOWSQL, Action.EXPLAINSQL):

src/select_ai/errors.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,18 @@ def __str__(self):
5656
)
5757

5858

59+
class ProfileAttributesEmptyError(SelectAIError):
60+
"""Profile attributes empty in the database"""
61+
62+
def __init__(self, profile_name: str):
63+
self.profile_name = profile_name
64+
65+
def __str__(self):
66+
return (
67+
f"Profile {self.profile_name} attributes empty in the database. "
68+
)
69+
70+
5971
class VectorIndexNotFoundError(SelectAIError):
6072
"""VectorIndex not found in the database"""
6173

0 commit comments

Comments
 (0)